From c6c123fe222d67faae7c5fac40acee1c07ac3e7c Mon Sep 17 00:00:00 2001 From: Walt Askew Date: Wed, 30 Jul 2025 00:02:31 +0000 Subject: [PATCH] fix: Correctly Generate DDL for ALTER COLUMN ... SET DEFAULT Alembic expects `get_column_default_string` to be implemented in order to use it for `ALTER TABLE.. ALTER COLUMN .. SET DEFAULT` DDL. In our case, this means wrapping the default value in parentheses. We implement `get_column_default_string` and have it add parentheses for use in both `CREATE TABLE` and `ALTER TABLE` DDL. Call path for alembic relying on `get_column_default_string` is here: https://github.com/sqlalchemy/alembic/blob/cd4f404358f101b2b930013c609c074baca61468/alembic/ddl/base.py#L252 https://github.com/sqlalchemy/alembic/blob/cd4f404358f101b2b930013c609c074baca61468/alembic/ddl/base.py#L315 --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 9 +++- test/mockserver_tests/default_model.py | 30 ++++++++++++ test/mockserver_tests/test_default.py | 49 +++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 test/mockserver_tests/default_model.py create mode 100644 test/mockserver_tests/test_default.py diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index d868daf9..7c15d448 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -572,7 +572,7 @@ def get_column_specification(self, column, **kwargs): elif has_identity: colspec += " " + self.process(column.identity) elif default is not None: - colspec += " DEFAULT (" + default + ")" + colspec += f" DEFAULT {default}" elif hasattr(column, "computed") and column.computed is not None: colspec += " " + self.process(column.computed) @@ -583,6 +583,13 @@ def get_column_specification(self, column, **kwargs): return colspec + def get_column_default_string(self, column): + default = super().get_column_default_string(column) + if default is not None: + return f"({default})" + + return default + def visit_computed_column(self, generated, **kw): """Computed column operator.""" text = "AS (%s) STORED" % self.sql_compiler.process( diff --git a/test/mockserver_tests/default_model.py b/test/mockserver_tests/default_model.py new file mode 100644 index 00000000..6a363c57 --- /dev/null +++ b/test/mockserver_tests/default_model.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import func +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Singer(Base): + __tablename__ = "singers" + id: Mapped[str] = mapped_column( + server_default=func.GENERATE_UUID(), primary_key=True + ) + name: Mapped[str] diff --git a/test/mockserver_tests/test_default.py b/test/mockserver_tests/test_default.py new file mode 100644 index 00000000..9b46ede0 --- /dev/null +++ b/test/mockserver_tests/test_default.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import FixedSizePool, ResultSet +from test.mockserver_tests.mock_server_test_base import MockServerTestBase, add_result +from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest + + +class TestCreateTableDefault(MockServerTestBase): + def test_create_table_with_default(self): + from test.mockserver_tests.default_model import Base + + add_result( + """SELECT true +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_SCHEMA="" AND TABLE_NAME="singers" +LIMIT 1 +""", + ResultSet(), + ) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + Base.metadata.create_all(engine) + requests = self.database_admin_service.requests + eq_(1, len(requests)) + is_instance_of(requests[0], UpdateDatabaseDdlRequest) + eq_(1, len(requests[0].statements)) + eq_( + "CREATE TABLE singers (\n" + "\tid STRING(MAX) NOT NULL DEFAULT (GENERATE_UUID()), \n" + "\tname STRING(MAX) NOT NULL\n" + ") PRIMARY KEY (id)", + requests[0].statements[0], + )