diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index ffa393fd..e3fb563c 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -425,7 +425,14 @@ def returning_clause(self, stmt, returning_cols, **kw): self._label_select_column( None, c, True, False, {"spanner_is_returning": True} ) - for c in expression._select_iterables(returning_cols) + for c in expression._select_iterables( + filter( + lambda col: not col.dialect_options.get("spanner", {}).get( + "exclude_from_returning", False + ), + returning_cols, + ) + ) ] return "THEN RETURN " + ", ".join(columns) @@ -831,6 +838,7 @@ class SpannerDialect(DefaultDialect): update_returning = True delete_returning = True supports_multivalues_insert = True + use_insertmanyvalues = True ddl_compiler = SpannerDDLCompiler preparer = SpannerIdentifierPreparer diff --git a/samples/insertmany_sample.py b/samples/insertmany_sample.py new file mode 100644 index 00000000..859bc158 --- /dev/null +++ b/samples/insertmany_sample.py @@ -0,0 +1,84 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime +import uuid +from sqlalchemy import text, String, create_engine +from sqlalchemy.orm import DeclarativeBase, Session +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sample_helper import run_sample + + +class Base(DeclarativeBase): + pass + + +# To use SQLAlchemy 2.0's insertmany feature, models must have a +# unique column marked as an "insert_sentinal" with client-side +# generated values passed into it. This allows SQLAlchemy to perform a +# single bulk insert, even if the table has columns with server-side +# defaults which must be retrieved from a THEN RETURN clause, for +# operations like: +# +# with Session.begin() as session: +# session.add(Singer(name="a")) +# session.add(Singer(name="b")) +# +# Read more in the SQLAlchemy documentation of this feature: +# https://docs.sqlalchemy.org/en/20/core/connections.html#configuring-sentinel-columns + + +class Singer(Base): + __tablename__ = "singers_with_sentinel" + id: Mapped[str] = mapped_column( + String(36), + primary_key=True, + # Supply a unique UUID client-side + default=lambda: str(uuid.uuid4()), + # The column is unique and can be used as an insert_sentinel + insert_sentinel=True, + # Set a server-side default for write outside SQLAlchemy + server_default=text("GENERATE_UUID()"), + ) + name: Mapped[str] + inserted_at: Mapped[datetime] = mapped_column( + server_default=text("CURRENT_TIMESTAMP()") + ) + + +# Shows how to insert data using SQLAlchemy, including relationships that are +# defined both as foreign keys and as interleaved tables. +def insertmany(): + engine = create_engine( + "spanner:///projects/sample-project/" + "instances/sample-instance/" + "databases/sample-database", + echo=True, + ) + # Create the sample table. + Base.metadata.create_all(engine) + + # Insert two singers in one session. These two singers will be inserted using + # a single INSERT statement with a THEN RETURN clause to return the generated + # creation timestamp. + with Session(engine) as session: + session.add(Singer(name="John Smith")) + session.add(Singer(name="Jane Smith")) + session.commit() + + +if __name__ == "__main__": + run_sample(insertmany) diff --git a/samples/noxfile.py b/samples/noxfile.py index 880bca48..cd28a3f0 100644 --- a/samples/noxfile.py +++ b/samples/noxfile.py @@ -92,6 +92,11 @@ def informational_fk(session): _sample(session) +@nox.session() +def insertmany(session): + _sample(session) + + @nox.session() def _all_samples(session): _sample(session) diff --git a/test/mockserver_tests/insertmany_model.py b/test/mockserver_tests/insertmany_model.py new file mode 100644 index 00000000..a196e142 --- /dev/null +++ b/test/mockserver_tests/insertmany_model.py @@ -0,0 +1,48 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import uuid +from sqlalchemy import text, String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class SingerUUID(Base): + __tablename__ = "singers_uuid" + id: Mapped[str] = mapped_column( + String(36), + primary_key=True, + server_default=text("GENERATE_UUID()"), + default=lambda: str(uuid.uuid4()), + insert_sentinel=True, + ) + name: Mapped[str] + inserted_at: Mapped[datetime] = mapped_column( + server_default=text("CURRENT_TIMESTAMP()") + ) + + +class SingerIntID(Base): + __tablename__ = "singers_int_id" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String) + inserted_at: Mapped[datetime] = mapped_column( + server_default=text("CURRENT_TIMESTAMP()") + ) diff --git a/test/mockserver_tests/test_insertmany.py b/test/mockserver_tests/test_insertmany.py new file mode 100644 index 00000000..f5b9f882 --- /dev/null +++ b/test/mockserver_tests/test_insertmany.py @@ -0,0 +1,191 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from unittest import mock + +import sqlalchemy +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import ( + ExecuteSqlRequest, + CommitRequest, + RollbackRequest, + BeginTransactionRequest, + CreateSessionRequest, +) +from test.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_result, +) +import google.cloud.spanner_v1.types.type as spanner_type +import google.cloud.spanner_v1.types.result_set as result_set + + +class TestInsertmany(MockServerTestBase): + @mock.patch.object(uuid, "uuid4", mock.MagicMock(side_effect=["a", "b"])) + def test_insertmany_with_uuid_sentinels(self): + """Ensures one bulk insert for ORM objects distinguished by uuid.""" + from test.mockserver_tests.insertmany_model import SingerUUID + + self.add_uuid_insert_result( + "INSERT INTO singers_uuid (id, name) " + "VALUES (@a0, @a1), (@a2, @a3) " + "THEN RETURN inserted_at, id" + ) + engine = self.create_engine() + + with Session(engine) as session: + session.add(SingerUUID(name="a")) + session.add(SingerUUID(name="b")) + session.commit() + + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(4, len(requests)) + is_instance_of(requests[0], CreateSessionRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], CommitRequest) + + def test_no_insertmany_with_bit_reversed_id(self): + """Ensures we don't try to bulk insert rows with bit-reversed PKs. + + SQLAlchemy's insertmany support requires either incrementing + PKs or client-side supplied sentinel values such as UUIDs. + Spanner's bit-reversed integer PKs don't meet the ordering + requirement, so we need to make sure we don't try to bulk + insert with them. + """ + from test.mockserver_tests.insertmany_model import SingerIntID + + self.add_int_id_insert_result( + "INSERT INTO singers_int_id (name) " + "VALUES (@a0) " + "THEN RETURN id, inserted_at" + ) + engine = self.create_engine() + + with Session(engine) as session: + session.add(SingerIntID(name="a")) + session.add(SingerIntID(name="b")) + try: + session.commit() + except sqlalchemy.exc.SAWarning: + # This will fail because we're returning the same PK + # for two rows. The mock server doesn't currently + # support associating the same query with two + # different results. For our purposes that's okay -- + # we just want to ensure we generate two INSERTs, not + # one. + pass + + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(5, len(requests)) + is_instance_of(requests[0], CreateSessionRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], ExecuteSqlRequest) + is_instance_of(requests[4], RollbackRequest) + + def add_uuid_insert_result(self, sql): + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name="inserted_at", + type=spanner_type.Type( + dict( + code=spanner_type.TypeCode.TIMESTAMP + ) + ), + ) + ), + spanner_type.StructType.Field( + dict( + name="id", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.STRING) + ), + ) + ), + ] + ) + ) + ) + ), + ) + ) + result.rows.extend( + [ + ( + "2020-06-02T23:58:40Z", + "a", + ), + ( + "2020-06-02T23:58:41Z", + "b", + ), + ] + ) + add_result(sql, result) + + def add_int_id_insert_result(self, sql): + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name="id", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.INT64) + ), + ) + ), + spanner_type.StructType.Field( + dict( + name="inserted_at", + type=spanner_type.Type( + dict( + code=spanner_type.TypeCode.TIMESTAMP + ) + ), + ) + ), + ] + ) + ) + ) + ), + ) + ) + result.rows.extend( + [ + ( + "1", + "2020-06-02T23:58:40Z", + ), + ] + ) + add_result(sql, result) diff --git a/test/system/test_basics.py b/test/system/test_basics.py index 7ea6fa2b..30a61865 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -316,6 +316,8 @@ class TimestampUser(Base): updated_at: Mapped[datetime.datetime] = mapped_column( spanner_allow_commit_timestamp=True, default=text("PENDING_COMMIT_TIMESTAMP()"), + # Make sure that this column is never part of a THEN RETURN clause. + spanner_exclude_from_returning=True, ) @event.listens_for(TimestampUser, "before_update")