diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 7b81e7e..9261133 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -115,6 +115,42 @@ def uuid_formatter(guid): return "i2Q", (16, (guid.int >> 64) & MAX_INT64, guid.int & MAX_INT64) +def tsvector_position_parser(position): + try: + return int(position) + except ValueError: + # No extra validation, just imagine that we have a weight + weight = position[-1].upper() + + # 68 is 'D', it goes increasingly by step of 0x40 from 'D' to 'A' + offset = 0x4000 * (68 - ord(weight)) + + return int(position[:-1]) + offset + + + +def tsvector_formatter(vector): + """ + See https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/adt/tsvector.c;h=650be842f28febbd7c54d17955895bc3e4f108db;hb=HEAD#l397 + """ + fmt = ["I"] # Number of lexemes + data = [len(vector)] + + for lexeme, positions in vector: + # Lexeme as null-terminated (so length + 1) + # Number of positions (1 short) + # For each position: value (n shorts) + fmt.append('%ss%sH' % (len(lexeme) + 1, len(positions) + 1)) + data.extend([ + lexeme, + len(positions), + *[tsvector_position_parser(position) for position in positions] + ]) + fmt = "".join(fmt) + size = struct.calcsize(">" + fmt) + return "I" + fmt, (size, *data) + + type_formatters = { "bool": simple_formatter("?"), "int2": simple_formatter("h"), @@ -134,6 +170,7 @@ def uuid_formatter(guid): "timestamptz": timestamp, "numeric": numeric, "uuid": uuid_formatter, + "tsvector": tsvector_formatter, } @@ -202,16 +239,23 @@ def _maxsize(v): def encode(att, encoding, formatter): is_text_type = att.type_name in ("varchar", "text", "json") is_enum_type = att.type_category == "E" - if not (is_text_type or is_enum_type): - return formatter - - def _encode(v): - try: - encf = v.encode - except AttributeError: + if att.type_name == "tsvector": + def _encode(v): + try: + v = [(lexeme.encode(encoding), position) for (lexeme, position) in v] + except AttributeError: + pass return formatter(v) - else: - return formatter(encf(encoding)) + elif not (is_text_type or is_enum_type): + return formatter + else: + def _encode(v): + try: + encf = v.encode + except AttributeError: + return formatter(v) + else: + return formatter(encf(encoding)) return _encode