3030from dataclasses import dataclass
3131from typing import TYPE_CHECKING , Any
3232
33- from sqlalchemy import and_ , column , func , inspect , select , table , text
33+ from sqlalchemy import Select , and_ , column , func , inspect , select , table , text
3434from sqlalchemy .exc import OperationalError , ProgrammingError
3535from sqlalchemy .ext .compiler import compiles
3636from sqlalchemy .orm import aliased
4747
4848if TYPE_CHECKING :
4949 from pendulum import DateTime
50- from sqlalchemy .orm import Query , Session
50+ from sqlalchemy .orm import Session
5151
5252 from airflow .models import Base
5353
@@ -163,17 +163,17 @@ def readable_config(self):
163163config_dict : dict [str , _TableConfig ] = {x .orm_model .name : x for x in sorted (config_list )}
164164
165165
166- def _check_for_rows (* , query : Query , print_rows : bool = False ) -> int :
167- num_entities = query . count ()
166+ def _check_for_rows (* , query : Select , print_rows : bool = False , session : Session ) -> int :
167+ num_entities = session . execute ( select ( func . count ()). select_from ( query . subquery ())). scalar ()
168168 print (f"Found { num_entities } rows meeting deletion criteria." )
169169 if not print_rows or num_entities == 0 :
170170 return num_entities
171171
172172 max_rows_to_print = 100
173173 print (f"Printing first { max_rows_to_print } rows." )
174174 logger .debug ("print entities query: %s" , query )
175- for entry in query .limit (max_rows_to_print ):
176- print (entry .__dict__ )
175+ for entry in session . execute ( query .limit (max_rows_to_print ) ):
176+ print (entry ._asdict () )
177177 return num_entities
178178
179179
@@ -193,7 +193,7 @@ def _dump_table_to_file(*, target_table: str, file_path: str, export_format: str
193193
194194
195195def _do_delete (
196- * , query : Query , orm_model : Base , skip_archive : bool , session : Session , batch_size : int | None
196+ * , query : Select , orm_model : Base , skip_archive : bool , session : Session , batch_size : int | None
197197) -> None :
198198 import itertools
199199 import re
@@ -204,7 +204,7 @@ def _do_delete(
204204
205205 while True :
206206 limited_query = query .limit (batch_size ) if batch_size else query
207- if limited_query . count () == 0 : # nothing left to delete
207+ if session . execute ( select ( func . count ()). select_from ( limited_query . subquery ())). scalar () == 0 :
208208 break
209209
210210 batch_no = next (batch_counter )
@@ -233,7 +233,7 @@ def _do_delete(
233233 logger .debug ("insert statement:\n %s" , insert_stm .compile ())
234234 session .execute (insert_stm )
235235 else :
236- stmt = CreateTableAs (target_table_name , limited_query . selectable )
236+ stmt = CreateTableAs (target_table_name , limited_query )
237237 logger .debug ("ctas query:\n %s" , stmt .compile ())
238238 session .execute (stmt )
239239 session .commit ()
@@ -309,10 +309,10 @@ def _build_query(
309309 clean_before_timestamp : DateTime ,
310310 session : Session ,
311311 ** kwargs ,
312- ) -> Query :
312+ ) -> Select :
313313 base_table_alias = "base"
314314 base_table = aliased (orm_model , name = base_table_alias )
315- query = session . query ( base_table ). with_entities ( text (f"{ base_table_alias } .*" ))
315+ query = select ( text (f"{ base_table_alias } .*" )). select_from ( base_table )
316316 base_table_recency_col = base_table .c [recency_column .name ]
317317 conditions = [base_table_recency_col < clean_before_timestamp ]
318318 if keep_last :
@@ -325,7 +325,7 @@ def _build_query(
325325 max_date_colname = max_date_col_name ,
326326 session = session ,
327327 )
328- query = query .select_from ( base_table ). outerjoin (
328+ query = query .outerjoin (
329329 subquery ,
330330 and_ (
331331 * [base_table .c [x ] == subquery .c [x ] for x in keep_last_group_by ], # type: ignore[attr-defined]
@@ -364,9 +364,9 @@ def _cleanup_table(
364364 clean_before_timestamp = clean_before_timestamp ,
365365 session = session ,
366366 )
367- logger .debug ("old rows query:\n %s" , query .selectable . compile ())
367+ logger .debug ("old rows query:\n %s" , query .compile ())
368368 print (f"Checking table { orm_model .name } " )
369- num_rows = _check_for_rows (query = query , print_rows = False )
369+ num_rows = _check_for_rows (query = query , print_rows = False , session = session )
370370
371371 if num_rows and not dry_run :
372372 _do_delete (
0 commit comments