Skip to content

Commit c014c40

Browse files
[store] sanitize txtai metadata persistence (closes #3)
1 parent 74789ff commit c014c40

File tree

4 files changed

+278
-23
lines changed

4 files changed

+278
-23
lines changed

pave/stores/txtai_store.py

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
from __future__ import annotations
55
import os, json, operator
66
from datetime import datetime
7-
from typing import Dict, Iterable, List, Any
7+
from typing import Any, Dict, Iterable, List, Optional
88
from threading import Lock
99
from contextlib import contextmanager
1010
from txtai.embeddings import Embeddings
1111
from pave.stores.base import BaseStore, Record
1212
from pave.config import CFG as c, LOG as log
1313

1414
_LOCKS : dict[str, Lock] = {}
15+
_SQL_TRANS = str.maketrans({
16+
";": " ",
17+
'"': " ",
18+
"`": " ",
19+
"\\": " ",
20+
"\x00": "",
21+
})
22+
1523
def get_lock(key: str) -> Lock:
1624
if key not in _LOCKS:
1725
_LOCKS[key] = Lock()
@@ -236,20 +244,22 @@ def index_records(self, tenant: str, collection: str, docid: str,
236244

237245
md["docid"] = docid
238246
try:
239-
meta_json = json.dumps(md, ensure_ascii=False)
240-
md = json.loads(meta_json)
241-
except:
242-
md = {}
247+
safe_meta = self._sanit_meta_dict(md)
248+
meta_json = json.dumps(safe_meta, ensure_ascii=False)
249+
except Exception:
250+
safe_meta = {}
243251
meta_json = ""
244252

245253
rid = str(rid)
246254
txt = str(txt)
247255
if not rid.startswith(f"{docid}::"):
248256
rid = f"{docid}::{rid}"
249257

250-
meta_side[rid] = md
258+
md_for_index = {k: v for k, v in safe_meta.items() if k != "text"}
259+
260+
meta_side[rid] = safe_meta
251261
record_ids.append(rid)
252-
prepared.append((rid, {"text":txt, **md}, meta_json))
262+
prepared.append((rid, {"text": txt, **md_for_index}, meta_json))
253263

254264
self._save_chunk_text(tenant, collection, rid, txt)
255265
assert txt == (self._load_chunk_text(tenant, collection, rid) or "")
@@ -280,10 +290,15 @@ def _matches_filters(m: Dict[str, Any],
280290
if not filters:
281291
return True
282292

283-
def match(have: Any, cond: str) -> bool:
293+
def match(have: Any, cond: Any) -> bool:
284294
if have is None:
285295
return False
286-
s = str(cond)
296+
if isinstance(have, (list, tuple, set)):
297+
return any(match(item, cond) for item in have)
298+
if isinstance(cond, str):
299+
s = TxtaiStore._sanit_sql(cond)
300+
else:
301+
s = str(cond)
287302
hv = str(have)
288303
# Numeric/date ops
289304
for op in (">=", "<=", "!=", ">", "<"):
@@ -313,7 +328,7 @@ def match(have: Any, cond: str) -> bool:
313328
return hv == s
314329

315330
for k, vals in filters.items():
316-
if not any(match(m.get(k), v) for v in vals):
331+
if not any(match(TxtaiStore._lookup_meta(m, k), v) for v in vals):
317332
return False
318333
return True
319334

@@ -325,6 +340,9 @@ def _split_filters(filters: dict[str, Any] | None) -> tuple[dict, dict]:
325340

326341
pre_f, pos_f = {}, {}
327342
for key, vals in (filters or {}).items():
343+
safe_key = TxtaiStore._sanit_field(key)
344+
if not safe_key:
345+
continue
328346
if not isinstance(vals, list):
329347
vals = [vals]
330348
exacts, extended = [], []
@@ -338,12 +356,68 @@ def _split_filters(filters: dict[str, Any] | None) -> tuple[dict, dict]:
338356
else:
339357
exacts.append(v)
340358
if exacts:
341-
pre_f[key] = exacts
359+
pre_f[safe_key] = exacts
342360
if extended:
343-
pos_f[key] = extended
361+
pos_f[safe_key] = extended
344362
log.debug(f"after split: PRE {pre_f} POS {pos_f}")
345363
return pre_f, pos_f
346364

365+
@staticmethod
366+
def _lookup_meta(meta: Dict[str, Any] | None, key: str) -> Any:
367+
if not meta:
368+
return None
369+
if key in meta:
370+
return meta.get(key)
371+
for raw_key, value in meta.items():
372+
if TxtaiStore._sanit_field(raw_key) == key:
373+
return value
374+
return None
375+
376+
@staticmethod
377+
def _sanit_meta_value(value: Any) -> Any:
378+
if isinstance(value, dict):
379+
return TxtaiStore._sanit_meta_dict(value)
380+
if isinstance(value, (list, tuple, set)):
381+
return [TxtaiStore._sanit_meta_value(v) for v in value]
382+
if isinstance(value, (int, float, bool)) or value is None:
383+
return value
384+
return TxtaiStore._sanit_sql(value)
385+
386+
@staticmethod
387+
def _sanit_meta_dict(meta: Dict[str, Any] | None) -> Dict[str, Any]:
388+
safe: Dict[str, Any] = {}
389+
if not isinstance(meta, dict):
390+
return safe
391+
for raw_key, raw_value in meta.items():
392+
safe_key = TxtaiStore._sanit_field(raw_key)
393+
if not safe_key or safe_key == "text":
394+
continue
395+
safe[safe_key] = TxtaiStore._sanit_meta_value(raw_value)
396+
return safe
397+
398+
@staticmethod
399+
def _sanit_sql(value: Any, *, max_len: Optional[int] = None) -> str:
400+
if value is None:
401+
return ""
402+
text = str(value).translate(_SQL_TRANS)
403+
for token in ("--", "/*", "*/"):
404+
if token in text:
405+
text = text.split(token, 1)[0]
406+
text = text.strip()
407+
if max_len is not None and max_len > 0 and len(text) > max_len:
408+
text = text[:max_len]
409+
return text.replace("'", "''")
410+
411+
@staticmethod
412+
def _sanit_field(name: Any) -> str:
413+
if not isinstance(name, str):
414+
name = str(name)
415+
safe = []
416+
for ch in name:
417+
if ch.isalnum() or ch in {"_", "-"}:
418+
safe.append(ch)
419+
return "".join(safe)
420+
347421
@staticmethod
348422
def _build_sql(query: str, k: int, filters: dict[str, Any], columns: list[str],
349423
with_similarity: bool = True, avoid_duplicates = True) -> str:
@@ -356,14 +430,23 @@ def _build_sql(query: str, k: int, filters: dict[str, Any], columns: list[str],
356430

357431
wheres = []
358432
if with_similarity and query:
359-
q_safe = query.replace("'", "''")
433+
max_len_cfg = c.get("vector_store.txtai.max_query_chars", 512)
434+
try:
435+
max_len = int(max_len_cfg)
436+
except (TypeError, ValueError):
437+
max_len = 512
438+
limit = max_len if max_len > 0 else None
439+
q_safe = TxtaiStore._sanit_sql(query, max_len=limit)
360440
wheres.append(f"similar('{q_safe}')")
361441

362442
for key, vals in filters.items():
443+
safe_key = TxtaiStore._sanit_field(key)
444+
if not safe_key:
445+
continue
363446
ors = []
364447
for v in vals:
365-
safe_v = str(v).replace("'", "''")
366-
ors.append(f"[{key}] = '{safe_v}'")
448+
safe_v = TxtaiStore._sanit_sql(v)
449+
ors.append(f"[{safe_key}] = '{safe_v}'")
367450
or_safe = " OR ".join(ors)
368451
wheres.append(f"({or_safe})")
369452

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
11
# (C) 2025 Rodrigo Rodrigues da Silva <rodrigopitanga@posteo.net>
22
# SPDX-License-Identifier: GPL-3.0-or-later
33

4+
import sys
5+
import types
6+
7+
if "txtai.embeddings" not in sys.modules:
8+
txtai_stub = types.ModuleType("txtai")
9+
embeddings_stub = types.ModuleType("txtai.embeddings")
10+
11+
class _StubEmbeddings: # pragma: no cover - stub for optional dependency
12+
def __init__(self, *args, **kwargs):
13+
pass
14+
15+
embeddings_stub.Embeddings = _StubEmbeddings
16+
txtai_stub.embeddings = embeddings_stub
17+
sys.modules.setdefault("txtai", txtai_stub)
18+
sys.modules.setdefault("txtai.embeddings", embeddings_stub)
19+
420
import pytest
521
from fastapi.testclient import TestClient
622
from pave.config import get_cfg, reload_cfg
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# (C) 2025 Rodrigo Rodrigues da Silva <rodrigopitanga@posteo.net>
2+
# SPDX-License-Identifier: GPL-3.0-or-later
3+
4+
import json
5+
6+
import pytest
7+
8+
from pave.stores import txtai_store as store_mod
9+
from pave.stores.txtai_store import TxtaiStore
10+
from pave.config import get_cfg
11+
from utils import FakeEmbeddings
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def _fake_embeddings(monkeypatch):
16+
monkeypatch.setattr(store_mod, "Embeddings", FakeEmbeddings, raising=True)
17+
18+
19+
@pytest.fixture()
20+
def store():
21+
return TxtaiStore()
22+
23+
24+
def _extract_similarity_term(sql: str) -> str:
25+
marker = "similar('"
26+
if marker not in sql:
27+
raise AssertionError(f"similar() clause missing in SQL: {sql!r}")
28+
rest = sql.split(marker, 1)[1]
29+
return rest.split("')", 1)[0]
30+
31+
32+
def test_build_sql_sanitizes_similarity_term(store):
33+
raw_query = "foo'; DROP TABLE users; -- comment"
34+
sql = store._build_sql(raw_query, 5, {}, ["id", "text"])
35+
term = _extract_similarity_term(sql)
36+
37+
# injection primitives are stripped or neutralised
38+
assert ";" not in term
39+
assert "--" not in term
40+
# original alpha characters remain so search still works
41+
assert "foo" in term
42+
43+
44+
def test_build_sql_sanitizes_filter_values(store):
45+
filters = {"lang": ["en'; DELETE FROM x;"], "tags": ['alpha"beta']}
46+
sql = store._build_sql("foo", 5, filters, ["id", "text"])
47+
48+
# filter clause should not leak dangerous characters
49+
assert ";" not in sql
50+
assert '"' not in sql
51+
assert "--" not in sql
52+
53+
54+
def test_build_sql_normalises_filter_keys(store):
55+
filters = {"lang]; DROP": ["en"], 123: ["x"]}
56+
sql = store._build_sql("foo", 5, filters, ["id"])
57+
assert "[langDROP]" in sql
58+
assert "[123]" in sql
59+
60+
61+
def test_build_sql_applies_query_length_limit(store):
62+
cfg = get_cfg()
63+
snapshot = cfg.snapshot()
64+
try:
65+
cfg.set("vector_store.txtai.max_query_chars", 8)
66+
sql = store._build_sql("abcdefghijklmno", 5, {}, ["id"])
67+
term = _extract_similarity_term(sql)
68+
69+
# collapse the doubled quotes to measure the original payload length
70+
collapsed = term.replace("''", "'")
71+
assert len(collapsed) == 8
72+
finally:
73+
cfg.replace(data=snapshot)
74+
75+
76+
def test_search_handles_special_characters(store):
77+
tenant, collection = "tenant", "coll"
78+
store.load_or_init(tenant, collection)
79+
80+
records = [("r1", "hello world", {"lang": "en"})]
81+
store.index_records(tenant, collection, "doc", records)
82+
83+
hits = store.search(tenant, collection, "world; -- comment", k=5)
84+
assert hits
85+
assert hits[0]["id"].endswith("::r1")
86+
87+
88+
def test_round_trip_with_weird_metadata_field(store):
89+
tenant, collection = "tenant", "coll"
90+
store.load_or_init(tenant, collection)
91+
92+
weird_key = "meta;`DROP"
93+
weird_value = "val'u"
94+
records = [("r2", "strange world", {weird_key: weird_value})]
95+
store.index_records(tenant, collection, "doc2", records)
96+
97+
filters = {weird_key: weird_value}
98+
hits = store.search(tenant, collection, "strange", k=5, filters=filters)
99+
100+
assert hits
101+
assert hits[0]["id"].endswith("::r2")
102+
103+
emb = store._emb[(tenant, collection)]
104+
safe_key = TxtaiStore._sanit_field(weird_key)
105+
assert emb.last_sql and f"[{safe_key}]" in emb.last_sql
106+
107+
rid = hits[0]["id"]
108+
stored_meta = store._load_meta(tenant, collection).get(rid) or {}
109+
assert safe_key in stored_meta
110+
assert stored_meta[safe_key] == TxtaiStore._sanit_sql(weird_value)
111+
112+
doc = emb._docs[rid]
113+
assert doc["meta"].get(safe_key) == TxtaiStore._sanit_sql(weird_value)
114+
serialized = json.loads(doc["meta_json"]) if doc.get("meta_json") else {}
115+
assert serialized.get(safe_key) == TxtaiStore._sanit_sql(weird_value)
116+
assert hits[0]["meta"].get(safe_key) == TxtaiStore._sanit_sql(weird_value)

0 commit comments

Comments
 (0)