diff --git a/google/cloud/sqlalchemy_spanner/requirements.py b/google/cloud/sqlalchemy_spanner/requirements.py index a6e6dc00..04d21dc8 100644 --- a/google/cloud/sqlalchemy_spanner/requirements.py +++ b/google/cloud/sqlalchemy_spanner/requirements.py @@ -102,7 +102,10 @@ def get_isolation_levels(self, _): Returns: dict: isolation levels description. """ - return {"default": "SERIALIZABLE", "supported": ["SERIALIZABLE", "AUTOCOMMIT"]} + return { + "default": "SERIALIZABLE", + "supported": ["SERIALIZABLE", "REPEATABLE READ", "AUTOCOMMIT"], + } @property def precision_numerics_enotation_large(self): diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 0ef9c865..f2add3a1 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -25,7 +25,7 @@ ) from google.api_core.client_options import ClientOptions from google.auth.credentials import AnonymousCredentials -from google.cloud.spanner_v1 import Client +from google.cloud.spanner_v1 import Client, TransactionOptions from sqlalchemy.exc import NoSuchTableError from sqlalchemy.sql import elements from sqlalchemy import ForeignKeyConstraint, types, TypeDecorator, PickleType @@ -218,6 +218,16 @@ def pre_exec(self): if request_tag: self.cursor.request_tag = request_tag + ignore_transaction_warnings = self.execution_options.get( + "ignore_transaction_warnings" + ) + if ignore_transaction_warnings is not None: + conn = self._dbapi_connection.connection + if conn is not None and hasattr(conn, "_connection_variables"): + conn._connection_variables[ + "ignore_transaction_warnings" + ] = ignore_transaction_warnings + def fire_sequence(self, seq, type_): """Builds a statement for fetching next value of the sequence.""" return self._execute_scalar( @@ -777,6 +787,7 @@ class SpannerDialect(DefaultDialect): encoding = "utf-8" max_identifier_length = 256 _legacy_binary_type_literal_encoding = "utf-8" + _default_isolation_level = "SERIALIZABLE" execute_sequence_format = list @@ -828,12 +839,11 @@ def default_isolation_level(self): Returns: str: default isolation level. """ - return "SERIALIZABLE" + return self._default_isolation_level @default_isolation_level.setter def default_isolation_level(self, value): - """Default isolation level should not be changed.""" - pass + self._default_isolation_level = value def _check_unicode_returns(self, connection, additional_tests=None): """Ensure requests are returning Unicode responses.""" @@ -1682,7 +1692,7 @@ def set_isolation_level(self, conn_proxy, level): spanner_dbapi.connection.Connection, ] ): - Database connection proxy object or the connection iself. + Database connection proxy object or the connection itself. level (string): Isolation level. """ if isinstance(conn_proxy, spanner_dbapi.Connection): @@ -1690,7 +1700,13 @@ def set_isolation_level(self, conn_proxy, level): else: conn = conn_proxy.connection - conn.autocommit = level == "AUTOCOMMIT" + if level == "AUTOCOMMIT": + conn.autocommit = True + else: + if isinstance(level, str): + level = self._string_to_isolation_level(level) + conn.isolation_level = level + conn.autocommit = False def get_isolation_level(self, conn_proxy): """Get the connection isolation level. @@ -1702,7 +1718,7 @@ def get_isolation_level(self, conn_proxy): spanner_dbapi.connection.Connection, ] ): - Database connection proxy object or the connection iself. + Database connection proxy object or the connection itself. Returns: str: the connection isolation level. @@ -1712,7 +1728,31 @@ def get_isolation_level(self, conn_proxy): else: conn = conn_proxy.connection - return "AUTOCOMMIT" if conn.autocommit else "SERIALIZABLE" + if conn.autocommit: + return "AUTOCOMMIT" + + level = conn.isolation_level + if level == TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED: + level = TransactionOptions.IsolationLevel.SERIALIZABLE + if isinstance(level, TransactionOptions.IsolationLevel): + level = self._isolation_level_to_string(level) + + return level + + def _string_to_isolation_level(self, name): + try: + # SQLAlchemy guarantees that the isolation level string will: + # 1. Be all upper case. + # 2. Contain spaces instead of underscores. + # We change the spaces into underscores to get the enum value. + return TransactionOptions.IsolationLevel[name.replace(" ", "_")] + except KeyError: + raise ValueError("Invalid isolation level name '%s'" % name) + + def _isolation_level_to_string(self, level): + # SQLAlchemy expects isolation level names to contain spaces, + # and not underscores, so we remove those before returning. + return level.name.replace("_", " ") def do_rollback(self, dbapi_connection): """ diff --git a/samples/isolation_level_sample.py b/samples/isolation_level_sample.py new file mode 100644 index 00000000..ceb56643 --- /dev/null +++ b/samples/isolation_level_sample.py @@ -0,0 +1,47 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from sample_helper import run_sample +from model import Singer + + +# Shows how to set the isolation level for a read/write transaction. +# Spanner supports the following isolation levels: +# - SERIALIZABLE (default) +# - REPEATABLE READ +def isolation_level_sample(): + engine = create_engine( + "spanner:///projects/sample-project/" + "instances/sample-instance/" + "databases/sample-database", + # You can set a default isolation level for an engine. + isolation_level="REPEATABLE READ", + echo=True, + ) + # You can override the default isolation level of the connection + # by setting it in the execution_options. + with Session(engine.execution_options(isolation_level="SERIALIZABLE")) as session: + singer_id = str(uuid.uuid4()) + singer = Singer(id=singer_id, first_name="John", last_name="Doe") + session.add(singer) + session.commit() + + +if __name__ == "__main__": + run_sample(isolation_level_sample) diff --git a/samples/noxfile.py b/samples/noxfile.py index 67c3fae5..82019f5b 100644 --- a/samples/noxfile.py +++ b/samples/noxfile.py @@ -62,6 +62,11 @@ def transaction(session): _sample(session) +@nox.session() +def isolation_level(session): + _sample(session) + + @nox.session() def stale_read(session): _sample(session) diff --git a/test/mockserver_tests/isolation_level_model.py b/test/mockserver_tests/isolation_level_model.py new file mode 100644 index 00000000..9965dbf0 --- /dev/null +++ b/test/mockserver_tests/isolation_level_model.py @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import String, BigInteger +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Singer(Base): + __tablename__ = "singers" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + name: Mapped[str] = mapped_column(String) diff --git a/test/mockserver_tests/test_isolation_level.py b/test/mockserver_tests/test_isolation_level.py new file mode 100644 index 00000000..f6545298 --- /dev/null +++ b/test/mockserver_tests/test_isolation_level.py @@ -0,0 +1,208 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import ( + FixedSizePool, + BatchCreateSessionsRequest, + ExecuteSqlRequest, + CommitRequest, + BeginTransactionRequest, + TransactionOptions, +) + +from test.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_result, +) +import google.cloud.spanner_v1.types.type as spanner_type +import google.cloud.spanner_v1.types.result_set as result_set + +ISOLATION_LEVEL_UNSPECIFIED = ( + TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED +) + + +class TestIsolationLevel(MockServerTestBase): + def test_default_isolation_level(self): + from test.mockserver_tests.isolation_level_model import Singer + + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + with Session(engine) as session: + singer = Singer(name="Test") + session.add(singer) + session.commit() + self.verify_isolation_level( + TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED + ) + + def test_engine_isolation_level(self): + from test.mockserver_tests.isolation_level_model import Singer + + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + isolation_level="REPEATABLE READ", + ) + + with Session(engine) as session: + singer = Singer(name="Test") + session.add(singer) + session.commit() + self.verify_isolation_level(TransactionOptions.IsolationLevel.REPEATABLE_READ) + + def test_execution_options_isolation_level(self): + from test.mockserver_tests.isolation_level_model import Singer + + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + with Session( + engine.execution_options(isolation_level="repeatable read") + ) as session: + singer = Singer(name="Test") + session.add(singer) + session.commit() + self.verify_isolation_level(TransactionOptions.IsolationLevel.REPEATABLE_READ) + + def test_override_engine_isolation_level(self): + from test.mockserver_tests.isolation_level_model import Singer + + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + isolation_level="REPEATABLE READ", + ) + + with Session( + engine.execution_options(isolation_level="SERIALIZABLE") + ) as session: + singer = Singer(name="Test") + session.add(singer) + session.commit() + self.verify_isolation_level(TransactionOptions.IsolationLevel.SERIALIZABLE) + + def test_auto_commit(self): + from test.mockserver_tests.isolation_level_model import Singer + + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={ + "client": self.client, + "pool": FixedSizePool(size=10), + "ignore_transaction_warnings": True, + }, + ) + + with Session( + engine.execution_options( + isolation_level="AUTOCOMMIT", ignore_transaction_warnings=True + ) + ) as session: + singer = Singer(name="Test") + session.add(singer) + session.commit() + + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(3, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + is_instance_of(requests[1], ExecuteSqlRequest) + is_instance_of(requests[2], CommitRequest) + execute_request: ExecuteSqlRequest = requests[1] + eq_( + TransactionOptions( + dict( + isolation_level=ISOLATION_LEVEL_UNSPECIFIED, + read_write=TransactionOptions.ReadWrite(), + ) + ), + execute_request.transaction.begin, + ) + + def test_invalid_isolation_level(self): + from test.mockserver_tests.isolation_level_model import Singer + + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + with pytest.raises(ValueError): + with Session(engine.execution_options(isolation_level="foo")) as session: + singer = Singer(name="Test") + session.add(singer) + session.commit() + + def verify_isolation_level(self, level): + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(4, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], CommitRequest) + begin_request: BeginTransactionRequest = requests[1] + eq_( + TransactionOptions( + dict( + isolation_level=level, + read_write=TransactionOptions.ReadWrite(), + ) + ), + begin_request.options, + ) + + def add_insert_result(self, sql): + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name="id", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.INT64) + ), + ) + ) + ] + ) + ) + ) + ), + stats=result_set.ResultSetStats( + dict( + row_count_exact=1, + ) + ), + ) + ) + result.rows.extend([("987654321",)]) + add_result(sql, result) diff --git a/test/system/test_basics.py b/test/system/test_basics.py index 3001052d..693617b1 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Optional from sqlalchemy import ( text, @@ -165,7 +165,10 @@ class SchemaUser(Base): session.add(number) session.commit() - with Session(engine) as session: + level = "serializable" + if os.environ.get("SPANNER_EMULATOR_HOST", ""): + level = "REPEATABLE READ" + with Session(engine.execution_options(isolation_level=level)) as session: user = User(name="Test") session.add(user) session.commit()