Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,14 @@ def returning_clause(self, stmt, returning_cols, **kw):
self._label_select_column(
None, c, True, False, {"spanner_is_returning": True}
)
for c in expression._select_iterables(returning_cols)
for c in expression._select_iterables(
filter(
lambda col: not col.dialect_options.get("spanner", {}).get(
"exclude_from_returning", False
),
returning_cols,
)
)
]

return "THEN RETURN " + ", ".join(columns)
Expand Down Expand Up @@ -831,6 +838,7 @@ class SpannerDialect(DefaultDialect):
update_returning = True
delete_returning = True
supports_multivalues_insert = True
use_insertmanyvalues = True

ddl_compiler = SpannerDDLCompiler
preparer = SpannerIdentifierPreparer
Expand Down
84 changes: 84 additions & 0 deletions samples/insertmany_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from datetime import datetime
import uuid
from sqlalchemy import text, String, create_engine
from sqlalchemy.orm import DeclarativeBase, Session
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sample_helper import run_sample


class Base(DeclarativeBase):
pass


# To use SQLAlchemy 2.0's insertmany feature, models must have a
# unique column marked as an "insert_sentinal" with client-side
# generated values passed into it. This allows SQLAlchemy to perform a
# single bulk insert, even if the table has columns with server-side
# defaults which must be retrieved from a THEN RETURN clause, for
# operations like:
#
# with Session.begin() as session:
# session.add(Singer(name="a"))
# session.add(Singer(name="b"))
#
# Read more in the SQLAlchemy documentation of this feature:
# https://docs.sqlalchemy.org/en/20/core/connections.html#configuring-sentinel-columns


class Singer(Base):
__tablename__ = "singers_with_sentinel"
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
# Supply a unique UUID client-side
default=lambda: str(uuid.uuid4()),
# The column is unique and can be used as an insert_sentinel
insert_sentinel=True,
# Set a server-side default for write outside SQLAlchemy
server_default=text("GENERATE_UUID()"),
)
name: Mapped[str]
inserted_at: Mapped[datetime] = mapped_column(
server_default=text("CURRENT_TIMESTAMP()")
)


# Shows how to insert data using SQLAlchemy, including relationships that are
# defined both as foreign keys and as interleaved tables.
def insertmany():
engine = create_engine(
"spanner:///projects/sample-project/"
"instances/sample-instance/"
"databases/sample-database",
echo=True,
)
# Create the sample table.
Base.metadata.create_all(engine)

# Insert two singers in one session. These two singers will be inserted using
# a single INSERT statement with a THEN RETURN clause to return the generated
# creation timestamp.
with Session(engine) as session:
session.add(Singer(name="John Smith"))
session.add(Singer(name="Jane Smith"))
session.commit()


if __name__ == "__main__":
run_sample(insertmany)
5 changes: 5 additions & 0 deletions samples/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def informational_fk(session):
_sample(session)


@nox.session()
def insertmany(session):
_sample(session)


@nox.session()
def _all_samples(session):
_sample(session)
Expand Down
48 changes: 48 additions & 0 deletions test/mockserver_tests/insertmany_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
import uuid
from sqlalchemy import text, String
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column


class Base(DeclarativeBase):
pass


class SingerUUID(Base):
__tablename__ = "singers_uuid"
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
server_default=text("GENERATE_UUID()"),
default=lambda: str(uuid.uuid4()),
insert_sentinel=True,
)
name: Mapped[str]
inserted_at: Mapped[datetime] = mapped_column(
server_default=text("CURRENT_TIMESTAMP()")
)


class SingerIntID(Base):
__tablename__ = "singers_int_id"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
inserted_at: Mapped[datetime] = mapped_column(
server_default=text("CURRENT_TIMESTAMP()")
)
191 changes: 191 additions & 0 deletions test/mockserver_tests/test_insertmany.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2025 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from unittest import mock

import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.testing import eq_, is_instance_of
from google.cloud.spanner_v1 import (
ExecuteSqlRequest,
CommitRequest,
RollbackRequest,
BeginTransactionRequest,
CreateSessionRequest,
)
from test.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_result,
)
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set


class TestInsertmany(MockServerTestBase):
@mock.patch.object(uuid, "uuid4", mock.MagicMock(side_effect=["a", "b"]))
def test_insertmany_with_uuid_sentinels(self):
"""Ensures one bulk insert for ORM objects distinguished by uuid."""
from test.mockserver_tests.insertmany_model import SingerUUID

self.add_uuid_insert_result(
"INSERT INTO singers_uuid (id, name) "
"VALUES (@a0, @a1), (@a2, @a3) "
"THEN RETURN inserted_at, id"
)
engine = self.create_engine()

with Session(engine) as session:
session.add(SingerUUID(name="a"))
session.add(SingerUUID(name="b"))
session.commit()

# Verify the requests that we got.
requests = self.spanner_service.requests
eq_(4, len(requests))
is_instance_of(requests[0], CreateSessionRequest)
is_instance_of(requests[1], BeginTransactionRequest)
is_instance_of(requests[2], ExecuteSqlRequest)
is_instance_of(requests[3], CommitRequest)

def test_no_insertmany_with_bit_reversed_id(self):
"""Ensures we don't try to bulk insert rows with bit-reversed PKs.

SQLAlchemy's insertmany support requires either incrementing
PKs or client-side supplied sentinel values such as UUIDs.
Spanner's bit-reversed integer PKs don't meet the ordering
requirement, so we need to make sure we don't try to bulk
insert with them.
"""
from test.mockserver_tests.insertmany_model import SingerIntID

self.add_int_id_insert_result(
"INSERT INTO singers_int_id (name) "
"VALUES (@a0) "
"THEN RETURN id, inserted_at"
)
engine = self.create_engine()

with Session(engine) as session:
session.add(SingerIntID(name="a"))
session.add(SingerIntID(name="b"))
try:
session.commit()
except sqlalchemy.exc.SAWarning:
# This will fail because we're returning the same PK
# for two rows. The mock server doesn't currently
# support associating the same query with two
# different results. For our purposes that's okay --
# we just want to ensure we generate two INSERTs, not
# one.
pass

# Verify the requests that we got.
requests = self.spanner_service.requests
eq_(5, len(requests))
is_instance_of(requests[0], CreateSessionRequest)
is_instance_of(requests[1], BeginTransactionRequest)
is_instance_of(requests[2], ExecuteSqlRequest)
is_instance_of(requests[3], ExecuteSqlRequest)
is_instance_of(requests[4], RollbackRequest)

def add_uuid_insert_result(self, sql):
result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name="inserted_at",
type=spanner_type.Type(
dict(
code=spanner_type.TypeCode.TIMESTAMP
)
),
)
),
spanner_type.StructType.Field(
dict(
name="id",
type=spanner_type.Type(
dict(code=spanner_type.TypeCode.STRING)
),
)
),
]
)
)
)
),
)
)
result.rows.extend(
[
(
"2020-06-02T23:58:40Z",
"a",
),
(
"2020-06-02T23:58:41Z",
"b",
),
]
)
add_result(sql, result)

def add_int_id_insert_result(self, sql):
result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name="id",
type=spanner_type.Type(
dict(code=spanner_type.TypeCode.INT64)
),
)
),
spanner_type.StructType.Field(
dict(
name="inserted_at",
type=spanner_type.Type(
dict(
code=spanner_type.TypeCode.TIMESTAMP
)
),
)
),
]
)
)
)
),
)
)
result.rows.extend(
[
(
"1",
"2020-06-02T23:58:40Z",
),
]
)
add_result(sql, result)
2 changes: 2 additions & 0 deletions test/system/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ class TimestampUser(Base):
updated_at: Mapped[datetime.datetime] = mapped_column(
spanner_allow_commit_timestamp=True,
default=text("PENDING_COMMIT_TIMESTAMP()"),
# Make sure that this column is never part of a THEN RETURN clause.
spanner_exclude_from_returning=True,
)

@event.listens_for(TimestampUser, "before_update")
Expand Down