Skip to content

Commit da3781a

Browse files
committed
Refactor airflow-core/src to use SQLA2
1 parent b97197a commit da3781a

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ repos:
424424
(?x)
425425
^airflow-ctl.*\.py$|
426426
^airflow-core/src/airflow/models/.*\.py$|
427+
^airflow-core/src/airflow/dag_processing/.*\.py$|
428+
^airflow-core/src/airflow/migrations/.*\.py$|
429+
^airflow-core/src/airflow/utils/.*\.py$|
427430
^airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py$|
428431
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
429432
^airflow-core/tests/unit/models/test_serialized_dag.py$|

airflow-core/src/airflow/dag_processing/bundles/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _extract_and_sign_template(bundle_name: str) -> tuple[str | None, dict]:
230230
self.log.debug("Signed URL template for bundle %s", bundle_name)
231231
return new_template_, new_params_
232232

233-
stored = {b.name: b for b in session.query(DagBundleModel).all()}
233+
stored = {b.name: b for b in session.scalars(select(DagBundleModel)).all()}
234234
bundle_to_team = {
235235
bundle.name: bundle.teams[0].name if len(bundle.teams) == 1 else None
236236
for bundle in stored.values()

airflow-core/src/airflow/migrations/versions/0015_2_9_0_update_trigger_kwargs_type.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import sqlalchemy as sa
3434
from alembic import context, op
35+
from sqlalchemy import select
3536
from sqlalchemy.orm import lazyload
3637

3738
from airflow.models.trigger import Trigger
@@ -60,7 +61,7 @@ def upgrade():
6061
if not context.is_offline_mode():
6162
session = get_session()
6263
try:
63-
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
64+
for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))):
6465
trigger.kwargs = trigger.kwargs
6566
session.commit()
6667
finally:
@@ -81,7 +82,7 @@ def downgrade():
8182
else:
8283
session = get_session()
8384
try:
84-
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
85+
for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))):
8586
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
8687
session.commit()
8788
finally:

airflow-core/src/airflow/utils/db_cleanup.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from dataclasses import dataclass
3131
from typing import TYPE_CHECKING, Any
3232

33-
from sqlalchemy import and_, column, func, inspect, select, table, text
33+
from sqlalchemy import Select, and_, column, func, inspect, select, table, text
3434
from sqlalchemy.exc import OperationalError, ProgrammingError
3535
from sqlalchemy.ext.compiler import compiles
3636
from sqlalchemy.orm import aliased
@@ -47,7 +47,7 @@
4747

4848
if TYPE_CHECKING:
4949
from pendulum import DateTime
50-
from sqlalchemy.orm import Query, Session
50+
from sqlalchemy.orm import Session
5151

5252
from airflow.models import Base
5353

@@ -163,17 +163,17 @@ def readable_config(self):
163163
config_dict: dict[str, _TableConfig] = {x.orm_model.name: x for x in sorted(config_list)}
164164

165165

166-
def _check_for_rows(*, query: Query, print_rows: bool = False) -> int:
167-
num_entities = query.count()
166+
def _check_for_rows(*, query: Select, print_rows: bool = False, session: Session) -> int:
167+
num_entities = session.execute(select(func.count()).select_from(query.subquery())).scalar()
168168
print(f"Found {num_entities} rows meeting deletion criteria.")
169169
if not print_rows or num_entities == 0:
170170
return num_entities
171171

172172
max_rows_to_print = 100
173173
print(f"Printing first {max_rows_to_print} rows.")
174174
logger.debug("print entities query: %s", query)
175-
for entry in query.limit(max_rows_to_print):
176-
print(entry.__dict__)
175+
for entry in session.execute(query.limit(max_rows_to_print)):
176+
print(entry._asdict())
177177
return num_entities
178178

179179

@@ -193,7 +193,7 @@ def _dump_table_to_file(*, target_table: str, file_path: str, export_format: str
193193

194194

195195
def _do_delete(
196-
*, query: Query, orm_model: Base, skip_archive: bool, session: Session, batch_size: int | None
196+
*, query: Select, orm_model: Base, skip_archive: bool, session: Session, batch_size: int | None
197197
) -> None:
198198
import itertools
199199
import re
@@ -204,7 +204,7 @@ def _do_delete(
204204

205205
while True:
206206
limited_query = query.limit(batch_size) if batch_size else query
207-
if limited_query.count() == 0: # nothing left to delete
207+
if session.execute(select(func.count()).select_from(limited_query.subquery())).scalar() == 0:
208208
break
209209

210210
batch_no = next(batch_counter)
@@ -233,7 +233,7 @@ def _do_delete(
233233
logger.debug("insert statement:\n%s", insert_stm.compile())
234234
session.execute(insert_stm)
235235
else:
236-
stmt = CreateTableAs(target_table_name, limited_query.selectable)
236+
stmt = CreateTableAs(target_table_name, limited_query)
237237
logger.debug("ctas query:\n%s", stmt.compile())
238238
session.execute(stmt)
239239
session.commit()
@@ -309,10 +309,10 @@ def _build_query(
309309
clean_before_timestamp: DateTime,
310310
session: Session,
311311
**kwargs,
312-
) -> Query:
312+
) -> Select:
313313
base_table_alias = "base"
314314
base_table = aliased(orm_model, name=base_table_alias)
315-
query = session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
315+
query = select(text(f"{base_table_alias}.*")).select_from(base_table)
316316
base_table_recency_col = base_table.c[recency_column.name]
317317
conditions = [base_table_recency_col < clean_before_timestamp]
318318
if keep_last:
@@ -325,7 +325,7 @@ def _build_query(
325325
max_date_colname=max_date_col_name,
326326
session=session,
327327
)
328-
query = query.select_from(base_table).outerjoin(
328+
query = query.outerjoin(
329329
subquery,
330330
and_(
331331
*[base_table.c[x] == subquery.c[x] for x in keep_last_group_by], # type: ignore[attr-defined]
@@ -364,9 +364,9 @@ def _cleanup_table(
364364
clean_before_timestamp=clean_before_timestamp,
365365
session=session,
366366
)
367-
logger.debug("old rows query:\n%s", query.selectable.compile())
367+
logger.debug("old rows query:\n%s", query.compile())
368368
print(f"Checking table {orm_model.name}")
369-
num_rows = _check_for_rows(query=query, print_rows=False)
369+
num_rows = _check_for_rows(query=query, print_rows=False, session=session)
370370

371371
if num_rows and not dry_run:
372372
_do_delete(

0 commit comments

Comments
 (0)