From 879fa46f802bbc330dd6a2fc4618d1a92b423c86 Mon Sep 17 00:00:00 2001 From: Alexey Kinev Date: Fri, 27 Dec 2024 13:34:13 +0100 Subject: [PATCH] Fix: proper escape parameters for Endpoints --- dingolytics/api/endpoints.py | 76 ++++++++++++++++++++++++++++++++---- redash/utils/__init__.py | 2 +- 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/dingolytics/api/endpoints.py b/dingolytics/api/endpoints.py index 3a715ba1..6efa2b40 100644 --- a/dingolytics/api/endpoints.py +++ b/dingolytics/api/endpoints.py @@ -1,9 +1,11 @@ import hmac import json import logging +import re +from datetime import datetime, date +from typing import Dict, Any -from flask import request, url_for -from flask_restful import abort +from flask import abort, request, url_for from redash import models from redash.settings import get_settings @@ -32,6 +34,60 @@ def _serialize(o: models.Query) -> dict[str, object]: } +def _escape_string(value: str) -> str: + """Escape special characters in [ClickHouse] SQL string literals.""" + escaped = value.replace("\\", "\\\\") + escaped = escaped.replace("'", "\\'") + escaped = escaped.replace("\n", "\\n") + escaped = escaped.replace("\t", "\\t") + escaped = escaped.replace("\b", "\\b") + escaped = escaped.replace("\f", "\\f") + escaped = escaped.replace("\r", "\\r") + escaped = escaped.replace("\0", "\\0") + return "'%s'" % escaped + + +def _format_value(value: Any) -> str: + """Format different types of values for [ClickHouse] SQL.""" + if value is None: + return "NULL" + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, (datetime, date)): + return f"'{value.isoformat()}'" + elif isinstance(value, (list, tuple)): + return f"[{','.join(_format_value(v) for v in value)}]" + else: + return _escape_string(str(value)) + + +def _validate_identifier(identifier: str) -> bool: + """Validate that the identifier is safe to use in ClickHouse SQL.""" + return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier)) + + +def _parameters_from_request(args: Dict[str, Any]) -> Dict[str, str]: + """ + Extract and escape parameters from Flask request arguments. + Parameters must start with 'p_' prefix. Raises `ValueError` + if parameter name contains invalid characters. + """ + parameters = {} + + for key, value in args.items(): + if key.startswith("p_"): + param_name = key[2:] # Remove 'p_' prefix + + if not _validate_identifier(param_name): + raise ValueError(f"Invalid parameter name: {param_name}") + + parameters[param_name] = _format_value(value) + + return parameters + + class EndpointDetailsResource(BaseResource): @require_permission("list_data_sources") def get(self, endpoint_id): @@ -62,7 +118,7 @@ def get(self) -> dict[str, object]: class EndpointPublicResultsResource(BaseResource): - decorators = BaseResource.decorators + [csp_allows_embeding] + decorators = [csp_allows_embeding] def get(self, endpoint_id: int, token: str) -> dict[str, object]: endpoint: models.Query = get_object_or_404( @@ -75,12 +131,17 @@ def get(self, endpoint_id: int, token: str) -> dict[str, object]: if not hmac.compare_digest(token, endpoint.api_key): abort(403) + try: + args = _parameters_from_request(request.args) + except ValueError as exc: + abort(400, "Failed to parse parameters: %s" % str(exc)) + parameterized = endpoint.parameterized - parameterized.apply(collect_parameters_from_request(request.args)) + parameterized.apply(args) + sql = parameterized.text query_runner: BaseSQLQueryRunner = endpoint.data_source.query_runner - query_sql = parameterized.text - result_str, error = query_runner.run_query(query_sql, user=None) + result_str, error = query_runner.run_query(sql, user=None) # Handle query error if error: @@ -103,6 +164,7 @@ def get(self, endpoint_id: int, token: str) -> dict[str, object]: if len(rows) > 1: logger.warning( "One row expected from query for endpoint_id=%s, " - "but received %d rows", endpoint_id, len(rows)) + "but received %d rows", endpoint_id, len(rows) + ) return rows[0] diff --git a/redash/utils/__init__.py b/redash/utils/__init__.py index d9b82311..4942e386 100644 --- a/redash/utils/__init__.py +++ b/redash/utils/__init__.py @@ -238,7 +238,7 @@ def writerows(self, rows): self.writerow(row) -def collect_parameters_from_request(args): +def collect_parameters_from_request(args: dict[str, str]) -> dict[str, str]: parameters = {} for k, v in args.items(): if k.startswith("p_"):