From 8494bac50266fbad4fa06a27774f99aaab4ac27c Mon Sep 17 00:00:00 2001 From: Nikolay Makhalin Date: Fri, 19 Dec 2025 23:23:04 +0100 Subject: [PATCH 1/2] Support nullable StructType fields via Optional wrapper --- test/test_core.py | 111 +++++++++++++++++++ ydb_sqlalchemy/sqlalchemy/compiler/base.py | 27 ++++- ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py | 80 +++++++++++++ ydb_sqlalchemy/sqlalchemy/types.py | 28 ++++- 4 files changed, 241 insertions(+), 5 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 3f7d808..9a5bdff 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1112,3 +1112,114 @@ def test_reflection(self): metadata.reflect(reflection_engine) assert "nested_dir/table" in metadata.tables + + +class TestAsTable(TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test_as_table", + metadata, + Column("id", Integer, primary_key=True), + Column("val_int", Integer, nullable=True), + Column("val_str", String, nullable=True), + ) + + def test_upsert_as_table(self, connection): + table = self.tables.test_as_table + + input_data = [ + {"id": 1, "val_int": 10, "val_str": "a"}, + {"id": 2, "val_int": None, "val_str": "b"}, + {"id": 3, "val_int": 30, "val_str": None}, + ] + + struct_type = types.StructType( + { + "id": Integer, + "val_int": types.Optional(Integer), + "val_str": types.Optional(String), + } + ) + list_type = types.ListType(struct_type) + + bind_param = sa.bindparam("data", type_=list_type) + + upsert_stm = ydb_sa.upsert(table).from_select( + ["id", "val_int", "val_str"], + sa.select( + sa.column("id", type_=Integer), sa.column("val_int", type_=Integer), sa.column("val_str", type_=String) + ).select_from(sa.func.AS_TABLE(bind_param)), + ) + + connection.execute(upsert_stm, {"data": input_data}) + + rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall() + assert rows == [ + (1, 10, "a"), + (2, None, "b"), + (3, 30, None), + ] + + def test_insert_as_table(self, connection): + table = self.tables.test_as_table + + input_data = [ + {"id": 4, "val_int": 40, "val_str": "d"}, + {"id": 5, "val_int": None, "val_str": "e"}, + ] + + struct_type = types.StructType( + { + "id": Integer, + "val_int": types.Optional(Integer), + "val_str": types.Optional(String), + } + ) + list_type = types.ListType(struct_type) + + bind_param = sa.bindparam("data", type_=list_type) + + insert_stm = sa.insert(table).from_select( + ["id", "val_int", "val_str"], + sa.select( + sa.column("id", type_=Integer), sa.column("val_int", type_=Integer), sa.column("val_str", type_=String) + ).select_from(sa.func.AS_TABLE(bind_param)), + ) + + connection.execute(insert_stm, {"data": input_data}) + + rows = connection.execute(sa.select(table).where(table.c.id >= 4).order_by(table.c.id)).fetchall() + assert rows == [ + (4, 40, "d"), + (5, None, "e"), + ] + + def test_upsert_from_table_reflection(self, connection): + table = self.tables.test_as_table + + input_data = [ + {"id": 1, "val_int": 10, "val_str": "a"}, + {"id": 2, "val_int": None, "val_str": "b"}, + ] + + struct_type = types.StructType.from_table(table) + list_type = types.ListType(struct_type) + + bind_param = sa.bindparam("data", type_=list_type) + + cols = [sa.column(c.name, type_=c.type) for c in table.columns] + upsert_stm = ydb_sa.upsert(table).from_select( + [c.name for c in table.columns], + sa.select(*cols).select_from(sa.func.AS_TABLE(bind_param)), + ) + + connection.execute(upsert_stm, {"data": input_data}) + + rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall() + assert rows == [ + (1, 10, "a"), + (2, None, "b"), + ] diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/base.py b/ydb_sqlalchemy/sqlalchemy/compiler/base.py index 9833139..bd4c8a8 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/base.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/base.py @@ -152,10 +152,20 @@ def visit_ARRAY(self, type_: sa.ARRAY, **kw): inner = self.process(type_.item_type, **kw) return f"List<{inner}>" + def visit_optional(self, type_: types.Optional, **kw): + el = type_.element_type + if isinstance(el, type): + el = el() + inner = self.process(el, **kw) + return f"Optional<{inner}>" + def visit_struct_type(self, type_: types.StructType, **kw): text = "Struct<" - for field, field_type in type_.fields_types: - text += f"{field}:{self.process(field_type, **kw)}" + for field, field_type in type_.fields_types.items(): + type_str = self.process(field_type, **kw) + text += f"{field}:{type_str}," + if text.endswith(","): + text = text[:-1] return text + ">" def get_ydb_type( @@ -167,6 +177,13 @@ def get_ydb_type( if isinstance(type_, (sa.Text, sa.String)): ydb_type = ydb.PrimitiveType.Utf8 + elif isinstance(type_, types.Optional): + if isinstance(type_.element_type, type): + inner = type_.element_type() + else: + inner = type_.element_type + return self.get_ydb_type(inner, is_optional=True) + # Integers elif isinstance(type_, types.UInt64): ydb_type = ydb.PrimitiveType.Uint64 @@ -235,7 +252,11 @@ def get_ydb_type( elif isinstance(type_, types.StructType): ydb_type = ydb.StructType() for field, field_type in type_.fields_types.items(): - ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False)) + if isinstance(field_type, type): + inner_type = field_type() + else: + inner_type = field_type + ydb_type.add_member(field, self.get_ydb_type(inner_type, is_optional=False)) else: raise NotSupportedError(f"{type_} bind variables not supported") diff --git a/ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py b/ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py index 6908193..67a9530 100644 --- a/ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py +++ b/ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py @@ -35,3 +35,83 @@ def test_ydb_types(): compiled = query.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) assert str(compiled) == "Date('1996-11-19')" + + +def test_struct_type_generation(): + dialect = YqlDialect() + type_compiler = dialect.type_compiler + + # Test default (non-optional) + struct_type = types.StructType( + { + "id": sa.Integer, + "val_int": sa.Integer, + } + ) + ydb_type = type_compiler.get_ydb_type(struct_type, is_optional=False) + # Keys are sorted + assert str(ydb_type) == "Struct" + + # Test optional + struct_type_opt = types.StructType( + { + "id": sa.Integer, + "val_int": types.Optional(sa.Integer), + } + ) + ydb_type_opt = type_compiler.get_ydb_type(struct_type_opt, is_optional=False) + assert str(ydb_type_opt) == "Struct" + + +def test_types_compilation(): + dialect = YqlDialect() + + def compile_type(type_): + return dialect.type_compiler.process(type_) + + assert compile_type(types.UInt64()) == "UInt64" + assert compile_type(types.UInt32()) == "UInt32" + assert compile_type(types.UInt16()) == "UInt16" + assert compile_type(types.UInt8()) == "UInt8" + + assert compile_type(types.Int64()) == "Int64" + assert compile_type(types.Int32()) == "Int32" + assert compile_type(types.Int16()) == "Int32" + assert compile_type(types.Int8()) == "Int8" + + assert compile_type(types.ListType(types.Int64())) == "List" + + struct = types.StructType({"a": types.Int32(), "b": types.ListType(types.Int32())}) + # Ordered by key: a, b + assert compile_type(struct) == "Struct>" + + +def test_optional_type_compilation(): + dialect = YqlDialect() + type_compiler = dialect.type_compiler + + def compile_type(type_): + return type_compiler.process(type_) + + # Test Optional(Integer) + opt_int = types.Optional(sa.Integer) + assert compile_type(opt_int) == "Optional" + + # Test Optional(String) + opt_str = types.Optional(sa.String) + assert compile_type(opt_str) == "Optional" + + # Test Nested Optional + opt_opt_int = types.Optional(types.Optional(sa.Integer)) + assert compile_type(opt_opt_int) == "Optional>" + + # Test get_ydb_type + ydb_type = type_compiler.get_ydb_type(opt_int, is_optional=False) + import ydb + + assert isinstance(ydb_type, ydb.OptionalType) + # Int64 corresponds to PrimitiveType.Int64 + # Note: ydb.PrimitiveType.Int64 is an enum member, but ydb_type.item is also an instance/enum? + # get_ydb_type returns ydb.PrimitiveType.Int64 (enum) wrapped in OptionalType. + # OptionalType.item is the inner type. + assert ydb_type.item == ydb.PrimitiveType.Int64 diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 89c43d0..397f3fa 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -8,7 +8,7 @@ else: from sqlalchemy.sql.expression import ColumnElement -from sqlalchemy import ARRAY, exc, types +from sqlalchemy import ARRAY, exc, Table, types from sqlalchemy.sql import type_api from .datetime_types import YqlDate, YqlDateTime, YqlTimestamp, YqlDate32, YqlTimestamp64, YqlDateTime64 # noqa: F401 @@ -116,12 +116,36 @@ def __hash__(self): return hash(tuple(self.items())) +class Optional(types.TypeEngine): + __visit_name__ = "optional" + + def __init__(self, element_type: Union[Type[types.TypeEngine], types.TypeEngine]): + self.element_type = element_type + + class StructType(types.TypeEngine[Mapping[str, Any]]): __visit_name__ = "struct_type" - def __init__(self, fields_types: Mapping[str, Union[Type[types.TypeEngine], Type[types.TypeDecorator]]]): + def __init__( + self, + fields_types: Mapping[ + str, + Union[Type[types.TypeEngine], types.TypeEngine, Optional], + ], + ): self.fields_types = HashableDict(dict(sorted(fields_types.items()))) + @classmethod + def from_table(cls, table: Table): + fields = {} + for col in table.columns: + t = col.type + if col.nullable: + fields[col.name] = Optional(t) + else: + fields[col.name] = t + return cls(fields) + @property def python_type(self): return dict From daa4380d780c693757d39936bb7dec94ce3a61df Mon Sep 17 00:00:00 2001 From: Nikolay Makhalin Date: Wed, 24 Dec 2025 09:16:31 +0100 Subject: [PATCH 2/2] review fixes 1 --- ydb_sqlalchemy/sqlalchemy/compiler/base.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/base.py b/ydb_sqlalchemy/sqlalchemy/compiler/base.py index bd4c8a8..2d54fd8 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/base.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/base.py @@ -12,6 +12,7 @@ StrSQLTypeCompiler, selectable, ) +from sqlalchemy.sql.type_api import to_instance from typing import ( Any, Dict, @@ -153,20 +154,16 @@ def visit_ARRAY(self, type_: sa.ARRAY, **kw): return f"List<{inner}>" def visit_optional(self, type_: types.Optional, **kw): - el = type_.element_type - if isinstance(el, type): - el = el() + el = to_instance(type_.element_type) inner = self.process(el, **kw) return f"Optional<{inner}>" def visit_struct_type(self, type_: types.StructType, **kw): - text = "Struct<" + rendered_types = [] for field, field_type in type_.fields_types.items(): type_str = self.process(field_type, **kw) - text += f"{field}:{type_str}," - if text.endswith(","): - text = text[:-1] - return text + ">" + rendered_types.append(f"{field}:{type_str}") + return f"Struct<{','.join(rendered_types)}>" def get_ydb_type( self, type_: sa.types.TypeEngine, is_optional: bool @@ -178,10 +175,7 @@ def get_ydb_type( ydb_type = ydb.PrimitiveType.Utf8 elif isinstance(type_, types.Optional): - if isinstance(type_.element_type, type): - inner = type_.element_type() - else: - inner = type_.element_type + inner = to_instance(type_.element_type) return self.get_ydb_type(inner, is_optional=True) # Integers