Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,3 +1112,114 @@ def test_reflection(self):
metadata.reflect(reflection_engine)

assert "nested_dir/table" in metadata.tables


class TestAsTable(TablesTest):
__backend__ = True

@classmethod
def define_tables(cls, metadata):
Table(
"test_as_table",
metadata,
Column("id", Integer, primary_key=True),
Column("val_int", Integer, nullable=True),
Column("val_str", String, nullable=True),
)

def test_upsert_as_table(self, connection):
table = self.tables.test_as_table

input_data = [
{"id": 1, "val_int": 10, "val_str": "a"},
{"id": 2, "val_int": None, "val_str": "b"},
{"id": 3, "val_int": 30, "val_str": None},
]

struct_type = types.StructType(
{
"id": Integer,
"val_int": types.Optional(Integer),
"val_str": types.Optional(String),
}
)
list_type = types.ListType(struct_type)

bind_param = sa.bindparam("data", type_=list_type)

upsert_stm = ydb_sa.upsert(table).from_select(
["id", "val_int", "val_str"],
sa.select(
sa.column("id", type_=Integer), sa.column("val_int", type_=Integer), sa.column("val_str", type_=String)
).select_from(sa.func.AS_TABLE(bind_param)),
)

connection.execute(upsert_stm, {"data": input_data})

rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall()
assert rows == [
(1, 10, "a"),
(2, None, "b"),
(3, 30, None),
]

def test_insert_as_table(self, connection):
table = self.tables.test_as_table

input_data = [
{"id": 4, "val_int": 40, "val_str": "d"},
{"id": 5, "val_int": None, "val_str": "e"},
]

struct_type = types.StructType(
{
"id": Integer,
"val_int": types.Optional(Integer),
"val_str": types.Optional(String),
}
)
list_type = types.ListType(struct_type)

bind_param = sa.bindparam("data", type_=list_type)

insert_stm = sa.insert(table).from_select(
["id", "val_int", "val_str"],
sa.select(
sa.column("id", type_=Integer), sa.column("val_int", type_=Integer), sa.column("val_str", type_=String)
).select_from(sa.func.AS_TABLE(bind_param)),
)

connection.execute(insert_stm, {"data": input_data})

rows = connection.execute(sa.select(table).where(table.c.id >= 4).order_by(table.c.id)).fetchall()
assert rows == [
(4, 40, "d"),
(5, None, "e"),
]

def test_upsert_from_table_reflection(self, connection):
table = self.tables.test_as_table

input_data = [
{"id": 1, "val_int": 10, "val_str": "a"},
{"id": 2, "val_int": None, "val_str": "b"},
]

struct_type = types.StructType.from_table(table)
list_type = types.ListType(struct_type)

bind_param = sa.bindparam("data", type_=list_type)

cols = [sa.column(c.name, type_=c.type) for c in table.columns]
upsert_stm = ydb_sa.upsert(table).from_select(
[c.name for c in table.columns],
sa.select(*cols).select_from(sa.func.AS_TABLE(bind_param)),
)

connection.execute(upsert_stm, {"data": input_data})

rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall()
assert rows == [
(1, 10, "a"),
(2, None, "b"),
]
27 changes: 24 additions & 3 deletions ydb_sqlalchemy/sqlalchemy/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,20 @@ def visit_ARRAY(self, type_: sa.ARRAY, **kw):
inner = self.process(type_.item_type, **kw)
return f"List<{inner}>"

def visit_optional(self, type_: types.Optional, **kw):
el = type_.element_type
if isinstance(el, type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Let's use the helper function from the sqlaclhemy:

from sqlalchemy.sql.type_api import to_instance

The same below

el = el()
inner = self.process(el, **kw)
return f"Optional<{inner}>"

def visit_struct_type(self, type_: types.StructType, **kw):
text = "Struct<"
for field, field_type in type_.fields_types:
text += f"{field}:{self.process(field_type, **kw)}"
for field, field_type in type_.fields_types.items():
type_str = self.process(field_type, **kw)
text += f"{field}:{type_str},"
if text.endswith(","):
text = text[:-1]
Comment on lines +164 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (non-blocking): Let's use join here, like

return f"Struct<{','.join(...)}>"

return text + ">"

def get_ydb_type(
Expand All @@ -167,6 +177,13 @@ def get_ydb_type(
if isinstance(type_, (sa.Text, sa.String)):
ydb_type = ydb.PrimitiveType.Utf8

elif isinstance(type_, types.Optional):
if isinstance(type_.element_type, type):
inner = type_.element_type()
else:
inner = type_.element_type
return self.get_ydb_type(inner, is_optional=True)

# Integers
elif isinstance(type_, types.UInt64):
ydb_type = ydb.PrimitiveType.Uint64
Expand Down Expand Up @@ -235,7 +252,11 @@ def get_ydb_type(
elif isinstance(type_, types.StructType):
ydb_type = ydb.StructType()
for field, field_type in type_.fields_types.items():
ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False))
if isinstance(field_type, type):
inner_type = field_type()
else:
inner_type = field_type
ydb_type.add_member(field, self.get_ydb_type(inner_type, is_optional=False))
else:
raise NotSupportedError(f"{type_} bind variables not supported")

Expand Down
80 changes: 80 additions & 0 deletions ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,83 @@ def test_ydb_types():
compiled = query.compile(dialect=dialect, compile_kwargs={"literal_binds": True})

assert str(compiled) == "Date('1996-11-19')"


def test_struct_type_generation():
dialect = YqlDialect()
type_compiler = dialect.type_compiler

# Test default (non-optional)
struct_type = types.StructType(
{
"id": sa.Integer,
"val_int": sa.Integer,
}
)
ydb_type = type_compiler.get_ydb_type(struct_type, is_optional=False)
# Keys are sorted
assert str(ydb_type) == "Struct<id:Int64,val_int:Int64>"

# Test optional
struct_type_opt = types.StructType(
{
"id": sa.Integer,
"val_int": types.Optional(sa.Integer),
}
)
ydb_type_opt = type_compiler.get_ydb_type(struct_type_opt, is_optional=False)
assert str(ydb_type_opt) == "Struct<id:Int64,val_int:Int64?>"


def test_types_compilation():
dialect = YqlDialect()

def compile_type(type_):
return dialect.type_compiler.process(type_)

assert compile_type(types.UInt64()) == "UInt64"
assert compile_type(types.UInt32()) == "UInt32"
assert compile_type(types.UInt16()) == "UInt16"
assert compile_type(types.UInt8()) == "UInt8"

assert compile_type(types.Int64()) == "Int64"
assert compile_type(types.Int32()) == "Int32"
assert compile_type(types.Int16()) == "Int32"
assert compile_type(types.Int8()) == "Int8"

assert compile_type(types.ListType(types.Int64())) == "List<Int64>"

struct = types.StructType({"a": types.Int32(), "b": types.ListType(types.Int32())})
# Ordered by key: a, b
assert compile_type(struct) == "Struct<a:Int32,b:List<Int32>>"


def test_optional_type_compilation():
dialect = YqlDialect()
type_compiler = dialect.type_compiler

def compile_type(type_):
return type_compiler.process(type_)

# Test Optional(Integer)
opt_int = types.Optional(sa.Integer)
assert compile_type(opt_int) == "Optional<Int64>"

# Test Optional(String)
opt_str = types.Optional(sa.String)
assert compile_type(opt_str) == "Optional<UTF8>"

# Test Nested Optional
opt_opt_int = types.Optional(types.Optional(sa.Integer))
assert compile_type(opt_opt_int) == "Optional<Optional<Int64>>"

# Test get_ydb_type
ydb_type = type_compiler.get_ydb_type(opt_int, is_optional=False)
import ydb

assert isinstance(ydb_type, ydb.OptionalType)
# Int64 corresponds to PrimitiveType.Int64
# Note: ydb.PrimitiveType.Int64 is an enum member, but ydb_type.item is also an instance/enum?
# get_ydb_type returns ydb.PrimitiveType.Int64 (enum) wrapped in OptionalType.
# OptionalType.item is the inner type.
assert ydb_type.item == ydb.PrimitiveType.Int64
28 changes: 26 additions & 2 deletions ydb_sqlalchemy/sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
else:
from sqlalchemy.sql.expression import ColumnElement

from sqlalchemy import ARRAY, exc, types
from sqlalchemy import ARRAY, exc, Table, types
from sqlalchemy.sql import type_api

from .datetime_types import YqlDate, YqlDateTime, YqlTimestamp, YqlDate32, YqlTimestamp64, YqlDateTime64 # noqa: F401
Expand Down Expand Up @@ -116,12 +116,36 @@ def __hash__(self):
return hash(tuple(self.items()))


class Optional(types.TypeEngine):
__visit_name__ = "optional"

def __init__(self, element_type: Union[Type[types.TypeEngine], types.TypeEngine]):
self.element_type = element_type


class StructType(types.TypeEngine[Mapping[str, Any]]):
__visit_name__ = "struct_type"

def __init__(self, fields_types: Mapping[str, Union[Type[types.TypeEngine], Type[types.TypeDecorator]]]):
def __init__(
self,
fields_types: Mapping[
str,
Union[Type[types.TypeEngine], types.TypeEngine, Optional],
],
):
self.fields_types = HashableDict(dict(sorted(fields_types.items())))

@classmethod
def from_table(cls, table: Table):
fields = {}
for col in table.columns:
t = col.type
if col.nullable:
fields[col.name] = Optional(t)
else:
fields[col.name] = t
return cls(fields)

@property
def python_type(self):
return dict
Expand Down