From 4af545f98cd88471037068e61080c6a643867985 Mon Sep 17 00:00:00 2001 From: Glandos Date: Wed, 26 Feb 2025 23:42:23 +0100 Subject: [PATCH 1/2] first support for tsvector --- pgcopy/copy.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 7b81e7e..3ac914d 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -115,6 +115,43 @@ def uuid_formatter(guid): return "i2Q", (16, (guid.int >> 64) & MAX_INT64, guid.int & MAX_INT64) +def tsvector_position_parser(position): + try: + return int(position) + except ValueError: + # No extra validation, just imagine that we have a weight + weight = position[-1].upper() + + # 68 is 'D', it goes increasingly by step of 0x40 from 'D' to 'A' + offset = 0x4000 * (68 - ord(weight)) + + return int(position[:-1]) + offset + + + +def tsvector_formatter(vector): + """ + See https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/adt/tsvector.c;h=650be842f28febbd7c54d17955895bc3e4f108db;hb=HEAD#l397 + """ + fmt = ["I"] # Number of lexemes + data = [len(vector)] + + for lexeme, positions in vector: + # Lexeme as null-terminated (so length + 1) + # Number of positions (1 short) + # For each position: value (n shorts) + lexeme = lexeme.encode() + fmt.append('%ss%sH' % (len(lexeme) + 1, len(positions) + 1)) + data.extend([ + lexeme, + len(positions), + *[tsvector_position_parser(position) for position in positions] + ]) + fmt = "".join(fmt) + size = struct.calcsize(">" + fmt) + return "I" + fmt, (size, *data) + + type_formatters = { "bool": simple_formatter("?"), "int2": simple_formatter("h"), @@ -134,6 +171,7 @@ def uuid_formatter(guid): "timestamptz": timestamp, "numeric": numeric, "uuid": uuid_formatter, + "tsvector": tsvector_formatter, } From 1a352ee15abd5c18589ebe7f093695b35daa017e Mon Sep 17 00:00:00 2001 From: Glandos Date: Sat, 1 Mar 2025 22:48:33 +0100 Subject: [PATCH 2/2] encode lexemes with the correct encoding --- pgcopy/copy.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 3ac914d..9261133 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -140,7 +140,6 @@ def tsvector_formatter(vector): # Lexeme as null-terminated (so length + 1) # Number of positions (1 short) # For each position: value (n shorts) - lexeme = lexeme.encode() fmt.append('%ss%sH' % (len(lexeme) + 1, len(positions) + 1)) data.extend([ lexeme, @@ -240,16 +239,23 @@ def _maxsize(v): def encode(att, encoding, formatter): is_text_type = att.type_name in ("varchar", "text", "json") is_enum_type = att.type_category == "E" - if not (is_text_type or is_enum_type): - return formatter - - def _encode(v): - try: - encf = v.encode - except AttributeError: + if att.type_name == "tsvector": + def _encode(v): + try: + v = [(lexeme.encode(encoding), position) for (lexeme, position) in v] + except AttributeError: + pass return formatter(v) - else: - return formatter(encf(encoding)) + elif not (is_text_type or is_enum_type): + return formatter + else: + def _encode(v): + try: + encf = v.encode + except AttributeError: + return formatter(v) + else: + return formatter(encf(encoding)) return _encode