Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions domino/airflow/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, List, Optional

from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults

from domino import Domino

Expand All @@ -29,7 +28,6 @@ class DominoOperator(BaseOperator):
template_fields = ("command", "title")
ui_color = "#5188c7"

@apply_defaults
def __init__(
self,
project: str,
Expand Down Expand Up @@ -180,7 +178,6 @@ class DominoSparkOperator(BaseOperator):
template_fields = ("command",)
ui_color = "#6C50AD"

@apply_defaults
def __init__(
self,
project: str,
Expand Down
10 changes: 3 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ def get_version():
keywords=["Domino Data Lab", "API"],
python_requires=">=3.10.0",
install_requires=[
"packaging==23.2",
"packaging>=23.2,<27",
"requests>=2.4.2",
"beautifulsoup4~=4.11",
"polling2~=0.5.0",
"urllib3>=1.26.19,<3",
"typing-extensions~=4.13.0",
"frozendict~=2.3",
"frozendict~=2.4.6",
"python-dateutil~=2.8.2",
"retry==0.9.2",
],
extras_require={
"airflow": ["apache-airflow==2.2.4"],
"airflow": ["apache-airflow~=2.11"],
"data": ["dominodatalab-data>=0.1.0"],
"agents": [
"semver>=3.0.4",
Expand All @@ -75,10 +75,6 @@ def get_version():
"ai-mock>=0.3.1", # used in agent tracing tests
"black==22.3.0",
"flake8==4.0.1",
"Jinja2==2.11.3",
"nbconvert==6.3.0",
"packaging==23.2",
"polling2==0.5.0",
"pre-commit==2.19.0",
"pyspark==3.3.0",
"pytest==7.4.3",
Expand Down
61 changes: 37 additions & 24 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
`airflow db init`
"""
import os
from datetime import datetime
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

import pytest
from airflow.operators.dummy import DummyOperator
from airflow.operators.empty import EmptyOperator

from domino.airflow import DominoOperator
from domino.exceptions import RunFailedException
from airflow import DAG
from airflow.models import TaskInstance


TEST_PROJECT = os.environ.get("DOMINO_TEST_PROJECT")
Expand All @@ -24,25 +24,28 @@ def test_airflow_dags():
pytest.importorskip("airflow")

from airflow import DAG
from airflow.models import TaskInstance

start_time = datetime.now()

dag = DAG(dag_id, start_date=start_time)
task = DummyOperator(
dag = DAG(dag_id, start_date=start_time, schedule=timedelta(days=1))
task = EmptyOperator(
dag=dag,
task_id='test_airflow_dags',
)

task.run()
ti = TaskInstance(task=task, execution_date=start_time)
task.execute(ti.get_template_context())
task.execute(context={})


def test_operator():
@patch("domino.airflow._operator.Domino")
def test_operator(mock_domino):
mock_client = MagicMock()
mock_client.runs_start_blocking.return_value = {"runId": "abc123", "status": "Succeeded"}
mock_client.get_run_log.return_value = []
mock_domino.return_value = mock_client

start_time = datetime.now()

dag = DAG(dag_id, start_date=start_time)
dag = DAG(dag_id, start_date=start_time, schedule=timedelta(days=1))
task = DominoOperator(
dag=dag,
task_id="test_operator",
Expand All @@ -51,15 +54,21 @@ def test_operator():
command=["python -V"],
)

task.run()
ti = TaskInstance(task=task, execution_date=start_time)
task.execute(ti.get_template_context())
task.execute(context={})

mock_domino.assert_called_once()
mock_client.runs_start_blocking.assert_called_once()


def test_operator_fail(caplog):
@patch("domino.airflow._operator.Domino")
def test_operator_fail(mock_domino):
mock_client = MagicMock()
mock_client.runs_start_blocking.side_effect = RunFailedException("Run failed")
mock_domino.return_value = mock_client

execution_dt = datetime.now()

dag = DAG(dag_id, start_date=execution_dt)
dag = DAG(dag_id, start_date=execution_dt, schedule=timedelta(days=1))
task = DominoOperator(
dag=dag,
task_id="test_operator_fail",
Expand All @@ -69,15 +78,21 @@ def test_operator_fail(caplog):
)

with pytest.raises(RunFailedException):
task.run()
ti = TaskInstance(task=task, execution_date=execution_dt)
task.execute(ti.get_template_context())
task.execute(context={})


@patch("domino.airflow._operator.Domino")
def test_operator_fail_invalid_tier(mock_domino):
mock_client = MagicMock()
mock_client.hardware_tiers_list.return_value = [
{"hardwareTier": {"name": "small"}},
{"hardwareTier": {"name": "medium"}},
]
mock_domino.return_value = mock_client

def test_operator_fail_invalid_tier(caplog):
execution_dt = datetime.now()

dag = DAG(dag_id, start_date=execution_dt)
dag = DAG(dag_id, start_date=execution_dt, schedule=timedelta(days=1))
task = DominoOperator(
dag=dag,
task_id="test_operator_fail_invalid_tier",
Expand All @@ -88,6 +103,4 @@ def test_operator_fail_invalid_tier(caplog):
)

with pytest.raises(ValueError):
task.run()
ti = TaskInstance(task=task, execution_date=execution_dt)
task.execute(ti.get_template_context())
task.execute(context={})