diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index ffa393fd..6f42bace 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -20,6 +20,7 @@ ColumnType, alter_column, alter_table, + format_server_default, format_type, ) from google.api_core.client_options import ClientOptions @@ -572,7 +573,7 @@ def get_column_specification(self, column, **kwargs): elif has_identity: colspec += " " + self.process(column.identity) elif default is not None: - colspec += " DEFAULT (" + default + ")" + colspec += f" DEFAULT {default}" elif hasattr(column, "computed") and column.computed is not None: colspec += " " + self.process(column.computed) @@ -583,6 +584,13 @@ def get_column_specification(self, column, **kwargs): return colspec + def get_column_default_string(self, column): + default = super().get_column_default_string(column) + if default is not None: + return f"({default})" + + return default + def visit_computed_column(self, generated, **kw): """Computed column operator.""" text = "AS (%s) STORED" % self.sql_compiler.process( @@ -1852,11 +1860,14 @@ def do_execute_no_params(self, cursor, statement, context=None): def visit_column_nullable( element: "ColumnNullable", compiler: "SpannerDDLCompiler", **kw ) -> str: - return "%s %s %s %s" % ( - alter_table(compiler, element.table_name, element.schema), - alter_column(compiler, element.column_name), - format_type(compiler, element.existing_type), - "" if element.nullable else "NOT NULL", + return _format_alter_column( + compiler, + element.table_name, + element.schema, + element.column_name, + element.existing_type, + element.nullable, + element.existing_server_default, ) @@ -1865,9 +1876,34 @@ def visit_column_nullable( def visit_column_type( element: "ColumnType", compiler: "SpannerDDLCompiler", **kw ) -> str: - return "%s %s %s %s" % ( - alter_table(compiler, element.table_name, element.schema), - alter_column(compiler, element.column_name), - "%s" % format_type(compiler, element.type_), - "" if element.existing_nullable else "NOT NULL", + return _format_alter_column( + compiler, + element.table_name, + element.schema, + element.column_name, + element.type_, + element.existing_nullable, + element.existing_server_default, + ) + + +def _format_alter_column( + compiler, table_name, schema, column_name, type_, nullable, server_default +): + # Older versions of SQLAlchemy pass in a boolean to indicate whether there + # is an existing DEFAULT constraint, instead of the actual DEFAULT constraint + # expression. In those cases, we do not want to explicitly include the DEFAULT + # constraint in the expression that is generated here. + if isinstance(server_default, bool): + server_default = None + return "%s %s %s%s%s" % ( + alter_table(compiler, table_name, schema), + alter_column(compiler, column_name), + format_type(compiler, type_), + "" if nullable else " NOT NULL", + ( + "" + if server_default is None + else f" DEFAULT {format_server_default(compiler, server_default)}" + ), ) diff --git a/test/mockserver_tests/default_model.py b/test/mockserver_tests/default_model.py new file mode 100644 index 00000000..6a363c57 --- /dev/null +++ b/test/mockserver_tests/default_model.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import func +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Singer(Base): + __tablename__ = "singers" + id: Mapped[str] = mapped_column( + server_default=func.GENERATE_UUID(), primary_key=True + ) + name: Mapped[str] diff --git a/test/mockserver_tests/test_default.py b/test/mockserver_tests/test_default.py new file mode 100644 index 00000000..9b46ede0 --- /dev/null +++ b/test/mockserver_tests/test_default.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import FixedSizePool, ResultSet +from test.mockserver_tests.mock_server_test_base import MockServerTestBase, add_result +from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest + + +class TestCreateTableDefault(MockServerTestBase): + def test_create_table_with_default(self): + from test.mockserver_tests.default_model import Base + + add_result( + """SELECT true +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_SCHEMA="" AND TABLE_NAME="singers" +LIMIT 1 +""", + ResultSet(), + ) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + Base.metadata.create_all(engine) + requests = self.database_admin_service.requests + eq_(1, len(requests)) + is_instance_of(requests[0], UpdateDatabaseDdlRequest) + eq_(1, len(requests[0].statements)) + eq_( + "CREATE TABLE singers (\n" + "\tid STRING(MAX) NOT NULL DEFAULT (GENERATE_UUID()), \n" + "\tname STRING(MAX) NOT NULL\n" + ") PRIMARY KEY (id)", + requests[0].statements[0], + ) diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/test_alembic.py b/test/unit/test_alembic.py new file mode 100644 index 00000000..75e39561 --- /dev/null +++ b/test/unit/test_alembic.py @@ -0,0 +1,97 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from alembic.ddl import base as ddl_base +from google.cloud.sqlalchemy_spanner import sqlalchemy_spanner +from sqlalchemy import String, TextClause +from sqlalchemy.testing import eq_ +from sqlalchemy.testing.plugin.plugin_base import fixtures + + +class TestAlembicTest(fixtures.TestBase): + def test_visit_column_nullable_with_not_null_column(self): + ddl = sqlalchemy_spanner.visit_column_nullable( + ddl_base.ColumnNullable( + name="tbl", column_name="col", nullable=False, existing_type=String(256) + ), + sqlalchemy_spanner.SpannerDDLCompiler( + sqlalchemy_spanner.SpannerDialect(), None + ), + ) + eq_(ddl, "ALTER TABLE tbl ALTER COLUMN col STRING(256) NOT NULL") + + def test_visit_column_nullable_with_nullable_column(self): + ddl = sqlalchemy_spanner.visit_column_nullable( + ddl_base.ColumnNullable( + name="tbl", column_name="col", nullable=True, existing_type=String(256) + ), + sqlalchemy_spanner.SpannerDDLCompiler( + sqlalchemy_spanner.SpannerDialect(), None + ), + ) + eq_(ddl, "ALTER TABLE tbl ALTER COLUMN col STRING(256)") + + def test_visit_column_nullable_with_default(self): + ddl = sqlalchemy_spanner.visit_column_nullable( + ddl_base.ColumnNullable( + name="tbl", + column_name="col", + nullable=False, + existing_type=String(256), + existing_server_default=TextClause("GENERATE_UUID()"), + ), + sqlalchemy_spanner.SpannerDDLCompiler( + sqlalchemy_spanner.SpannerDialect(), None + ), + ) + eq_( + ddl, + "ALTER TABLE tbl " + "ALTER COLUMN col " + "STRING(256) NOT NULL DEFAULT (GENERATE_UUID())", + ) + + def test_visit_column_type(self): + ddl = sqlalchemy_spanner.visit_column_type( + ddl_base.ColumnType( + name="tbl", + column_name="col", + type_=String(256), + existing_nullable=True, + ), + sqlalchemy_spanner.SpannerDDLCompiler( + sqlalchemy_spanner.SpannerDialect(), None + ), + ) + eq_(ddl, "ALTER TABLE tbl ALTER COLUMN col STRING(256)") + + def test_visit_column_type_with_default(self): + ddl = sqlalchemy_spanner.visit_column_type( + ddl_base.ColumnType( + name="tbl", + column_name="col", + type_=String(256), + existing_nullable=False, + existing_server_default=TextClause("GENERATE_UUID()"), + ), + sqlalchemy_spanner.SpannerDDLCompiler( + sqlalchemy_spanner.SpannerDialect(), None + ), + ) + eq_( + ddl, + "ALTER TABLE tbl " + "ALTER COLUMN col " + "STRING(256) NOT NULL DEFAULT (GENERATE_UUID())", + )