Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 74 additions & 51 deletions app/datatables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from flask import request
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import text, func, or_, cast, String, select
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.sql import Select

# `db` will be provided by the application using this module.
Expand Down Expand Up @@ -37,13 +38,27 @@ def datatables_response(query):
query: sqlalchemy.sql.Select | str
Query to execute. May be a SQLAlchemy Select object or a raw SQL string.
"""
values = request.values

def empty_response(columns):
return {
"draw": int(values.get("draw", 1)),
"recordsTotal": 0,
"recordsFiltered": 0,
"data": [],
"columns": list(columns),
}

if isinstance(query, str):
base_sql = query.rstrip(";")
values = request.values
start = values.get("start", 0, type=int)
length = values.get("length", 20, type=int)

columns = query_columns(query)
try:
columns = query_columns(query)
except (OperationalError, ProgrammingError):
return empty_response([])

quoted_cols = [f'"{c}"' for c in columns]

search_value = values.get("search[value]")
Expand Down Expand Up @@ -79,31 +94,35 @@ def datatables_response(query):
filters.append(f"{col_expr} {like_op} :col_{idx}")

base_select = f"SELECT * FROM ({base_sql}) AS q"
total_records = db.session.scalar(
text(f"SELECT COUNT(*) FROM ({base_sql}) AS q")
)

filtered_sql = base_select
if filters:
filtered_sql += " WHERE " + " AND ".join(filters)

records_filtered = db.session.scalar(
text(f"SELECT COUNT(*) FROM ({filtered_sql}) AS sq"),
params,
)

order_idx = values.get("order[0][column]", type=int)
if order_idx is not None and 0 <= order_idx < len(columns):
col = columns[order_idx]
direction = (
"DESC" if values.get("order[0][dir]", "asc") == "desc" else "ASC"
try:
total_records = db.session.scalar(
text(f"SELECT COUNT(*) FROM ({base_sql}) AS q")
)

filtered_sql = base_select
if filters:
filtered_sql += " WHERE " + " AND ".join(filters)

records_filtered = db.session.scalar(
text(f"SELECT COUNT(*) FROM ({filtered_sql}) AS sq"),
params,
)
filtered_sql += f' ORDER BY "{col}" {direction}'

paginated_sql = filtered_sql + " LIMIT :limit OFFSET :offset"
params.update({"limit": length, "offset": start})
order_idx = values.get("order[0][column]", type=int)
if order_idx is not None and 0 <= order_idx < len(columns):
col = columns[order_idx]
direction = (
"DESC" if values.get("order[0][dir]", "asc") == "desc" else "ASC"
)
filtered_sql += f' ORDER BY "{col}" {direction}'

paginated_sql = filtered_sql + " LIMIT :limit OFFSET :offset"
params.update({"limit": length, "offset": start})

rows = db.session.execute(text(paginated_sql), params).mappings().all()
except (OperationalError, ProgrammingError):
return empty_response(columns)

rows = db.session.execute(text(paginated_sql), params).mappings().all()
data = [dict(r) for r in rows]

return {
Expand All @@ -120,34 +139,38 @@ def datatables_response(query):
# Ensure ORM selections return column mappings instead of model objects
columns = [c.key for c in query.selected_columns]
base_query = query.with_only_columns(query.selected_columns)
Comment on lines 140 to 141
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lines that access query.selected_columns (lines 140-141) are executed before the try/except block begins at line 142. If the query references a non-existent table or has other structural issues, accessing selected_columns or calling with_only_columns() could raise an OperationalError or ProgrammingError before error handling is active. Move lines 140-141 inside the try block to ensure all operations that could fail due to database issues are protected.

Copilot uses AI. Check for mistakes.
total_records = _execute_count(base_query)
try:
total_records = _execute_count(base_query)

search_value = values.get("search[value]")
if search_value:
filters = [
cast(c, String).ilike(f"%{search_value}%")
for c in query.selected_columns
]
base_query = base_query.where(or_(*filters))

for idx, col in enumerate(query.selected_columns):
val = values.get(f"columns[{idx}][search][value]")
if val:
base_query = base_query.where(cast(col, String).ilike(f"%{val}%"))

order_idx = values.get("order[0][column]")
if order_idx is not None:
col = query.selected_columns[int(order_idx)]
if values.get("order[0][dir]", "asc") == "desc":
col = col.desc()
base_query = base_query.order_by(col)

records_filtered = _execute_count(base_query)

start = values.get("start", 0, type=int)
length = values.get("length", 20, type=int)
paginated = base_query.offset(start).limit(length)
rows = db.session.execute(paginated).all()
except (OperationalError, ProgrammingError):
return empty_response(columns)

values = request.values
search_value = values.get("search[value]")
if search_value:
filters = [
cast(c, String).ilike(f"%{search_value}%") for c in query.selected_columns
]
base_query = base_query.where(or_(*filters))

for idx, col in enumerate(query.selected_columns):
val = values.get(f"columns[{idx}][search][value]")
if val:
base_query = base_query.where(cast(col, String).ilike(f"%{val}%"))

order_idx = values.get("order[0][column]")
if order_idx is not None:
col = query.selected_columns[int(order_idx)]
if values.get("order[0][dir]", "asc") == "desc":
col = col.desc()
base_query = base_query.order_by(col)

records_filtered = _execute_count(base_query)

start = values.get("start", 0, type=int)
length = values.get("length", 20, type=int)
paginated = base_query.offset(start).limit(length)
rows = db.session.execute(paginated).all()
data = [dict(r._mapping) for r in rows]

return {
Expand Down
Loading