Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64

import re

import sqlalchemy
from alembic.ddl.base import (
ColumnNullable,
ColumnType,
Expand All @@ -25,14 +24,16 @@
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import Client, TransactionOptions
from sqlalchemy.exc import NoSuchTableError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you remove the unrelated changes from this pull request (and potentially open a separate pull request for formatting issues, if any)

from sqlalchemy.sql import elements
from google.cloud.spanner_v1.data_types import JsonObject
from sqlalchemy import ForeignKeyConstraint, types, TypeDecorator, PickleType
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
from sqlalchemy.event import listens_for
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.pool import Pool
from sqlalchemy.sql import elements
from sqlalchemy.sql import expression
from sqlalchemy.sql.compiler import (
selectable,
DDLCompiler,
Expand All @@ -44,13 +45,10 @@
)
from sqlalchemy.sql.default_comparator import operator_lookup
from sqlalchemy.sql.operators import json_getitem_op
from sqlalchemy.sql import expression

from google.cloud.spanner_v1.data_types import JsonObject
from google.cloud import spanner_dbapi
from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call
from google.cloud.sqlalchemy_spanner import version as sqlalchemy_spanner_version
import sqlalchemy
from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call

USING_SQLACLCHEMY_20 = False
if sqlalchemy.__version__.split(".")[0] == "2":
Expand All @@ -63,12 +61,18 @@
@listens_for(Pool, "reset")
def reset_connection(dbapi_conn, connection_record, reset_state=None):
"""An event of returning a connection back to a pool."""
if hasattr(dbapi_conn, "driver_connection"):
dbapi_conn = dbapi_conn.driver_connection
if hasattr(dbapi_conn, "connection"):
dbapi_conn = dbapi_conn.connection
if isinstance(dbapi_conn, spanner_dbapi.Connection):
if dbapi_conn.inside_transaction:
transaction_started = getattr(
dbapi_conn,
"spanner_transaction_started",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: The name of the property is _spanner_transaction_started.

But I don't think it is wise to make this change at all, or at least not in the way that it is done here. _spanner_transaction_started is a private property, meaning that it could silently be removed or changed at any time. It would be better to continue to rely only on the inside_transaction property, even though it is deprecated, as it is part of the public API.

getattr(dbapi_conn, "inside_transaction", False),
)
if transaction_started:
dbapi_conn.rollback()

dbapi_conn.staleness = None
dbapi_conn.read_only = False

Expand Down Expand Up @@ -1709,7 +1713,7 @@ def set_isolation_level(self, conn_proxy, level):
conn_proxy (
Union[
sqlalchemy.pool._ConnectionFairy,
spanner_dbapi.connection.Connection,
spanner_dbapi.driver_connection.Connection,
]
):
Database connection proxy object or the connection itself.
Expand All @@ -1718,7 +1722,7 @@ def set_isolation_level(self, conn_proxy, level):
if isinstance(conn_proxy, spanner_dbapi.Connection):
conn = conn_proxy
else:
conn = conn_proxy.connection
conn = conn_proxy.driver_connection

if level == "AUTOCOMMIT":
conn.autocommit = True
Expand All @@ -1735,7 +1739,7 @@ def get_isolation_level(self, conn_proxy):
conn_proxy (
Union[
sqlalchemy.pool._ConnectionFairy,
spanner_dbapi.connection.Connection,
spanner_dbapi.driver_connection.Connection,
]
):
Database connection proxy object or the connection itself.
Expand All @@ -1746,7 +1750,7 @@ def get_isolation_level(self, conn_proxy):
if isinstance(conn_proxy, spanner_dbapi.Connection):
conn = conn_proxy
else:
conn = conn_proxy.connection
conn = conn_proxy.driver_connection

if conn.autocommit:
return "AUTOCOMMIT"
Expand Down
Loading