Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions pgcopy/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,42 @@ def uuid_formatter(guid):
return "i2Q", (16, (guid.int >> 64) & MAX_INT64, guid.int & MAX_INT64)


def tsvector_position_parser(position):
try:
return int(position)
except ValueError:
# No extra validation, just imagine that we have a weight
weight = position[-1].upper()

# 68 is 'D', it goes increasingly by step of 0x40 from 'D' to 'A'
offset = 0x4000 * (68 - ord(weight))

return int(position[:-1]) + offset



def tsvector_formatter(vector):
"""
See https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/adt/tsvector.c;h=650be842f28febbd7c54d17955895bc3e4f108db;hb=HEAD#l397
"""
fmt = ["I"] # Number of lexemes
data = [len(vector)]

for lexeme, positions in vector:
# Lexeme as null-terminated (so length + 1)
# Number of positions (1 short)
# For each position: value (n shorts)
fmt.append('%ss%sH' % (len(lexeme) + 1, len(positions) + 1))
data.extend([
lexeme,
len(positions),
*[tsvector_position_parser(position) for position in positions]
])
fmt = "".join(fmt)
size = struct.calcsize(">" + fmt)
return "I" + fmt, (size, *data)


type_formatters = {
"bool": simple_formatter("?"),
"int2": simple_formatter("h"),
Expand All @@ -134,6 +170,7 @@ def uuid_formatter(guid):
"timestamptz": timestamp,
"numeric": numeric,
"uuid": uuid_formatter,
"tsvector": tsvector_formatter,
}


Expand Down Expand Up @@ -202,16 +239,23 @@ def _maxsize(v):
def encode(att, encoding, formatter):
is_text_type = att.type_name in ("varchar", "text", "json")
is_enum_type = att.type_category == "E"
if not (is_text_type or is_enum_type):
return formatter

def _encode(v):
try:
encf = v.encode
except AttributeError:
if att.type_name == "tsvector":
def _encode(v):
try:
v = [(lexeme.encode(encoding), position) for (lexeme, position) in v]
except AttributeError:
pass
return formatter(v)
else:
return formatter(encf(encoding))
elif not (is_text_type or is_enum_type):
return formatter
else:
def _encode(v):
try:
encf = v.encode
except AttributeError:
return formatter(v)
else:
return formatter(encf(encoding))

return _encode

Expand Down