Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions dingolytics/api/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hmac
import json
import logging
import re
from datetime import datetime, date
from typing import Dict, Any

from flask import request, url_for
from flask_restful import abort
from flask import abort, request, url_for

from redash import models
from redash.settings import get_settings
Expand Down Expand Up @@ -32,6 +34,60 @@ def _serialize(o: models.Query) -> dict[str, object]:
}


def _escape_string(value: str) -> str:
"""Escape special characters in [ClickHouse] SQL string literals."""
escaped = value.replace("\\", "\\\\")
escaped = escaped.replace("'", "\\'")
escaped = escaped.replace("\n", "\\n")
escaped = escaped.replace("\t", "\\t")
escaped = escaped.replace("\b", "\\b")
escaped = escaped.replace("\f", "\\f")
escaped = escaped.replace("\r", "\\r")
escaped = escaped.replace("\0", "\\0")
return "'%s'" % escaped


def _format_value(value: Any) -> str:
"""Format different types of values for [ClickHouse] SQL."""
if value is None:
return "NULL"
elif isinstance(value, (int, float)):
return str(value)
elif isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, (datetime, date)):
return f"'{value.isoformat()}'"
elif isinstance(value, (list, tuple)):
return f"[{','.join(_format_value(v) for v in value)}]"
else:
return _escape_string(str(value))


def _validate_identifier(identifier: str) -> bool:
"""Validate that the identifier is safe to use in ClickHouse SQL."""
return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier))


def _parameters_from_request(args: Dict[str, Any]) -> Dict[str, str]:
"""
Extract and escape parameters from Flask request arguments.
Parameters must start with 'p_' prefix. Raises `ValueError`
if parameter name contains invalid characters.
"""
parameters = {}

for key, value in args.items():
if key.startswith("p_"):
param_name = key[2:] # Remove 'p_' prefix

if not _validate_identifier(param_name):
raise ValueError(f"Invalid parameter name: {param_name}")

parameters[param_name] = _format_value(value)

return parameters


class EndpointDetailsResource(BaseResource):
@require_permission("list_data_sources")
def get(self, endpoint_id):
Expand Down Expand Up @@ -62,7 +118,7 @@ def get(self) -> dict[str, object]:


class EndpointPublicResultsResource(BaseResource):
decorators = BaseResource.decorators + [csp_allows_embeding]
decorators = [csp_allows_embeding]

def get(self, endpoint_id: int, token: str) -> dict[str, object]:
endpoint: models.Query = get_object_or_404(
Expand All @@ -75,12 +131,17 @@ def get(self, endpoint_id: int, token: str) -> dict[str, object]:
if not hmac.compare_digest(token, endpoint.api_key):
abort(403)

try:
args = _parameters_from_request(request.args)
except ValueError as exc:
abort(400, "Failed to parse parameters: %s" % str(exc))

parameterized = endpoint.parameterized
parameterized.apply(collect_parameters_from_request(request.args))
parameterized.apply(args)
sql = parameterized.text

query_runner: BaseSQLQueryRunner = endpoint.data_source.query_runner
query_sql = parameterized.text
result_str, error = query_runner.run_query(query_sql, user=None)
result_str, error = query_runner.run_query(sql, user=None)

# Handle query error
if error:
Expand All @@ -103,6 +164,7 @@ def get(self, endpoint_id: int, token: str) -> dict[str, object]:
if len(rows) > 1:
logger.warning(
"One row expected from query for endpoint_id=%s, "
"but received %d rows", endpoint_id, len(rows))
"but received %d rows", endpoint_id, len(rows)
)

return rows[0]
2 changes: 1 addition & 1 deletion redash/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def writerows(self, rows):
self.writerow(row)


def collect_parameters_from_request(args):
def collect_parameters_from_request(args: dict[str, str]) -> dict[str, str]:
parameters = {}
for k, v in args.items():
if k.startswith("p_"):
Expand Down
Loading