From fccbba9575e00a62bddbb400bebfc1c924ebb0e9 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 13 Jan 2026 12:08:38 +0530 Subject: [PATCH 1/5] chore: remove session pool and only use multiplexed sessions --- google/cloud/spanner_dbapi/connection.py | 9 - google/cloud/spanner_v1/database.py | 48 +- .../spanner_v1/database_sessions_manager.py | 95 +- google/cloud/spanner_v1/instance.py | 14 +- google/cloud/spanner_v1/pool.py | 137 +- .../mockserver_tests/mock_server_test_base.py | 3 +- tests/system/conftest.py | 3 - tests/system/test_backup_api.py | 3 - tests/system/test_database_api.py | 108 +- tests/system/test_dbapi.py | 2 - tests/system/test_session_api.py | 10 - tests/system/utils/populate_streaming.py | 6 +- tests/unit/spanner_dbapi/test_connect.py | 2 - tests/unit/spanner_dbapi/test_connection.py | 5 +- tests/unit/test_database.py | 495 ++---- tests/unit/test_database_session_manager.py | 171 +- tests/unit/test_instance.py | 15 - tests/unit/test_pool.py | 1485 ----------------- 18 files changed, 287 insertions(+), 2324 deletions(-) delete mode 100644 tests/unit/test_pool.py diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 111bc4cc1b..69a354c609 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -104,10 +104,6 @@ def __init__(self, instance, database=None, read_only=False, **kwargs): self.is_closed = False self._autocommit = False - # indicator to know if the session pool used by - # this connection should be cleared on the - # connection close - self._own_pool = True self._read_only = read_only self._staleness = None self.request_priority = None @@ -443,9 +439,6 @@ def close(self): if self._spanner_transaction_started and not self._read_only: self._transaction.rollback() - if self._own_pool and self.database: - self.database._sessions_manager._pool.clear() - self.is_closed = True @check_not_closed @@ -830,7 +823,5 @@ def connect( database_id, pool=pool, database_role=database_role, logger=logger ) conn = Connection(instance, database, **kwargs) - if pool is not None: - conn._own_pool = False return conn diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 33c442602c..aca653fadb 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -60,7 +60,6 @@ from google.cloud.spanner_v1.batch import MutationGroups from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.merged_result_set import MergedResultSet -from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.database_sessions_manager import ( DatabaseSessionsManager, @@ -122,9 +121,11 @@ class Database(object): :type pool: concrete subclass of :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. - :param pool: (Optional) session pool to be used by database. If not - passed, the database will construct an instance of - :class:`~google.cloud.spanner_v1.pool.BurstyPool`. + :param pool: (Deprecated) session pool to be used by database. Session + pools are deprecated as multiplexed sessions are now used for + all operations by default. If not passed, the database will + construct an internal pool instance for backward compatibility. + New code should not pass a pool argument. :type logger: :class:`logging.Logger` :param logger: (Optional) a custom logger that is used if `log_commit_stats` @@ -198,16 +199,21 @@ def __init__( self._proto_descriptors = proto_descriptors self._channel_id = 0 # It'll be created when _spanner_api is created. - if pool is None: - pool = BurstyPool(database_role=database_role) + # Session pools are deprecated. Multiplexed sessions are now used for all operations. + # The pool parameter is kept for backward compatibility but is ignored. + if pool is not None: + from warnings import warn + + warn( + "The 'pool' parameter is deprecated and ignored. " + "Multiplexed sessions are now used for all operations.", + DeprecationWarning, + stacklevel=2, + ) - self._pool = pool - pool.bind(self) is_experimental_host = self._instance.experimental_host is not None - self._sessions_manager = DatabaseSessionsManager( - self, pool, is_experimental_host - ) + self._sessions_manager = DatabaseSessionsManager(self, is_experimental_host) @classmethod def from_pb(cls, database_pb, instance, pool=None): @@ -861,13 +867,9 @@ def session(self, labels=None, database_role=None): # If role is specified in param, then that role is used # instead. role = database_role or self._database_role - is_multiplexed = False - if self.sessions_manager._use_multiplexed( - transaction_type=TransactionType.READ_ONLY - ): - is_multiplexed = True + # Always use multiplexed sessions return Session( - self, labels=labels, database_role=role, is_multiplexed=is_multiplexed + self, labels=labels, database_role=role, is_multiplexed=True ) def snapshot(self, **kw): @@ -1430,12 +1432,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() self._database.sessions_manager.put_session(self._session) @@ -1471,12 +1467,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() self._database.sessions_manager.put_session(self._session) diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index bc0db1577c..c6843777db 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from os import getenv from datetime import timedelta from threading import Event, Lock, Thread from time import sleep, time @@ -41,31 +40,21 @@ class DatabaseSessionsManager(object): transaction type using :meth:`get_session`, and returned to the session manager using :meth:`put_session`. - The sessions returned by the session manager depend on the configured environment variables - and the provided session pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). + Multiplexed sessions are used for all transaction types. :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to manage sessions for. - :type pool: :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` - :param pool: The pool to get non-multiplexed sessions from. + :type is_experimental_host: bool + :param is_experimental_host: Whether the database is using an experimental host. """ - # Environment variables for multiplexed sessions - _ENV_VAR_MULTIPLEXED = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" - _ENV_VAR_MULTIPLEXED_PARTITIONED = ( - "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" - ) - _ENV_VAR_MULTIPLEXED_READ_WRITE = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" - # Intervals for the maintenance thread to check and refresh the multiplexed session. _MAINTENANCE_THREAD_POLLING_INTERVAL = timedelta(minutes=10) _MAINTENANCE_THREAD_REFRESH_INTERVAL = timedelta(days=7) - def __init__(self, database, pool, is_experimental_host: bool = False): + def __init__(self, database, is_experimental_host: bool = False): self._database = database - self._pool = pool - self._is_experimental_host = is_experimental_host # Declare multiplexed session attributes. When a multiplexed session for the # database session manager is created, a maintenance thread is initialized to @@ -81,17 +70,16 @@ def __init__(self, database, pool, is_experimental_host: bool = False): self._multiplexed_session_terminate_event: Event = Event() def get_session(self, transaction_type: TransactionType) -> Session: - """Returns a session for the given transaction type from the database session manager. + """Returns a multiplexed session for the given transaction type. + + :type transaction_type: :class:`TransactionType` + :param transaction_type: The type of transaction (ignored, multiplexed + sessions support all transaction types). :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session for the given transaction type. + :returns: a multiplexed session. """ - - session = ( - self._get_multiplexed_session() - if self._use_multiplexed(transaction_type) or self._is_experimental_host - else self._pool.get() - ) + session = self._get_multiplexed_session() add_span_event( get_current_span(), @@ -104,21 +92,18 @@ def get_session(self, transaction_type: TransactionType) -> Session: def put_session(self, session: Session) -> None: """Returns the session to the database session manager. + For multiplexed sessions, this is a no-op since they can handle + multiple concurrent transactions and don't need to be returned to a pool. + :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: The session to return to the database session manager. """ - add_span_event( get_current_span(), "Returning session", {"id": session.session_id, "multiplexed": session.is_multiplexed}, ) - - # No action is needed for multiplexed sessions: the session - # pool is only used for managing non-multiplexed sessions, - # since they can only process one transaction at a time. - if not session.is_multiplexed: - self._pool.put(session) + # Multiplexed sessions don't need to be returned to a pool def _get_multiplexed_session(self) -> Session: """Returns a multiplexed session from the database session manager. @@ -226,53 +211,3 @@ def _maintain_multiplexed_session(session_manager_ref) -> None: session_created_time = time() - @classmethod - def _use_multiplexed(cls, transaction_type: TransactionType) -> bool: - """Returns whether to use multiplexed sessions for the given transaction type. - - Multiplexed sessions are enabled for read-only transactions if: - * _ENV_VAR_MULTIPLEXED != 'false'. - - Multiplexed sessions are enabled for partitioned transactions if: - * _ENV_VAR_MULTIPLEXED_PARTITIONED != 'false'. - - Multiplexed sessions are enabled for read/write transactions if: - * _ENV_VAR_MULTIPLEXED_READ_WRITE != 'false'. - - :type transaction_type: :class:`TransactionType` - :param transaction_type: the type of transaction - - :rtype: bool - :returns: True if multiplexed sessions should be used for the given transaction - type, False otherwise. - - :raises ValueError: if the transaction type is not supported. - """ - - if transaction_type is TransactionType.READ_ONLY: - return cls._getenv(cls._ENV_VAR_MULTIPLEXED) - - elif transaction_type is TransactionType.PARTITIONED: - return cls._getenv(cls._ENV_VAR_MULTIPLEXED_PARTITIONED) - - elif transaction_type is TransactionType.READ_WRITE: - return cls._getenv(cls._ENV_VAR_MULTIPLEXED_READ_WRITE) - - raise ValueError(f"Transaction type {transaction_type} is not supported.") - - @classmethod - def _getenv(cls, env_var_name: str) -> bool: - """Returns the value of the given environment variable as a boolean. - - True unless explicitly 'false' (case-insensitive). - All other values (including unset) are considered true. - - :type env_var_name: str - :param env_var_name: the name of the boolean environment variable - - :rtype: bool - :returns: True unless the environment variable is set to 'false', False otherwise. - """ - - env_var_value = getenv(env_var_name, "true").lower().strip() - return env_var_value != "false" diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 0d05699728..f896ae8d7c 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -450,7 +450,8 @@ def database( :type pool: concrete subclass of :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. - :param pool: (Optional) session pool to be used by database. + :param pool: (Optional) Deprecated. Session pools are no longer used. + Multiplexed sessions are now used for all operations. :type logger: :class:`logging.Logger` :param logger: (Optional) a custom logger that is used if `log_commit_stats` @@ -488,13 +489,21 @@ def database( :rtype: :class:`~google.cloud.spanner_v1.database.Database` :returns: a database owned by this instance. """ + if pool is not None: + from warnings import warn + + warn( + "The 'pool' parameter is deprecated and ignored. " + "Multiplexed sessions are now used for all operations.", + DeprecationWarning, + stacklevel=2, + ) if not enable_interceptors_in_tests: return Database( database_id, self, ddl_statements=ddl_statements, - pool=pool, logger=logger, encryption_config=encryption_config, database_dialect=database_dialect, @@ -507,7 +516,6 @@ def database( database_id, self, ddl_statements=ddl_statements, - pool=pool, logger=logger, encryption_config=encryption_config, database_dialect=database_dialect, diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index a75c13cb7a..f0304bd66c 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Pools managing shared Session objects.""" +"""Pools managing shared Session objects. + +.. deprecated:: + Session pools are deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default, eliminating + the need for session pooling. +""" import datetime import queue @@ -37,10 +43,21 @@ _NOW = datetime.datetime.utcnow # unit tests may replace +_POOL_DEPRECATION_MESSAGE = ( + "Session pools are deprecated and will be removed in a future release. " + "Multiplexed sessions are now used for all operations by default, " + "eliminating the need for session pooling. " + "To disable this warning, do not pass a pool argument when creating a Database." +) + class AbstractSessionPool(object): """Specifies required API for concrete session pool implementations. + .. deprecated:: + Session pools are deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default. + :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. @@ -155,6 +172,10 @@ def session(self, **kwargs): class FixedSizePool(AbstractSessionPool): """Concrete session pool implementation: + .. deprecated:: + FixedSizePool is deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default. + - Pre-allocates / creates a fixed number of sessions. - "Pings" existing sessions via :meth:`session.exists` before returning @@ -169,11 +190,13 @@ class FixedSizePool(AbstractSessionPool): :meth:`get` followed by :meth:`put` whenever in need of a session. :type size: int - :param size: fixed pool size + :param size: (Deprecated) fixed pool size. This parameter is deprecated + as session pools are no longer needed with multiplexed sessions. :type default_timeout: int - :param default_timeout: default timeout, in seconds, to wait for - a returned session. + :param default_timeout: (Deprecated) default timeout, in seconds, to wait for + a returned session. This parameter is deprecated as session pools are + no longer needed with multiplexed sessions. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created @@ -195,6 +218,11 @@ def __init__( database_role=None, max_age_minutes=DEFAULT_MAX_AGE_MINUTES, ): + warn( + _POOL_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) super(FixedSizePool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout @@ -368,6 +396,10 @@ def clear(self): class BurstyPool(AbstractSessionPool): """Concrete session pool implementation: + .. deprecated:: + BurstyPool is deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default. + - "Pings" existing sessions via :meth:`session.exists` before returning them. @@ -378,7 +410,8 @@ class BurstyPool(AbstractSessionPool): is called on a full pool. :type target_size: int - :param target_size: max pool size + :param target_size: (Deprecated) max pool size. This parameter is deprecated + as session pools are no longer needed with multiplexed sessions. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created @@ -388,7 +421,17 @@ class BurstyPool(AbstractSessionPool): :param database_role: (Optional) user-assigned database_role for the session. """ + # Internal flag to suppress deprecation warning when BurstyPool is used + # as a fallback/internal implementation detail. + _suppress_warning = False + def __init__(self, target_size=10, labels=None, database_role=None): + if not BurstyPool._suppress_warning: + warn( + _POOL_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) super(BurstyPool, self).__init__(labels=labels, database_role=database_role) self.target_size = target_size self._database = None @@ -474,6 +517,10 @@ def clear(self): class PingingPool(AbstractSessionPool): """Concrete session pool implementation: + .. deprecated:: + PingingPool is deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default. + - Pre-allocates / creates a fixed number of sessions. - Sessions are used in "round-robin" order (LRU first). @@ -492,14 +539,18 @@ class PingingPool(AbstractSessionPool): times, e.g. from a background thread. :type size: int - :param size: fixed pool size + :param size: (Deprecated) fixed pool size. This parameter is deprecated + as session pools are no longer needed with multiplexed sessions. :type default_timeout: int - :param default_timeout: default timeout, in seconds, to wait for - a returned session. + :param default_timeout: (Deprecated) default timeout, in seconds, to wait for + a returned session. This parameter is deprecated as session pools are + no longer needed with multiplexed sessions. :type ping_interval: int - :param ping_interval: interval at which to ping sessions. + :param ping_interval: (Deprecated) interval at which to ping sessions. + This parameter is deprecated as session pools are no longer needed + with multiplexed sessions. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created @@ -509,6 +560,9 @@ class PingingPool(AbstractSessionPool): :param database_role: (Optional) user-assigned database_role for the session. """ + # Internal flag to suppress deprecation warning when called from subclass. + _suppress_warning = False + def __init__( self, size=10, @@ -517,6 +571,12 @@ def __init__( labels=None, database_role=None, ): + if not PingingPool._suppress_warning: + warn( + _POOL_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) super(PingingPool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout @@ -699,9 +759,11 @@ def ping(self): class TransactionPingingPool(PingingPool): """Concrete session pool implementation: - Deprecated: TransactionPingingPool no longer begins a transaction for each of its sessions at startup. - Hence the TransactionPingingPool is same as :class:`PingingPool` and maybe removed in the future. - + .. deprecated:: + TransactionPingingPool is deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default. + TransactionPingingPool no longer begins a transaction for each of its sessions + at startup. Hence the TransactionPingingPool is same as :class:`PingingPool`. In addition to the features of :class:`PingingPool`, this class creates and begins a transaction for each of its sessions at startup. @@ -713,14 +775,18 @@ class TransactionPingingPool(PingingPool): as appropriate via the pool's :meth:`begin_pending_transactions` method. :type size: int - :param size: fixed pool size + :param size: (Deprecated) fixed pool size. This parameter is deprecated + as session pools are no longer needed with multiplexed sessions. :type default_timeout: int - :param default_timeout: default timeout, in seconds, to wait for - a returned session. + :param default_timeout: (Deprecated) default timeout, in seconds, to wait for + a returned session. This parameter is deprecated as session pools are + no longer needed with multiplexed sessions. :type ping_interval: int - :param ping_interval: interval at which to ping sessions. + :param ping_interval: (Deprecated) interval at which to ping sessions. + This parameter is deprecated as session pools are no longer needed + with multiplexed sessions. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created @@ -740,19 +806,24 @@ def __init__( ): """This throws a deprecation warning on initialization.""" warn( - f"{self.__class__.__name__} is deprecated.", + _POOL_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2, ) self._pending_sessions = queue.Queue() - super(TransactionPingingPool, self).__init__( - size, - default_timeout, - ping_interval, - labels=labels, - database_role=database_role, - ) + # Suppress warning from parent class to avoid double warning + PingingPool._suppress_warning = True + try: + super(TransactionPingingPool, self).__init__( + size, + default_timeout, + ping_interval, + labels=labels, + database_role=database_role, + ) + finally: + PingingPool._suppress_warning = False self.begin_pending_transactions() @@ -797,13 +868,16 @@ def begin_pending_transactions(self): class SessionCheckout(object): """Context manager: hold session checked out from a pool. - Deprecated. Sessions should be checked out indirectly using context - managers or :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, - rather than checked out directly from the pool. + .. deprecated:: + SessionCheckout is deprecated and will be removed in a future release. + Multiplexed sessions are now used for all operations by default. + Sessions should be checked out indirectly using context managers or + :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, + rather than checked out directly from the pool. :type pool: concrete subclass of :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` - :param pool: Pool from which to check out a session. + :param pool: (Deprecated) Pool from which to check out a session. :param kwargs: extra keyword arguments to be passed to :meth:`pool.get`. """ @@ -811,6 +885,13 @@ class SessionCheckout(object): _session = None def __init__(self, pool, **kwargs): + warn( + "SessionCheckout is deprecated. " + "Sessions should be managed through database context managers or " + "run_in_transaction instead of being checked out directly from the pool.", + DeprecationWarning, + stacklevel=2, + ) self._pool = pool self._kwargs = kwargs.copy() diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 117b649e1b..75455807d6 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -33,7 +33,7 @@ import google.cloud.spanner_v1.types.result_set as result_set import google.cloud.spanner_v1.types.type as spanner_type from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode -from google.cloud.spanner_v1 import Client, FixedSizePool, ResultSetMetadata, TypeCode +from google.cloud.spanner_v1 import Client, ResultSetMetadata, TypeCode from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer @@ -228,7 +228,6 @@ def database(self) -> Database: if self._database is None: self._database = self.instance.database( "test-database", - pool=FixedSizePool(size=10), enable_interceptors_in_tests=True, logger=self.logger, ) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 6b0ad6cebe..1aafdf3c93 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -223,11 +223,9 @@ def shared_database( shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file ): database_name = _helpers.unique_id("test_database") - pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: database = shared_instance.database( database_name, - pool=pool, database_dialect=database_dialect, ) operation = database.create() @@ -240,7 +238,6 @@ def shared_database( database = shared_instance.database( database_name, ddl_statements=_helpers.DDL_STATEMENTS, - pool=pool, database_dialect=database_dialect, proto_descriptors=proto_descriptor_file, ) diff --git a/tests/system/test_backup_api.py b/tests/system/test_backup_api.py index 26a2620765..7349dae0f4 100644 --- a/tests/system/test_backup_api.py +++ b/tests/system/test_backup_api.py @@ -104,11 +104,9 @@ def second_database( shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file ): database_name = _helpers.unique_id("test_database2") - pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: database = shared_instance.database( database_name, - pool=pool, database_dialect=database_dialect, ) operation = database.create() @@ -121,7 +119,6 @@ def second_database( database = shared_instance.database( database_name, ddl_statements=_helpers.DDL_STATEMENTS, - pool=pool, database_dialect=database_dialect, proto_descriptors=proto_descriptor_file, ) diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index d47826baf4..52c70517da 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -21,7 +21,6 @@ from google.api_core import exceptions from google.iam.v1 import policy_pb2 from google.cloud import spanner_v1 -from google.cloud.spanner_v1.pool import FixedSizePool, PingingPool from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud.spanner_v1 import DirectedReadOptions from google.type import expr_pb2 @@ -78,11 +77,8 @@ def test_list_databases(shared_instance, shared_database): def test_create_database(shared_instance, databases_to_delete, database_dialect): - pool = spanner_v1.BurstyPool(labels={"testcase": "create_database"}) temp_db_id = _helpers.unique_id("temp_db") - temp_db = shared_instance.database( - temp_db_id, pool=pool, database_dialect=database_dialect - ) + temp_db = shared_instance.database(temp_db_id, database_dialect=database_dialect) operation = temp_db.create() databases_to_delete.append(temp_db) @@ -93,90 +89,18 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) assert temp_db.name in database_ids -def test_database_binding_of_fixed_size_pool( - not_emulator, - shared_instance, - databases_to_delete, - not_postgres, - proto_descriptor_file, - not_experimental_host, -): - temp_db_id = _helpers.unique_id("fixed_size_db", separator="_") - temp_db = shared_instance.database(temp_db_id) - - create_op = temp_db.create() - databases_to_delete.append(temp_db) - create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. - - # Create role and grant select permission on table contacts for parent role. - ddl_statements = _helpers.DDL_STATEMENTS + [ - "CREATE ROLE parent", - "GRANT SELECT ON TABLE contacts TO ROLE parent", - ] - operation = temp_db.update_ddl( - ddl_statements, proto_descriptors=proto_descriptor_file - ) - operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. - - pool = FixedSizePool( - size=1, - default_timeout=500, - database_role="parent", - ) - database = shared_instance.database(temp_db_id, pool=pool) - assert database._pool.database_role == "parent" - - -def test_database_binding_of_pinging_pool( - not_emulator, - shared_instance, - databases_to_delete, - not_postgres, - proto_descriptor_file, - not_experimental_host, -): - temp_db_id = _helpers.unique_id("binding_db", separator="_") - temp_db = shared_instance.database(temp_db_id) - - create_op = temp_db.create() - databases_to_delete.append(temp_db) - create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. - - # Create role and grant select permission on table contacts for parent role. - ddl_statements = _helpers.DDL_STATEMENTS + [ - "CREATE ROLE parent", - "GRANT SELECT ON TABLE contacts TO ROLE parent", - ] - operation = temp_db.update_ddl( - ddl_statements, proto_descriptors=proto_descriptor_file - ) - operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. - - pool = PingingPool( - size=1, - default_timeout=500, - ping_interval=100, - database_role="parent", - ) - database = shared_instance.database(temp_db_id, pool=pool) - assert database._pool.database_role == "parent" - - def test_create_database_pitr_invalid_retention_period( not_emulator, # PITR-lite features are not supported by the emulator not_postgres, shared_instance, ): - pool = spanner_v1.BurstyPool(labels={"testcase": "create_database_pitr"}) temp_db_id = _helpers.unique_id("pitr_inv_db", separator="_") retention_period = "0d" ddl_statements = [ f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (version_retention_period = '{retention_period}')" ] - temp_db = shared_instance.database( - temp_db_id, pool=pool, ddl_statements=ddl_statements - ) + temp_db = shared_instance.database(temp_db_id, ddl_statements=ddl_statements) with pytest.raises(exceptions.InvalidArgument): temp_db.create() @@ -187,16 +111,13 @@ def test_create_database_pitr_success( shared_instance, databases_to_delete, ): - pool = spanner_v1.BurstyPool(labels={"testcase": "create_database_pitr"}) temp_db_id = _helpers.unique_id("pitr_db", separator="_") retention_period = "7d" ddl_statements = [ f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (version_retention_period = '{retention_period}')" ] - temp_db = shared_instance.database( - temp_db_id, pool=pool, ddl_statements=ddl_statements - ) + temp_db = shared_instance.database(temp_db_id, ddl_statements=ddl_statements) operation = temp_db.create() databases_to_delete.append(temp_db) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. @@ -225,17 +146,13 @@ def test_create_database_with_default_leader_success( multiregion_instance, databases_to_delete, ): - pool = spanner_v1.BurstyPool(labels={"testcase": "create_database_default_leader"}) - temp_db_id = _helpers.unique_id("dflt_ldr_db", separator="_") default_leader = "us-east4" ddl_statements = [ f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (default_leader = '{default_leader}')" ] - temp_db = multiregion_instance.database( - temp_db_id, pool=pool, ddl_statements=ddl_statements - ) + temp_db = multiregion_instance.database(temp_db_id, ddl_statements=ddl_statements) operation = temp_db.create() databases_to_delete.append(temp_db) operation.result(30) # raises on failure / timeout. @@ -262,7 +179,6 @@ def test_iam_policy( shared_instance, databases_to_delete, ): - pool = spanner_v1.BurstyPool(labels={"testcase": "iam_policy"}) temp_db_id = _helpers.unique_id("iam_db", separator="_") create_table = ( "CREATE TABLE policy (\n" @@ -275,7 +191,6 @@ def test_iam_policy( temp_db = shared_instance.database( temp_db_id, ddl_statements=[create_table, create_role], - pool=pool, ) create_op = temp_db.create() databases_to_delete.append(temp_db) @@ -333,10 +248,9 @@ def test_update_ddl_w_operation_id( # reason="'Database.update_ddl' has a flaky timeout. See: " # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/5629 # ) - pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl"}) temp_db_id = _helpers.unique_id("update_ddl", separator="_") temp_db = shared_instance.database( - temp_db_id, pool=pool, database_dialect=database_dialect + temp_db_id, database_dialect=database_dialect ) create_op = temp_db.create() databases_to_delete.append(temp_db) @@ -366,10 +280,9 @@ def test_update_ddl_w_pitr_invalid( databases_to_delete, proto_descriptor_file, ): - pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") retention_period = "0d" - temp_db = shared_instance.database(temp_db_id, pool=pool) + temp_db = shared_instance.database(temp_db_id) create_op = temp_db.create() databases_to_delete.append(temp_db) @@ -392,10 +305,9 @@ def test_update_ddl_w_pitr_success( databases_to_delete, proto_descriptor_file, ): - pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") retention_period = "7d" - temp_db = shared_instance.database(temp_db_id, pool=pool) + temp_db = shared_instance.database(temp_db_id) create_op = temp_db.create() databases_to_delete.append(temp_db) @@ -425,13 +337,9 @@ def test_update_ddl_w_default_leader_success( databases_to_delete, proto_descriptor_file, ): - pool = spanner_v1.BurstyPool( - labels={"testcase": "update_database_ddl_default_leader"}, - ) - temp_db_id = _helpers.unique_id("dfl_ldrr_upd_ddl", separator="_") default_leader = "us-east4" - temp_db = multiregion_instance.database(temp_db_id, pool=pool) + temp_db = multiregion_instance.database(temp_db_id) create_op = temp_db.create() databases_to_delete.append(temp_db) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 309f533170..90e383e245 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -78,11 +78,9 @@ @pytest.fixture(scope="session") def raw_database(shared_instance, database_operation_timeout, not_postgres): database_id = _helpers.unique_id("dbapi-txn") - pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) database = shared_instance.database( database_id, ddl_statements=DDL_STATEMENTS, - pool=pool, enable_interceptors_in_tests=True, ) op = database.create() diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 96f5cd76dc..be9a6b79c2 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -269,12 +269,10 @@ def sessions_database( shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file ): database_name = _helpers.unique_id("test_sessions", separator="_") - pool = spanner_v1.BurstyPool(labels={"testcase": "session_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: sessions_database = shared_instance.database( database_name, - pool=pool, database_dialect=database_dialect, ) @@ -288,7 +286,6 @@ def sessions_database( sessions_database = shared_instance.database( database_name, ddl_statements=_helpers.DDL_STATEMENTS, - pool=pool, proto_descriptors=proto_descriptor_file, ) @@ -296,10 +293,6 @@ def sessions_database( operation.result(database_operation_timeout) _helpers.retry_has_all_dll(sessions_database.reload)() - # Some tests expect there to be a session present in the pool. - # Experimental host connections only support multiplexed sessions - if not _helpers.USE_EXPERIMENTAL_HOST: - pool.put(pool.get()) yield sessions_database @@ -1846,12 +1839,10 @@ def test_read_w_index( # Create an alternate dataase w/ index. extra_ddl = ["CREATE INDEX contacts_by_last_name ON contacts(last_name)"] - pool = spanner_v1.BurstyPool(labels={"testcase": "read_w_index"}) if database_dialect == DatabaseDialect.POSTGRESQL: temp_db = shared_instance.database( _helpers.unique_id("test_read", separator="_"), - pool=pool, database_dialect=database_dialect, ) operation = temp_db.create() @@ -1868,7 +1859,6 @@ def test_read_w_index( ddl_statements=_helpers.DDL_STATEMENTS + extra_ddl + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, - pool=pool, database_dialect=database_dialect, proto_descriptors=proto_descriptor_file, ) diff --git a/tests/system/utils/populate_streaming.py b/tests/system/utils/populate_streaming.py index a336228a15..f95d0b595a 100644 --- a/tests/system/utils/populate_streaming.py +++ b/tests/system/utils/populate_streaming.py @@ -16,7 +16,6 @@ from google.cloud.spanner_v1 import Client from google.cloud.spanner_v1.keyset import KeySet -from google.cloud.spanner_v1.pool import BurstyPool # Import relative to the script's directory from streaming_utils import FOUR_KAY @@ -68,10 +67,7 @@ def ensure_database(client): print_func("Instance exists: {}".format(INSTANCE_NAME)) instance.reload() - pool = BurstyPool() - database = instance.database( - DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool - ) + database = instance.database(DATABASE_NAME, ddl_statements=DDL_STATEMENTS) if not database.exists(): print_func("Creating database: {}".format(DATABASE_NAME)) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 5fd2b74a17..84fa2f566c 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -61,8 +61,6 @@ def test_w_implicit(self, mock_client): instance.database.assert_called_once_with( DATABASE, pool=None, database_role=None, logger=None ) - # Database constructs its own pool - self.assertIsNotNone(connection.database._pool) self.assertTrue(connection.instance._client.route_to_leader_enabled) def test_w_explicit(self, mock_client): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 6e8159425f..6fa416506b 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -272,7 +272,6 @@ def test_close(self): mock_rollback.assert_called_once_with() connection._transaction = mock.MagicMock() - connection._own_pool = False connection.close() self.assertTrue(connection.is_closed) @@ -890,20 +889,18 @@ def database( database_role=None, logger=None, ): - return _Database(database_id, pool, database_dialect, database_role, logger) + return _Database(database_id, database_dialect, database_role, logger) class _Database(object): def __init__( self, database_id="database_id", - pool=None, database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, database_role=None, logger=None, ): self.name = database_id - self.pool = pool self.database_dialect = database_dialect self.database_role = database_role self.logger = logger diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 92001fb52c..3d46d94302 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -39,7 +39,6 @@ from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.database_sessions_manager import TransactionType from tests._builders import build_spanner_api -from tests._helpers import is_multiplexed_enabled DML_WO_PARAM = """ DELETE FROM citizens @@ -119,8 +118,6 @@ def _make_spanner_api(): return api def test_ctor_defaults(self): - from google.cloud.spanner_v1.pool import BurstyPool - instance = _Instance(self.INSTANCE_NAME) database = self._make_one(self.DATABASE_ID, instance) @@ -128,23 +125,28 @@ def test_ctor_defaults(self): self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIsInstance(database._pool, BurstyPool) self.assertFalse(database.log_commit_stats) self.assertIsNone(database._logger) - # BurstyPool does not create sessions during 'bind()'. - self.assertTrue(database._pool._sessions.empty()) self.assertIsNone(database.database_role) self.assertTrue(database._route_to_leader_enabled, True) + # Session pools are deprecated; multiplexed sessions are now used. + self.assertIsNotNone(database._sessions_manager) def test_ctor_w_explicit_pool(self): + import warnings + instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + mock_pool = mock.Mock() # Create a mock pool to pass + # Pool parameter is deprecated and should be ignored with a warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + database = self._make_one(self.DATABASE_ID, instance, pool=mock_pool) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("pool", str(w[0].message).lower()) self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIs(database._pool, pool) - self.assertIs(pool._bound, database) def test_ctor_w_database_role(self): instance = _Instance(self.INSTANCE_NAME) @@ -183,10 +185,8 @@ def test_ctor_w_ddl_statements_ok(self): from tests._fixtures import DDL_STATEMENTS instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() database = self._make_one( - self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool - ) + self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS ) self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) @@ -267,24 +267,28 @@ def test_from_pb_instance_mistmatch(self): klass.from_pb(database_pb, instance) def test_from_pb_success_w_explicit_pool(self): + import warnings from google.cloud.spanner_admin_database_v1 import Database client = _Client() instance = _Instance(self.INSTANCE_NAME, client) database_pb = Database(name=self.DATABASE_NAME) klass = self._get_target_class() - pool = _Pool() + mock_pool = mock.Mock() # Create a mock pool to pass - database = klass.from_pb(database_pb, instance, pool=pool) + # Pool parameter is deprecated and should be ignored with a warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + database = klass.from_pb(database_pb, instance, pool=mock_pool) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) self.assertIsInstance(database, klass) self.assertEqual(database._instance, instance) self.assertEqual(database.database_id, self.DATABASE_ID) - self.assertIs(database._pool, pool) def test_from_pb_success_w_hyphen_w_default_pool(self): from google.cloud.spanner_admin_database_v1 import Database - from google.cloud.spanner_v1.pool import BurstyPool DATABASE_ID_HYPHEN = "database-id" DATABASE_NAME_HYPHEN = self.INSTANCE_NAME + "/databases/" + DATABASE_ID_HYPHEN @@ -298,21 +302,18 @@ def test_from_pb_success_w_hyphen_w_default_pool(self): self.assertIsInstance(database, klass) self.assertEqual(database._instance, instance) self.assertEqual(database.database_id, DATABASE_ID_HYPHEN) - self.assertIsInstance(database._pool, BurstyPool) - # BurstyPool does not create sessions during 'bind()'. - self.assertTrue(database._pool._sessions.empty()) + # Session pools are deprecated; multiplexed sessions are now used. + self.assertIsNotNone(database._sessions_manager) def test_name_property(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) expected_name = self.DATABASE_NAME self.assertEqual(database.name, expected_name) def test_create_time_property(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) expected_create_time = database._create_time = self._make_timestamp() self.assertEqual(database.create_time, expected_create_time) @@ -320,8 +321,7 @@ def test_state_property(self): from google.cloud.spanner_admin_database_v1 import Database instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) expected_state = database._state = Database.State.READY self.assertEqual(database.state, expected_state) @@ -329,8 +329,7 @@ def test_restore_info(self): from google.cloud.spanner_admin_database_v1 import RestoreInfo instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) restore_info = database._restore_info = mock.create_autospec( RestoreInfo, instance=True ) @@ -338,15 +337,13 @@ def test_restore_info(self): def test_version_retention_period(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) version_retention_period = database._version_retention_period = "1d" self.assertEqual(database.version_retention_period, version_retention_period) def test_earliest_version_time(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) earliest_version_time = database._earliest_version_time = self._make_timestamp() self.assertEqual(database.earliest_version_time, earliest_version_time) @@ -354,8 +351,7 @@ def test_logger_property_default(self): import logging instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) logger = logging.getLogger(database.name) self.assertEqual(database.logger, logger) @@ -363,8 +359,7 @@ def test_logger_property_custom(self): import logging instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) logger = database._logger = mock.create_autospec(logging.Logger, instance=True) self.assertEqual(database.logger, logger) @@ -372,8 +367,7 @@ def test_encryption_config(self): from google.cloud.spanner_admin_database_v1 import EncryptionConfig instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) encryption_config = database._encryption_config = mock.create_autospec( EncryptionConfig, instance=True ) @@ -383,8 +377,7 @@ def test_encryption_info(self): from google.cloud.spanner_admin_database_v1 import EncryptionInfo instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) encryption_info = database._encryption_info = [ mock.create_autospec(EncryptionInfo, instance=True) ] @@ -392,16 +385,14 @@ def test_encryption_info(self): def test_default_leader(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) default_leader = database._default_leader = "us-east4" self.assertEqual(database.default_leader, default_leader) def test_proto_descriptors(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() database = self._make_one( - self.DATABASE_ID, instance, pool=pool, proto_descriptors=b"" + self.DATABASE_ID, instance, proto_descriptors=b"" ) self.assertEqual(database.proto_descriptors, b"") @@ -411,8 +402,7 @@ def test_spanner_api_property_w_scopeless_creds(self): client_options = client._client_options = mock.Mock() credentials = client.credentials = object() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") @@ -452,8 +442,7 @@ def with_scopes(self, scopes): client_options = client._client_options = mock.Mock() credentials = client.credentials = _CredentialsWithScopes() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") @@ -476,8 +465,7 @@ def with_scopes(self, scopes): def test_spanner_api_w_emulator_host(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client, emulator_host="host") - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) patch = mock.patch("google.cloud.spanner_v1.database.SpannerClient") with patch as spanner_client: @@ -496,23 +484,20 @@ def test_spanner_api_w_emulator_host(self): def test___eq__(self): instance = _Instance(self.INSTANCE_NAME) - pool1, pool2 = _Pool(), _Pool() - database1 = self._make_one(self.DATABASE_ID, instance, pool=pool1) - database2 = self._make_one(self.DATABASE_ID, instance, pool=pool2) + database1 = self._make_one(self.DATABASE_ID, instance) + database2 = self._make_one(self.DATABASE_ID, instance) self.assertEqual(database1, database2) def test___eq__type_differ(self): instance = _Instance(self.INSTANCE_NAME) - pool = _Pool() - database1 = self._make_one(self.DATABASE_ID, instance, pool=pool) + database1 = self._make_one(self.DATABASE_ID, instance) database2 = object() self.assertNotEqual(database1, database2) def test___ne__same_value(self): instance = _Instance(self.INSTANCE_NAME) - pool1, pool2 = _Pool(), _Pool() - database1 = self._make_one(self.DATABASE_ID, instance, pool=pool1) - database2 = self._make_one(self.DATABASE_ID, instance, pool=pool2) + database1 = self._make_one(self.DATABASE_ID, instance) + database2 = self._make_one(self.DATABASE_ID, instance) comparison_val = database1 != database2 self.assertFalse(comparison_val) @@ -520,9 +505,8 @@ def test___ne__(self): instance1, instance2 = _Instance(self.INSTANCE_NAME + "1"), _Instance( self.INSTANCE_NAME + "2" ) - pool1, pool2 = _Pool(), _Pool() - database1 = self._make_one("database_id1", instance1, pool=pool1) - database2 = self._make_one("database_id2", instance2, pool=pool2) + database1 = self._make_one("database_id1", instance1) + database2 = self._make_one("database_id2", instance2) self.assertNotEqual(database1, database2) def test_create_grpc_error(self): @@ -535,8 +519,7 @@ def test_create_grpc_error(self): api.create_database.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(GoogleAPICallError): database.create() @@ -568,8 +551,7 @@ def test_create_already_exists(self): api = client.database_admin_api = self._make_database_admin_api() api.create_database.side_effect = Conflict("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) + database = self._make_one(DATABASE_ID_HYPHEN, instance) with self.assertRaises(Conflict): database.create() @@ -600,8 +582,7 @@ def test_create_instance_not_found(self): api = client.database_admin_api = self._make_database_admin_api() api.create_database.side_effect = NotFound("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(NotFound): database.create() @@ -634,13 +615,11 @@ def test_create_success(self): api = client.database_admin_api = self._make_database_admin_api() api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() encryption_config = EncryptionConfig(kms_key_name="kms_key_name") database = self._make_one( self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, - pool=pool, encryption_config=encryption_config, ) @@ -676,13 +655,11 @@ def test_create_success_w_encryption_config_dict(self): api = client.database_admin_api = self._make_database_admin_api() api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() encryption_config = {"kms_key_name": "kms_key_name"} database = self._make_one( self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, - pool=pool, encryption_config=encryption_config, ) @@ -718,13 +695,11 @@ def test_create_success_w_proto_descriptors(self): api = client.database_admin_api = self._make_database_admin_api() api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() proto_descriptors = b"" database = self._make_one( self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, - pool=pool, proto_descriptors=proto_descriptors, ) @@ -757,8 +732,7 @@ def test_exists_grpc_error(self): api = client.database_admin_api = self._make_database_admin_api() api.get_database_ddl.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(Unknown): database.exists() @@ -781,8 +755,7 @@ def test_exists_not_found(self): api = client.database_admin_api = self._make_database_admin_api() api.get_database_ddl.side_effect = NotFound("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) self.assertFalse(database.exists()) @@ -806,8 +779,7 @@ def test_exists_success(self): api = client.database_admin_api = self._make_database_admin_api() api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) self.assertTrue(database.exists()) @@ -829,8 +801,7 @@ def test_reload_grpc_error(self): api = client.database_admin_api = self._make_database_admin_api() api.get_database_ddl.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(Unknown): database.reload() @@ -853,8 +824,7 @@ def test_reload_not_found(self): api = client.database_admin_api = self._make_database_admin_api() api.get_database_ddl.side_effect = NotFound("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(NotFound): database.reload() @@ -908,8 +878,7 @@ def test_reload_success(self): ) api.get_database.return_value = db_pb instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) database.reload() self.assertEqual(database._state, Database.State.READY) @@ -954,8 +923,7 @@ def test_update_ddl_grpc_error(self): api = client.database_admin_api = self._make_database_admin_api() api.update_database_ddl.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(Unknown): database.update_ddl(DDL_STATEMENTS) @@ -986,8 +954,7 @@ def test_update_ddl_not_found(self): api = client.database_admin_api = self._make_database_admin_api() api.update_database_ddl.side_effect = NotFound("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(NotFound): database.update_ddl(DDL_STATEMENTS) @@ -1018,8 +985,7 @@ def test_update_ddl(self): api = client.database_admin_api = self._make_database_admin_api() api.update_database_ddl.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) future = database.update_ddl(DDL_STATEMENTS) @@ -1051,8 +1017,7 @@ def test_update_ddl_w_operation_id(self): api = client.database_admin_api = self._make_database_admin_api() api.update_database_ddl.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) future = database.update_ddl(DDL_STATEMENTS, operation_id="someOperationId") @@ -1082,10 +1047,8 @@ def test_update_success(self): api.update_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() database = self._make_one( - self.DATABASE_ID, instance, enable_drop_protection=True, pool=pool - ) + self.DATABASE_ID, instance, enable_drop_protection=True ) future = database.update(["enable_drop_protection"]) @@ -1116,8 +1079,7 @@ def test_update_ddl_w_proto_descriptors(self): api = client.database_admin_api = self._make_database_admin_api() api.update_database_ddl.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) future = database.update_ddl(DDL_STATEMENTS, proto_descriptors=b"") @@ -1148,8 +1110,7 @@ def test_drop_grpc_error(self): api = client.database_admin_api = self._make_database_admin_api() api.drop_database.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(Unknown): database.drop() @@ -1172,8 +1133,7 @@ def test_drop_not_found(self): api = client.database_admin_api = self._make_database_admin_api() api.drop_database.side_effect = NotFound("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(NotFound): database.drop() @@ -1196,8 +1156,7 @@ def test_drop_success(self): api = client.database_admin_api = self._make_database_admin_api() api.drop_database.return_value = Empty() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) database.drop() @@ -1253,10 +1212,8 @@ def _execute_partitioned_dml_helper( client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() session = _Session() - pool.put(session) - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) multiplexed_partitioned_enabled = ( os.environ.get( @@ -1279,7 +1236,7 @@ def _execute_partitioned_dml_helper( ) expected_session = multiplexed_session else: - # When multiplexed sessions are disabled, use the regular pool session + # Multiplexed sessions are now always used expected_session = session api = database._spanner_api = self._make_spanner_api() @@ -1453,7 +1410,7 @@ def _execute_partitioned_dml_helper( database._sessions_manager.get_session.assert_called_with( TransactionType.PARTITIONED ) - # If multiplexed sessions are not enabled, the regular pool session should be used + # Multiplexed sessions are now always used def test_execute_partitioned_dml_wo_params(self): self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) @@ -1502,8 +1459,7 @@ def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self): def test_session_factory_defaults(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) session = database.session() @@ -1515,9 +1471,8 @@ def test_session_factory_defaults(self): def test_session_factory_w_labels(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() labels = {"foo": "bar"} - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) session = database.session(labels=labels) @@ -1532,46 +1487,22 @@ def test_snapshot_defaults(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - session = _Session() - pool.put(session) - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) # Mock the spanner_api to avoid creating a real SpannerClient database._spanner_api = instance._client._spanner_api - # Check if multiplexed sessions are enabled for read operations - multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) - - if multiplexed_enabled: - # When multiplexed sessions are enabled, configure the sessions manager - # to return a multiplexed session for read operations - multiplexed_session = _Session() - multiplexed_session.name = self.SESSION_NAME - multiplexed_session.is_multiplexed = True - # Override the side_effect to return the multiplexed session - database._sessions_manager.get_session = mock.Mock( - return_value=multiplexed_session - ) - expected_session = multiplexed_session - else: - expected_session = session - checkout = database.snapshot() self.assertIsInstance(checkout, SnapshotCheckout) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {}) with checkout as snapshot: - if not multiplexed_enabled: - self.assertIsNone(pool._session) - self.assertIsInstance(snapshot, Snapshot) - self.assertIs(snapshot._session, expected_session) + # Multiplexed sessions are always used + self.assertIsNotNone(snapshot._session) + self.assertTrue(snapshot._session.is_multiplexed) self.assertTrue(snapshot._strong) self.assertFalse(snapshot._multi_use) - if not multiplexed_enabled: - self.assertIs(pool._session, session) - def test_snapshot_w_read_timestamp_and_multi_use(self): import datetime from google.cloud._helpers import UTC @@ -1581,27 +1512,9 @@ def test_snapshot_w_read_timestamp_and_multi_use(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - session = _Session() - pool.put(session) - database = self._make_one(self.DATABASE_ID, instance, pool=pool) - - # Check if multiplexed sessions are enabled for read operations - multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) - - if multiplexed_enabled: - # When multiplexed sessions are enabled, configure the sessions manager - # to return a multiplexed session for read operations - multiplexed_session = _Session() - multiplexed_session.name = self.SESSION_NAME - multiplexed_session.is_multiplexed = True - # Override the side_effect to return the multiplexed session - database._sessions_manager.get_session = mock.Mock( - return_value=multiplexed_session - ) - expected_session = multiplexed_session - else: - expected_session = session + database = self._make_one(self.DATABASE_ID, instance) + # Mock the spanner_api to avoid creating a real SpannerClient + database._spanner_api = instance._client._spanner_api checkout = database.snapshot(read_timestamp=now, multi_use=True) @@ -1610,25 +1523,19 @@ def test_snapshot_w_read_timestamp_and_multi_use(self): self.assertEqual(checkout._kw, {"read_timestamp": now, "multi_use": True}) with checkout as snapshot: - if not multiplexed_enabled: - self.assertIsNone(pool._session) - self.assertIsInstance(snapshot, Snapshot) - self.assertIs(snapshot._session, expected_session) + # Multiplexed sessions are always used + self.assertIsNotNone(snapshot._session) + self.assertTrue(snapshot._session.is_multiplexed) self.assertEqual(snapshot._read_timestamp, now) self.assertTrue(snapshot._multi_use) - if not multiplexed_enabled: - self.assertIs(pool._session, session) - def test_batch(self): from google.cloud.spanner_v1.database import BatchCheckout client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() session = _Session() - pool.put(session) - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) checkout = database.batch() self.assertIsInstance(checkout, BatchCheckout) @@ -1639,10 +1546,8 @@ def test_mutation_groups(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() session = _Session() - pool.put(session) - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) checkout = database.mutation_groups() self.assertIsInstance(checkout, MutationGroupsCheckout) @@ -1652,7 +1557,7 @@ def test_batch_snapshot(self): from google.cloud.spanner_v1.database import BatchSnapshot instance = _Instance(self.INSTANCE_NAME) - database = self._make_one(self.DATABASE_ID, instance=instance, pool=_Pool()) + database = self._make_one(self.DATABASE_ID, instance=instance) batch_txn = database.batch_snapshot() self.assertIsInstance(batch_txn, BatchSnapshot) @@ -1664,7 +1569,7 @@ def test_batch_snapshot_w_read_timestamp(self): from google.cloud.spanner_v1.database import BatchSnapshot instance = _Instance(self.INSTANCE_NAME) - database = self._make_one(self.DATABASE_ID, instance=instance, pool=_Pool()) + database = self._make_one(self.DATABASE_ID, instance=instance) timestamp = self._make_timestamp() batch_txn = database.batch_snapshot(read_timestamp=timestamp) @@ -1677,7 +1582,7 @@ def test_batch_snapshot_w_exact_staleness(self): from google.cloud.spanner_v1.database import BatchSnapshot instance = _Instance(self.INSTANCE_NAME) - database = self._make_one(self.DATABASE_ID, instance=instance, pool=_Pool()) + database = self._make_one(self.DATABASE_ID, instance=instance) duration = self._make_duration() batch_txn = database.batch_snapshot(exact_staleness=duration) @@ -1692,11 +1597,9 @@ def test_run_in_transaction_wo_args(self): NOW = datetime.datetime.now() client = _Client(observability_options=dict(enable_end_to_end_tracing=True)) instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() session = _Session() - pool.put(session) session._committed = NOW - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) # Mock the spanner_api to avoid creating a real SpannerClient database._spanner_api = instance._client._spanner_api @@ -1719,11 +1622,9 @@ def test_run_in_transaction_w_args(self): NOW = datetime.datetime.now() client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() session = _Session() - pool.put(session) session._committed = NOW - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) # Mock the spanner_api to avoid creating a real SpannerClient database._spanner_api = instance._client._spanner_api @@ -1743,11 +1644,9 @@ def test_run_in_transaction_nested(self): # Perform the various setup tasks. instance = _Instance(self.INSTANCE_NAME, client=_Client()) - pool = _Pool() session = _Session(run_transaction_function=True) session._committed = datetime.now() - pool.put(session) - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) # Mock the spanner_api to avoid creating a real SpannerClient database._spanner_api = instance._client._spanner_api @@ -1778,8 +1677,7 @@ def test_restore_grpc_error(self): api = client.database_admin_api = self._make_database_admin_api() api.restore_database.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) backup = _Backup(self.BACKUP_NAME) with self.assertRaises(Unknown): @@ -1810,8 +1708,7 @@ def test_restore_not_found(self): api = client.database_admin_api = self._make_database_admin_api() api.restore_database.side_effect = NotFound("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) backup = _Backup(self.BACKUP_NAME) with self.assertRaises(NotFound): @@ -1845,13 +1742,12 @@ def test_restore_success(self): api = client.database_admin_api = self._make_database_admin_api() api.restore_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() encryption_config = RestoreDatabaseEncryptionConfig( encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, kms_key_name="kms_key_name", ) database = self._make_one( - self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config + self.DATABASE_ID, instance, encryption_config=encryption_config ) backup = _Backup(self.BACKUP_NAME) @@ -1888,13 +1784,12 @@ def test_restore_success_w_encryption_config_dict(self): api = client.database_admin_api = self._make_database_admin_api() api.restore_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() encryption_config = { "encryption_type": RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, "kms_key_name": "kms_key_name", } database = self._make_one( - self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config + self.DATABASE_ID, instance, encryption_config=encryption_config ) backup = _Backup(self.BACKUP_NAME) @@ -1931,13 +1826,12 @@ def test_restore_w_invalid_encryption_config_dict(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() encryption_config = { "encryption_type": RestoreDatabaseEncryptionConfig.EncryptionType.GOOGLE_DEFAULT_ENCRYPTION, "kms_key_name": "kms_key_name", } database = self._make_one( - self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config + self.DATABASE_ID, instance, encryption_config=encryption_config ) backup = _Backup(self.BACKUP_NAME) @@ -1949,8 +1843,7 @@ def test_is_ready(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) database._state = Database.State.READY self.assertTrue(database.is_ready()) database._state = Database.State.READY_OPTIMIZING @@ -1963,8 +1856,7 @@ def test_is_optimized(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) database._state = Database.State.READY self.assertTrue(database.is_optimized()) database._state = Database.State.READY_OPTIMIZING @@ -1981,8 +1873,7 @@ def test_list_database_operations_grpc_error(self): instance.list_database_operations = mock.MagicMock( side_effect=Unknown("testing") ) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(Unknown): database.list_database_operations() @@ -2000,8 +1891,7 @@ def test_list_database_operations_not_found(self): instance.list_database_operations = mock.MagicMock( side_effect=NotFound("testing") ) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(NotFound): database.list_database_operations() @@ -2016,8 +1906,7 @@ def test_list_database_operations_defaults(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) instance.list_database_operations = mock.MagicMock(return_value=[]) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) database.list_database_operations() @@ -2031,8 +1920,7 @@ def test_list_database_operations_explicit_filter(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) instance.list_database_operations = mock.MagicMock(return_value=[]) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) expected_filter_ = "({0}) AND ({1})".format( "metadata.@type:type.googleapis.com/google.spanner.admin.database.v1.RestoreDatabaseMetadata", @@ -2056,8 +1944,7 @@ def test_list_database_roles_grpc_error(self): api = client.database_admin_api = self._make_database_admin_api() api.list_database_roles.side_effect = Unknown("testing") instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) with self.assertRaises(Unknown): database.list_database_roles() @@ -2084,8 +1971,7 @@ def test_list_database_roles_defaults(self): api = client.database_admin_api = self._make_database_admin_api() instance = _Instance(self.INSTANCE_NAME, client=client) instance.list_database_roles = mock.MagicMock(return_value=[]) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) resp = database.list_database_roles() @@ -2110,8 +1996,7 @@ def test_table_factory_defaults(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) database._database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL my_table = database.table("my_table") self.assertIsInstance(my_table, Table) @@ -2121,8 +2006,7 @@ def test_table_factory_defaults(self): def test_list_tables(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance) tables = database.list_tables() self.assertIsNotNone(tables) @@ -2159,19 +2043,14 @@ def test_context_mgr_success(self): database = _Database(self.DATABASE_NAME) api = database.spanner_api = self._make_spanner_client() api.commit.return_value = response - pool = database._pool = _Pool() - session = _Session(database) - pool.put(session) checkout = self._make_one( database, request_options={"transaction_tag": self.TRANSACTION_TAG} ) with checkout as batch: - self.assertIsNone(pool._session) self.assertIsInstance(batch, Batch) - self.assertIs(batch._session, session) + self.assertIs(batch._session, database._default_session) - self.assertIs(pool._session, session) self.assertEqual(batch.committed, now) self.assertEqual(batch.transaction_tag, self.TRANSACTION_TAG) @@ -2212,17 +2091,13 @@ def test_context_mgr_w_commit_stats_success(self): database.log_commit_stats = True api = database.spanner_api = self._make_spanner_client() api.commit.return_value = response - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database) with checkout as batch: - self.assertIsNone(pool._session) self.assertIsInstance(batch, Batch) - self.assertIs(batch._session, session) + self.assertIs(batch._session, database._default_session) - self.assertIs(pool._session, session) self.assertEqual(batch.committed, now) expected_txn_options = TransactionOptions(read_write={}) @@ -2260,18 +2135,14 @@ def test_context_mgr_w_aborted_commit_status(self): database.log_commit_stats = True api = database.spanner_api = self._make_spanner_client() api.commit.side_effect = Aborted("aborted exception", errors=("Aborted error")) - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database, timeout_secs=0.1, default_retry_delay=0) with self.assertRaises(Aborted): with checkout as batch: - self.assertIsNone(pool._session) - self.assertIsInstance(batch, Batch) - self.assertIs(batch._session, session) + self.assertIsInstance(batch, Batch) + self.assertIs(batch._session, database._default_session) - self.assertIs(pool._session, session) expected_txn_options = TransactionOptions(read_write={}) @@ -2301,9 +2172,7 @@ def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database) class Testing(Exception): @@ -2311,12 +2180,10 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as batch: - self.assertIsNone(pool._session) self.assertIsInstance(batch, Batch) - self.assertIs(batch._session, session) + self.assertIs(batch._session, database._default_session) raise Testing() - self.assertIs(pool._session, session) self.assertIsNone(batch.committed) @@ -2331,21 +2198,17 @@ def test_ctor_defaults(self): database = _Database(self.DATABASE_NAME) session = _Session(database) - pool = database._pool = _Pool() - pool.put(session) - + checkout = self._make_one(database) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {}) with checkout as snapshot: - self.assertIsNone(pool._session) self.assertIsInstance(snapshot, Snapshot) - self.assertIs(snapshot._session, session) + self.assertIs(snapshot._session, database._default_session) self.assertTrue(snapshot._strong) self.assertFalse(snapshot._multi_use) - self.assertIs(pool._session, session) def test_ctor_w_read_timestamp_and_multi_use(self): import datetime @@ -2355,29 +2218,23 @@ def test_ctor_w_read_timestamp_and_multi_use(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) database = _Database(self.DATABASE_NAME) session = _Session(database) - pool = database._pool = _Pool() - pool.put(session) - + checkout = self._make_one(database, read_timestamp=now, multi_use=True) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {"read_timestamp": now, "multi_use": True}) with checkout as snapshot: - self.assertIsNone(pool._session) self.assertIsInstance(snapshot, Snapshot) - self.assertIs(snapshot._session, session) + self.assertIs(snapshot._session, database._default_session) self.assertEqual(snapshot._read_timestamp, now) self.assertTrue(snapshot._multi_use) - self.assertIs(pool._session, session) def test_context_mgr_failure(self): from google.cloud.spanner_v1.snapshot import Snapshot database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database) class Testing(Exception): @@ -2385,72 +2242,31 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as snapshot: - self.assertIsNone(pool._session) self.assertIsInstance(snapshot, Snapshot) - self.assertIs(snapshot._session, session) + self.assertIs(snapshot._session, database._default_session) raise Testing() - self.assertIs(pool._session, session) - - def test_context_mgr_session_not_found_error(self): - from google.cloud.exceptions import NotFound - - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=False) - pool = database._pool = _Pool() - new_session = _Session(database, name="session-2") - new_session.create = mock.MagicMock(return_value=[]) - pool._new_session = mock.MagicMock(return_value=new_session) - - pool.put(session) - checkout = self._make_one(database) - - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Session not found") - # Assert that session-1 was removed from pool and new session was added. - self.assertEqual(pool._session, new_session) - def test_context_mgr_table_not_found_error(self): from google.cloud.exceptions import NotFound database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=True) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - - pool.put(session) checkout = self._make_one(database) - self.assertEqual(pool._session, session) + # NotFound errors are propagated (multiplexed sessions don't have pool fallback) with self.assertRaises(NotFound): with checkout as _: raise NotFound("Table not found") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() def test_context_mgr_unknown_error(self): database = _Database(self.DATABASE_NAME) - session = _Session(database) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - pool.put(session) checkout = self._make_one(database) class Testing(Exception): pass - self.assertEqual(pool._session, session) with self.assertRaises(Testing): with checkout as _: raise Testing("Unknown error.") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() class TestBatchSnapshot(_BaseTest): @@ -3328,18 +3144,14 @@ def test_ctor(self): from google.cloud.spanner_v1.batch import MutationGroups database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database) self.assertIs(checkout._database, database) with checkout as groups: - self.assertIsNone(pool._session) self.assertIsInstance(groups, MutationGroups) - self.assertIs(groups._session, session) + self.assertIs(groups._session, database._default_session) - self.assertIs(pool._session, session) def test_context_mgr_success(self): import datetime @@ -3361,9 +3173,7 @@ def test_context_mgr_success(self): database = _Database(self.DATABASE_NAME) api = database.spanner_api = self._make_spanner_client() api.batch_write.return_value = [response] - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database) request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) @@ -3385,15 +3195,13 @@ def test_context_mgr_success(self): request_options=request_options, ) with checkout as groups: - self.assertIsNone(pool._session) self.assertIsInstance(groups, MutationGroups) - self.assertIs(groups._session, session) + self.assertIs(groups._session, database._default_session) group = groups.group() group.insert("table", ["col"], [["val"]]) groups.batch_write(request_options) self.assertEqual(groups.committed, True) - self.assertIs(pool._session, session) api.batch_write.assert_called_once_with( request=request, @@ -3411,9 +3219,7 @@ def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import MutationGroups database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) checkout = self._make_one(database) class Testing(Exception): @@ -3421,72 +3227,31 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as groups: - self.assertIsNone(pool._session) self.assertIsInstance(groups, MutationGroups) - self.assertIs(groups._session, session) + self.assertIs(groups._session, database._default_session) raise Testing() - self.assertIs(pool._session, session) - - def test_context_mgr_session_not_found_error(self): - from google.cloud.exceptions import NotFound - - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=False) - pool = database._pool = _Pool() - new_session = _Session(database, name="session-2") - new_session.create = mock.MagicMock(return_value=[]) - pool._new_session = mock.MagicMock(return_value=new_session) - - pool.put(session) - checkout = self._make_one(database) - - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Session not found") - # Assert that session-1 was removed from pool and new session was added. - self.assertEqual(pool._session, new_session) - def test_context_mgr_table_not_found_error(self): from google.cloud.exceptions import NotFound database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=True) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - - pool.put(session) checkout = self._make_one(database) - self.assertEqual(pool._session, session) + # NotFound errors are propagated (multiplexed sessions don't have pool fallback) with self.assertRaises(NotFound): with checkout as _: raise NotFound("Table not found") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() def test_context_mgr_unknown_error(self): database = _Database(self.DATABASE_NAME) - session = _Session(database) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - pool.put(session) checkout = self._make_one(database) class Testing(Exception): pass - self.assertEqual(pool._session, session) with self.assertRaises(Testing): with checkout as _: raise Testing("Unknown error.") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() def _make_instance_api(): @@ -3594,17 +3359,13 @@ def __init__(self, name, instance=None): # Mock sessions manager for multiplexed sessions support self._sessions_manager = mock.Mock() - # Configure get_session to return sessions from the pool + # Create a default session for the sessions manager to return + self._default_session = _Session(self) + self._default_session.is_multiplexed = True self._sessions_manager.get_session = mock.Mock( - side_effect=lambda tx_type: self._pool.get() - if hasattr(self, "_pool") and self._pool - else None - ) - self._sessions_manager.put_session = mock.Mock( - side_effect=lambda session: self._pool.put(session) - if hasattr(self, "_pool") and self._pool - else None + return_value=self._default_session ) + self._sessions_manager.put_session = mock.Mock() @property def sessions_manager(self): @@ -3636,20 +3397,6 @@ def _channel_id(self): return 1 -class _Pool(object): - _bound = None - - def bind(self, database): - self._bound = database - - def get(self): - session, self._session = self._session, None - return session - - def put(self, session): - self._session = session - - class _Session(object): _rows = () _created = False diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index c6156b5e8c..a41b461986 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from datetime import timedelta -from mock import Mock, patch -from os import environ +from mock import patch from time import time, sleep from typing import Callable from unittest import TestCase @@ -31,27 +30,11 @@ _MAINTENANCE_THREAD_REFRESH_INTERVAL=timedelta(seconds=2), ) class TestDatabaseSessionManager(TestCase): - @classmethod - def setUpClass(cls): - # Save the original environment variables. - cls._original_env = dict(environ) - - @classmethod - def tearDownClass(cls): - # Restore environment variables. - environ.clear() - environ.update(cls._original_env) - def setUp(self): # Build session manager. database = build_database() self._manager = database._sessions_manager - # Mock the session pool. - pool = self._manager._pool - pool.get = Mock(wraps=pool.get) - pool.put = Mock(wraps=pool.put) - def tearDown(self): # If the maintenance thread is still alive, set the event and wait # for the thread to terminate. We need to do this to ensure that the @@ -63,26 +46,8 @@ def tearDown(self): manager._multiplexed_session_terminate_event.set() self._assert_true_with_timeout(lambda: not thread.is_alive()) - def test_read_only_pooled(self): - manager = self._manager - pool = manager._pool - - self._disable_multiplexed_sessions() - - # Get session from pool. - session = manager.get_session(TransactionType.READ_ONLY) - self.assertFalse(session.is_multiplexed) - pool.get.assert_called_once() - - # Return session to pool. - manager.put_session(session) - pool.put.assert_called_once_with(session) - def test_read_only_multiplexed(self): manager = self._manager - pool = manager._pool - - self._enable_multiplexed_sessions() # Session is created. session_1 = manager.get_session(TransactionType.READ_ONLY) @@ -94,34 +59,12 @@ def test_read_only_multiplexed(self): self.assertEqual(session_1, session_2) manager.put_session(session_2) - # Verify that pool was not used. - pool.get.assert_not_called() - pool.put.assert_not_called() - # Verify logger calls. info = manager._database.logger.info info.assert_called_once_with("Created multiplexed session.") - def test_partitioned_pooled(self): - manager = self._manager - pool = manager._pool - - self._disable_multiplexed_sessions() - - # Get session from pool. - session = manager.get_session(TransactionType.PARTITIONED) - self.assertFalse(session.is_multiplexed) - pool.get.assert_called_once() - - # Return session to pool. - manager.put_session(session) - pool.put.assert_called_once_with(session) - def test_partitioned_multiplexed(self): manager = self._manager - pool = manager._pool - - self._enable_multiplexed_sessions() # Session is created. session_1 = manager.get_session(TransactionType.PARTITIONED) @@ -133,34 +76,12 @@ def test_partitioned_multiplexed(self): self.assertEqual(session_1, session_2) manager.put_session(session_2) - # Verify that pool was not used. - pool.get.assert_not_called() - pool.put.assert_not_called() - # Verify logger calls. info = manager._database.logger.info info.assert_called_once_with("Created multiplexed session.") - def test_read_write_pooled(self): - manager = self._manager - pool = manager._pool - - self._disable_multiplexed_sessions() - - # Get session from pool. - session = manager.get_session(TransactionType.READ_WRITE) - self.assertFalse(session.is_multiplexed) - pool.get.assert_called_once() - - # Return session to pool. - manager.put_session(session) - pool.put.assert_called_once_with(session) - def test_read_write_multiplexed(self): manager = self._manager - pool = manager._pool - - self._enable_multiplexed_sessions() # Session is created. session_1 = manager.get_session(TransactionType.READ_WRITE) @@ -172,17 +93,12 @@ def test_read_write_multiplexed(self): self.assertEqual(session_1, session_2) manager.put_session(session_2) - # Verify that pool was not used. - pool.get.assert_not_called() - pool.put.assert_not_called() - # Verify logger calls. info = manager._database.logger.info info.assert_called_once_with("Created multiplexed session.") def test_multiplexed_maintenance(self): manager = self._manager - self._enable_multiplexed_sessions() # Maintenance thread is started. session_1 = manager.get_session(TransactionType.READ_ONLY) @@ -219,75 +135,6 @@ def test_exception_failed_precondition(self): with self.assertRaises(FailedPrecondition): manager.get_session(TransactionType.READ_ONLY) - def test__use_multiplexed_read_only(self): - transaction_type = TransactionType.READ_ONLY - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false" - self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" - self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - def test__use_multiplexed_partitioned(self): - transaction_type = TransactionType.PARTITIONED - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "false" - self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "true" - self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - # Test default behavior (should be enabled) - del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] - self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - def test__use_multiplexed_read_write(self): - transaction_type = TransactionType.READ_WRITE - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "false" - self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "true" - self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - # Test default behavior (should be enabled) - del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] - self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) - - def test__use_multiplexed_unsupported_transaction_type(self): - unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE" - - with self.assertRaises(ValueError): - DatabaseSessionsManager._use_multiplexed(unsupported_type) - - def test__getenv(self): - true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] - for value in true_values: - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value - self.assertTrue( - DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) - ) - - false_values = ["false", "False", "FALSE", " false "] - for value in false_values: - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value - self.assertFalse( - DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) - ) - - # Test that empty string and "0" are now treated as true (default enabled) - default_true_values = ["", "0", "anything", "random"] - for value in default_true_values: - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value - self.assertTrue( - DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) - ) - - del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] - self.assertTrue( - DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) - ) - def _assert_true_with_timeout(self, condition: Callable) -> None: """Asserts that the given condition is met within a timeout period. @@ -303,19 +150,3 @@ def _assert_true_with_timeout(self, condition: Callable) -> None: sleep(sleep_seconds) self.assertTrue(condition()) - - @staticmethod - def _disable_multiplexed_sessions() -> None: - """Sets environment variables to disable multiplexed sessions for all transactions types.""" - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false" - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "false" - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "false" - - @staticmethod - def _enable_multiplexed_sessions() -> None: - """Sets environment variables to enable multiplexed sessions for all transaction types.""" - - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "true" - environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "true" diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f3bf6726c0..2a75179dce 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -529,7 +529,6 @@ def test_delete_success(self): def test_database_factory_defaults(self): from google.cloud.spanner_v1.database import Database - from google.cloud.spanner_v1.pool import BurstyPool client = _Client(self.PROJECT) instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) @@ -541,10 +540,7 @@ def test_database_factory_defaults(self): self.assertEqual(database.database_id, DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIsInstance(database._pool, BurstyPool) self.assertIsNone(database._logger) - pool = database._pool - self.assertIs(pool._database, database) self.assertIsNone(database.database_role) def test_database_factory_explicit(self): @@ -556,7 +552,6 @@ def test_database_factory_explicit(self): instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) DATABASE_ID = "database-id" DATABASE_ROLE = "dummy-role" - pool = _Pool() logger = mock.create_autospec(Logger, instance=True) encryption_config = {"kms_key_name": "kms_key_name"} proto_descriptors = b"" @@ -564,7 +559,6 @@ def test_database_factory_explicit(self): database = instance.database( DATABASE_ID, ddl_statements=DDL_STATEMENTS, - pool=pool, logger=logger, encryption_config=encryption_config, database_role=DATABASE_ROLE, @@ -575,9 +569,7 @@ def test_database_factory_explicit(self): self.assertEqual(database.database_id, DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) - self.assertIs(database._pool, pool) self.assertIs(database._logger, logger) - self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) self.assertIs(database.database_role, DATABASE_ROLE) self.assertIs(database._proto_descriptors, proto_descriptors) @@ -1088,10 +1080,3 @@ def delete_instance(self, name, metadata=None): class _FauxOperationFuture(object): pass - - -class _Pool(object): - _bound = None - - def bind(self, database): - self._bound = database diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py deleted file mode 100644 index ec03e4350b..0000000000 --- a/tests/unit/test_pool.py +++ /dev/null @@ -1,1485 +0,0 @@ -# Copyright 2016 Google LLC All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from functools import total_ordering -import time -import unittest -from datetime import datetime, timedelta - -import mock -from google.cloud.spanner_v1._helpers import ( - _metadata_with_request_id, - AtomicCounter, -) -from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID - -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call -from tests._builders import build_database -from tests._helpers import ( - OpenTelemetryBase, - LIB_VERSION, - StatusCode, - enrich_with_otel_scope, - HAS_OPENTELEMETRY_INSTALLED, -) - - -def _make_database(name="name"): - from google.cloud.spanner_v1.database import Database - - return mock.create_autospec(Database, instance=True) - - -def _make_session(): - from google.cloud.spanner_v1.database import Session - - return mock.create_autospec(Session, instance=True) - - -class TestAbstractSessionPool(unittest.TestCase): - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import AbstractSessionPool - - return AbstractSessionPool - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_defaults(self): - pool = self._make_one() - self.assertIsNone(pool._database) - self.assertEqual(pool.labels, {}) - self.assertIsNone(pool.database_role) - - def test_ctor_explicit(self): - labels = {"foo": "bar"} - database_role = "dummy-role" - pool = self._make_one(labels=labels, database_role=database_role) - self.assertIsNone(pool._database) - self.assertEqual(pool.labels, labels) - self.assertEqual(pool.database_role, database_role) - - def test_bind_abstract(self): - pool = self._make_one() - database = _make_database("name") - with self.assertRaises(NotImplementedError): - pool.bind(database) - - def test_get_abstract(self): - pool = self._make_one() - with self.assertRaises(NotImplementedError): - pool.get() - - def test_put_abstract(self): - pool = self._make_one() - session = object() - with self.assertRaises(NotImplementedError): - pool.put(session) - - def test_clear_abstract(self): - pool = self._make_one() - with self.assertRaises(NotImplementedError): - pool.clear() - - def test__new_session_wo_labels(self): - pool = self._make_one() - database = pool._database = build_database() - - new_session = pool._new_session() - - self.assertEqual(new_session._database, database) - self.assertEqual(new_session.labels, {}) - self.assertIsNone(new_session.database_role) - - def test__new_session_w_labels(self): - labels = {"foo": "bar"} - pool = self._make_one(labels=labels) - database = pool._database = build_database() - - new_session = pool._new_session() - - self.assertEqual(new_session._database, database) - self.assertEqual(new_session.labels, labels) - self.assertIsNone(new_session.database_role) - - def test__new_session_w_database_role(self): - database_role = "dummy-role" - pool = self._make_one(database_role=database_role) - database = pool._database = build_database() - - new_session = pool._new_session() - - self.assertEqual(new_session._database, database) - self.assertEqual(new_session.labels, {}) - self.assertEqual(new_session.database_role, database_role) - - def test_session_wo_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - pool = self._make_one() - checkout = pool.session() - self.assertIsInstance(checkout, SessionCheckout) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {}) - - def test_session_w_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - pool = self._make_one() - checkout = pool.session(foo="bar") - self.assertIsInstance(checkout, SessionCheckout) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {"foo": "bar"}) - - -class TestFixedSizePool(OpenTelemetryBase): - BASE_ATTRIBUTES = { - "db.type": "spanner", - "db.url": "spanner.googleapis.com", - "db.instance": "name", - "net.host.name": "spanner.googleapis.com", - "gcp.client.service": "spanner", - "gcp.client.version": LIB_VERSION, - "gcp.client.repo": "googleapis/python-spanner", - "cloud.region": "global", - } - enrich_with_otel_scope(BASE_ATTRIBUTES) - - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import FixedSizePool - - return FixedSizePool - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_defaults(self): - pool = self._make_one() - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertTrue(pool._sessions.empty()) - self.assertEqual(pool.labels, {}) - self.assertIsNone(pool.database_role) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ctor_explicit(self, mock_region): - labels = {"foo": "bar"} - database_role = "dummy-role" - pool = self._make_one( - size=4, default_timeout=30, labels=labels, database_role=database_role - ) - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 4) - self.assertEqual(pool.default_timeout, 30) - self.assertTrue(pool._sessions.empty()) - self.assertEqual(pool.labels, labels) - self.assertEqual(pool.database_role, database_role) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_bind(self, mock_region): - database_role = "dummy-role" - pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._database_role = database_role - database._sessions.extend(SESSIONS) - - pool.bind(database) - - self.assertIs(pool._database, database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.database_role, database_role) - self.assertEqual(pool.default_timeout, 10) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_active(self, mock_region): - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - - # check if sessions returned in LIFO order - for i in (3, 2, 1, 0): - session = pool.get() - self.assertIs(session, SESSIONS[i]) - self.assertFalse(session._exists_checked) - self.assertFalse(pool._sessions.full()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_non_expired(self, mock_region): - pool = self._make_one(size=4) - database = _Database("name") - last_use_time = datetime.utcnow() - timedelta(minutes=56) - SESSIONS = sorted( - [_Session(database, last_use_time=last_use_time) for i in range(0, 4)] - ) - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - - # check if sessions returned in LIFO order - for i in (3, 2, 1, 0): - session = pool.get() - self.assertIs(session, SESSIONS[i]) - self.assertTrue(session._exists_checked) - self.assertFalse(pool._sessions.full()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_bind_get(self, mock_region): - if not HAS_OPENTELEMETRY_INSTALLED: - return - - # This tests retrieving 1 out of 4 sessions from the session pool. - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - database._sessions.extend(SESSIONS) - pool.bind(database) - - with trace_call("pool.Get", SESSIONS[0]): - pool.get() - - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["CloudSpanner.FixedPool.BatchCreateSessions", "pool.Get"] - assert got_span_names == want_span_names - - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id - 1}.{database._channel_id}.{_Database.NTH_REQUEST.value}.1" - attrs = dict( - TestFixedSizePool.BASE_ATTRIBUTES.copy(), x_goog_spanner_request_id=req_id - ) - - # Check for the overall spans. - self.assertSpanAttributes( - "CloudSpanner.FixedPool.BatchCreateSessions", - status=StatusCode.OK, - attributes=attrs, - span=span_list[0], - ) - - self.assertSpanAttributes( - "pool.Get", - status=StatusCode.OK, - attributes=TestFixedSizePool.BASE_ATTRIBUTES, - span=span_list[-1], - ) - wantEventNames = [ - "Acquiring session", - "Waiting for a session to become available", - "Acquired session", - ] - self.assertSpanEvents("pool.Get", wantEventNames, span_list[-1]) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_bind_get_empty_pool(self, mock_region): - if not HAS_OPENTELEMETRY_INSTALLED: - return - - # Tests trying to invoke pool.get() from an empty pool. - pool = self._make_one(size=0, default_timeout=0.1) - database = _Database("name") - session1 = _Session(database) - with trace_call("pool.Get", session1): - try: - pool.bind(database) - database._sessions = database._sessions[:0] - pool.get() - except Exception: - pass - - wantEventNames = [ - "Invalid session pool size(0) <= 0", - "Acquiring session", - "Waiting for a session to become available", - "No sessions available in the pool", - ] - self.assertSpanEvents("pool.Get", wantEventNames) - - # Check for the overall spans too. - self.assertSpanNames(["pool.Get"]) - self.assertSpanAttributes( - "pool.Get", - attributes=TestFixedSizePool.BASE_ATTRIBUTES, - ) - - span_list = self.get_finished_spans() - got_all_events = [] - for span in span_list: - for event in span.events: - got_all_events.append((event.name, event.attributes)) - want_all_events = [ - ("Invalid session pool size(0) <= 0", {"kind": "FixedSizePool"}), - ("Acquiring session", {"kind": "FixedSizePool"}), - ("Waiting for a session to become available", {"kind": "FixedSizePool"}), - ("No sessions available in the pool", {"kind": "FixedSizePool"}), - ] - assert got_all_events == want_all_events - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_pool_bind(self, mock_region): - if not HAS_OPENTELEMETRY_INSTALLED: - return - - # Tests the exception generated from invoking pool.bind when - # you have an empty pool. - pool = self._make_one(size=1) - database = _Database("name") - pool._new_session = mock.Mock(side_effect=Exception("test")) - fauxSession = mock.Mock() - setattr(fauxSession, "_database", database) - try: - with trace_call("testBind", fauxSession): - pool.bind(database) - except Exception: - pass - - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["testBind", "CloudSpanner.FixedPool.BatchCreateSessions"] - assert got_span_names == want_span_names - - wantEventNames = [ - "Requesting 1 sessions", - "exception", - ] - self.assertSpanEvents("testBind", wantEventNames, span_list[0]) - - self.assertSpanAttributes( - "testBind", - status=StatusCode.ERROR, - attributes=TestFixedSizePool.BASE_ATTRIBUTES, - span=span_list[0], - ) - - got_all_events = [] - - # Some event attributes are noisy/highly ephemeral - # and can't be directly compared against. - imprecise_event_attributes = ["exception.stacktrace", "delay_seconds", "cause"] - for span in span_list: - for event in span.events: - evt_attributes = event.attributes.copy() - for attr_name in imprecise_event_attributes: - if attr_name in evt_attributes: - evt_attributes[attr_name] = "EPHEMERAL" - - got_all_events.append((event.name, evt_attributes)) - - want_all_events = [ - ("Requesting 1 sessions", {"kind": "FixedSizePool"}), - ( - "exception", - { - "exception.type": "Exception", - "exception.message": "test", - "exception.stacktrace": "EPHEMERAL", - "exception.escaped": "False", - }, - ), - ("Creating 1 sessions", {"kind": "FixedSizePool"}), - ("Created sessions", {"count": 1}), - ( - "exception", - { - "exception.type": "Exception", - "exception.message": "test", - "exception.stacktrace": "EPHEMERAL", - "exception.escaped": "False", - }, - ), - ] - assert got_all_events == want_all_events - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_expired(self, mock_region): - pool = self._make_one(size=4) - database = _Database("name") - last_use_time = datetime.utcnow() - timedelta(minutes=65) - SESSIONS = [_Session(database, last_use_time=last_use_time)] * 5 - SESSIONS[0]._exists = False - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - - session = pool.get() - - self.assertIs(session, SESSIONS[4]) - session.create.assert_called() - self.assertTrue(SESSIONS[0]._exists_checked) - self.assertFalse(pool._sessions.full()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_empty_default_timeout(self, mock_region): - import queue - - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - - with self.assertRaises(queue.Empty): - pool.get() - - self.assertEqual(session_queue._got, {"block": True, "timeout": 10}) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_empty_explicit_timeout(self, mock_region): - import queue - - pool = self._make_one(size=1, default_timeout=0.1) - session_queue = pool._sessions = _Queue() - - with self.assertRaises(queue.Empty): - pool.get(timeout=1) - - self.assertEqual(session_queue._got, {"block": True, "timeout": 1}) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_full(self, mock_region): - import queue - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) - pool.bind(database) - self.reset() - - with self.assertRaises(queue.Full): - pool.put(_Session(database)) - - self.assertTrue(pool._sessions.full()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_non_full(self, mock_region): - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) - pool.bind(database) - pool._sessions.get() - - pool.put(_Session(database)) - - self.assertTrue(pool._sessions.full()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_clear(self, mock_region): - pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - - pool.clear() - - for session in SESSIONS: - self.assertTrue(session._deleted) - - -class TestBurstyPool(OpenTelemetryBase): - BASE_ATTRIBUTES = { - "db.type": "spanner", - "db.url": "spanner.googleapis.com", - "db.instance": "name", - "net.host.name": "spanner.googleapis.com", - "gcp.client.service": "spanner", - "gcp.client.version": LIB_VERSION, - "gcp.client.repo": "googleapis/python-spanner", - "cloud.region": "global", - } - enrich_with_otel_scope(BASE_ATTRIBUTES) - - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import BurstyPool - - return BurstyPool - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_defaults(self): - pool = self._make_one() - self.assertIsNone(pool._database) - self.assertEqual(pool.target_size, 10) - self.assertTrue(pool._sessions.empty()) - self.assertEqual(pool.labels, {}) - self.assertIsNone(pool.database_role) - - def test_ctor_explicit(self): - labels = {"foo": "bar"} - database_role = "dummy-role" - pool = self._make_one(target_size=4, labels=labels, database_role=database_role) - self.assertIsNone(pool._database) - self.assertEqual(pool.target_size, 4) - self.assertTrue(pool._sessions.empty()) - self.assertEqual(pool.labels, labels) - self.assertEqual(pool.database_role, database_role) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ctor_explicit_w_database_role_in_db(self, mock_region): - database_role = "dummy-role" - pool = self._make_one() - database = pool._database = _Database("name") - database._database_role = database_role - pool.bind(database) - self.assertEqual(pool.database_role, database_role) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_empty(self, mock_region): - pool = self._make_one() - database = _Database("name") - pool._new_session = mock.Mock(return_value=_Session(database)) - pool.bind(database) - - session = pool.get() - - self.assertIsInstance(session, _Session) - self.assertIs(session._database, database) - session.create.assert_called() - self.assertTrue(pool._sessions.empty()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_get_empty_pool(self, mock_region): - if not HAS_OPENTELEMETRY_INSTALLED: - return - - # This scenario tests a pool that hasn't been filled up - # and pool.get() acquires from a pool, waiting for a session - # to become available. - pool = self._make_one() - database = _Database("name") - session1 = _Session(database) - pool._new_session = mock.Mock(return_value=session1) - pool.bind(database) - - with trace_call("pool.Get", session1): - session = pool.get() - self.assertIsInstance(session, _Session) - self.assertIs(session._database, database) - session.create.assert_called() - self.assertTrue(pool._sessions.empty()) - - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["pool.Get"] - assert got_span_names == want_span_names - - create_span = span_list[-1] - self.assertSpanAttributes( - "pool.Get", - attributes=TestBurstyPool.BASE_ATTRIBUTES, - span=create_span, - ) - wantEventNames = [ - "Acquiring session", - "Waiting for a session to become available", - "No sessions available in pool. Creating session", - ] - self.assertSpanEvents("pool.Get", wantEventNames, span=create_span) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_non_empty_session_exists(self, mock_region): - pool = self._make_one() - database = _Database("name") - previous = _Session(database) - pool.bind(database) - pool.put(previous) - - session = pool.get() - - self.assertIs(session, previous) - session.create.assert_not_called() - self.assertTrue(session._exists_checked) - self.assertTrue(pool._sessions.empty()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_get_non_empty_session_exists(self, mock_region): - # Tests the spans produces when you invoke pool.bind - # and then insert a session into the pool. - pool = self._make_one() - database = _Database("name") - previous = _Session(database) - pool.bind(database) - with trace_call("pool.Get", previous): - pool.put(previous) - session = pool.get() - self.assertIs(session, previous) - session.create.assert_not_called() - self.assertTrue(session._exists_checked) - self.assertTrue(pool._sessions.empty()) - - self.assertSpanAttributes( - "pool.Get", - attributes=TestBurstyPool.BASE_ATTRIBUTES, - ) - self.assertSpanEvents( - "pool.Get", - ["Acquiring session", "Waiting for a session to become available"], - ) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_non_empty_session_expired(self, mock_region): - pool = self._make_one() - database = _Database("name") - previous = _Session(database, exists=False) - newborn = _Session(database) - pool._new_session = mock.Mock(return_value=newborn) - pool.bind(database) - pool.put(previous) - - session = pool.get() - - self.assertTrue(previous._exists_checked) - self.assertIs(session, newborn) - session.create.assert_called() - self.assertFalse(session._exists_checked) - self.assertTrue(pool._sessions.empty()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_empty(self, mock_region): - pool = self._make_one() - database = _Database("name") - pool.bind(database) - session = _Session(database) - - pool.put(session) - - self.assertFalse(pool._sessions.empty()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_put_empty(self, mock_region): - # Tests the spans produced when you put sessions into an empty pool. - pool = self._make_one() - database = _Database("name") - pool.bind(database) - session = _Session(database) - - with trace_call("pool.put", session): - pool.put(session) - self.assertFalse(pool._sessions.empty()) - - self.assertSpanAttributes( - "pool.put", - attributes=TestBurstyPool.BASE_ATTRIBUTES, - ) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_full(self, mock_region): - pool = self._make_one(target_size=1) - database = _Database("name") - pool.bind(database) - older = _Session(database) - pool.put(older) - self.assertFalse(pool._sessions.empty()) - - younger = _Session(database) - pool.put(younger) # discarded silently - - self.assertTrue(younger._deleted) - self.assertIs(pool.get(), older) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_put_full(self, mock_region): - # This scenario tests the spans produced from putting an older - # session into a pool that is already full. - pool = self._make_one(target_size=1) - database = _Database("name") - pool.bind(database) - older = _Session(database) - with trace_call("pool.put", older): - pool.put(older) - self.assertFalse(pool._sessions.empty()) - - younger = _Session(database) - pool.put(younger) # discarded silently - - self.assertTrue(younger._deleted) - self.assertIs(pool.get(), older) - - self.assertSpanAttributes( - "pool.put", - attributes=TestBurstyPool.BASE_ATTRIBUTES, - ) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_full_expired(self, mock_region): - pool = self._make_one(target_size=1) - database = _Database("name") - pool.bind(database) - older = _Session(database) - pool.put(older) - self.assertFalse(pool._sessions.empty()) - - younger = _Session(database, exists=False) - pool.put(younger) # discarded silently - - self.assertTrue(younger._deleted) - self.assertIs(pool.get(), older) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_clear(self, mock_region): - pool = self._make_one() - database = _Database("name") - pool.bind(database) - previous = _Session(database) - pool.put(previous) - - pool.clear() - - self.assertTrue(previous._deleted) - self.assertNoSpans() - - -class TestPingingPool(OpenTelemetryBase): - BASE_ATTRIBUTES = { - "db.type": "spanner", - "db.url": "spanner.googleapis.com", - "db.instance": "name", - "net.host.name": "spanner.googleapis.com", - "gcp.client.service": "spanner", - "gcp.client.version": LIB_VERSION, - "gcp.client.repo": "googleapis/python-spanner", - "cloud.region": "global", - } - enrich_with_otel_scope(BASE_ATTRIBUTES) - - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import PingingPool - - return PingingPool - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_defaults(self): - pool = self._make_one() - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.empty()) - self.assertEqual(pool.labels, {}) - self.assertIsNone(pool.database_role) - - def test_ctor_explicit(self): - labels = {"foo": "bar"} - database_role = "dummy-role" - pool = self._make_one( - size=4, - default_timeout=30, - ping_interval=1800, - labels=labels, - database_role=database_role, - ) - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 4) - self.assertEqual(pool.default_timeout, 30) - self.assertEqual(pool._delta.seconds, 1800) - self.assertTrue(pool._sessions.empty()) - self.assertEqual(pool.labels, labels) - self.assertEqual(pool.database_role, database_role) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ctor_explicit_w_database_role_in_db(self, mock_region): - database_role = "dummy-role" - pool = self._make_one() - database = pool._database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) - database._database_role = database_role - pool.bind(database) - self.assertEqual(pool.database_role, database_role) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_bind(self, mock_region): - pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) - pool.bind(database) - - self.assertIs(pool._database, database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_hit_no_ping(self, mock_region): - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - self.reset() - - session = pool.get() - - self.assertIs(session, SESSIONS[0]) - self.assertFalse(session._exists_checked) - self.assertFalse(pool._sessions.full()) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_hit_w_ping(self, mock_region): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - pool._new_session = mock.Mock(side_effect=SESSIONS) - - sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) - - with _Monkey(MUT, _NOW=lambda: sessions_created): - pool.bind(database) - - self.reset() - - session = pool.get() - - self.assertIs(session, SESSIONS[0]) - self.assertTrue(session._exists_checked) - self.assertFalse(pool._sessions.full()) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_hit_w_ping_expired(self, mock_region): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 5 - SESSIONS[0]._exists = False - pool._new_session = mock.Mock(side_effect=SESSIONS) - - sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) - - with _Monkey(MUT, _NOW=lambda: sessions_created): - pool.bind(database) - self.reset() - - session = pool.get() - - self.assertIs(session, SESSIONS[4]) - session.create.assert_called() - self.assertTrue(SESSIONS[0]._exists_checked) - self.assertFalse(pool._sessions.full()) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_empty_default_timeout(self, mock_region): - import queue - - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - - with self.assertRaises(queue.Empty): - pool.get() - - self.assertEqual(session_queue._got, {"block": True, "timeout": 10}) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_get_empty_explicit_timeout(self, mock_region): - import queue - - pool = self._make_one(size=1, default_timeout=0.1) - session_queue = pool._sessions = _Queue() - - with self.assertRaises(queue.Empty): - pool.get(timeout=1) - - self.assertEqual(session_queue._got, {"block": True, "timeout": 1}) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_full(self, mock_region): - import queue - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) - pool.bind(database) - - with self.assertRaises(queue.Full): - pool.put(_Session(database)) - - self.assertTrue(pool._sessions.full()) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_put_full(self, mock_region): - if not HAS_OPENTELEMETRY_INSTALLED: - return - - import queue - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) - pool.bind(database) - - with self.assertRaises(queue.Full): - pool.put(_Session(database)) - - self.assertTrue(pool._sessions.full()) - - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["CloudSpanner.PingingPool.BatchCreateSessions"] - assert got_span_names == want_span_names - - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id - 1}.{database._channel_id}.{_Database.NTH_REQUEST.value}.1" - attrs = dict( - TestPingingPool.BASE_ATTRIBUTES.copy(), x_goog_spanner_request_id=req_id - ) - self.assertSpanAttributes( - "CloudSpanner.PingingPool.BatchCreateSessions", - attributes=attrs, - span=span_list[-1], - ) - wantEventNames = [ - "Created 2 sessions", - "Created 2 sessions", - "Requested for 4 sessions, returned 4", - ] - self.assertSpanEvents( - "CloudSpanner.PingingPool.BatchCreateSessions", wantEventNames - ) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_put_non_full(self, mock_region): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - - now = datetime.datetime.utcnow() - database = _Database("name") - session = _Session(database) - - with _Monkey(MUT, _NOW=lambda: now): - pool.put(session) - - self.assertEqual(len(session_queue._items), 1) - ping_after, queued = session_queue._items[0] - self.assertEqual(ping_after, now + datetime.timedelta(seconds=3000)) - self.assertIs(queued, session) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_clear(self, mock_region): - pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - self.reset() - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - - pool.clear() - - for session in SESSIONS: - self.assertTrue(session._deleted) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ping_empty(self, mock_region): - pool = self._make_one(size=1) - pool.ping() # Does not raise 'Empty' - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ping_oldest_fresh(self, mock_region): - pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 1 - database._sessions.extend(SESSIONS) - pool.bind(database) - self.reset() - - pool.ping() - - self.assertFalse(SESSIONS[0]._pinged) - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ping_oldest_stale_but_exists(self, mock_region): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 1 - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - - later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) - with _Monkey(MUT, _NOW=lambda: later): - pool.ping() - - self.assertTrue(SESSIONS[0]._pinged) - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_ping_oldest_stale_and_not_exists(self, mock_region): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 2 - SESSIONS[0]._exists = False - pool._new_session = mock.Mock(side_effect=SESSIONS) - pool.bind(database) - self.reset() - - later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) - with _Monkey(MUT, _NOW=lambda: later): - pool.ping() - - self.assertTrue(SESSIONS[0]._pinged) - SESSIONS[1].create.assert_called() - self.assertNoSpans() - - @mock.patch( - "google.cloud.spanner_v1._opentelemetry_tracing._get_cloud_region", - return_value="global", - ) - def test_spans_get_and_leave_empty_pool(self, mock_region): - if not HAS_OPENTELEMETRY_INSTALLED: - return - - # This scenario tests the spans generated from pulling a span - # out the pool and leaving it empty. - pool = self._make_one() - database = _Database("name") - session1 = _Session(database) - pool._new_session = mock.Mock(side_effect=[session1, Exception]) - try: - pool.bind(database) - except Exception: - pass - - with trace_call("pool.Get", session1): - session = pool.get() - self.assertIsInstance(session, _Session) - self.assertIs(session._database, database) - # session.create.assert_called() - self.assertTrue(pool._sessions.empty()) - - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["CloudSpanner.PingingPool.BatchCreateSessions", "pool.Get"] - assert got_span_names == want_span_names - - self.assertSpanAttributes( - "pool.Get", - attributes=TestPingingPool.BASE_ATTRIBUTES, - span=span_list[-1], - ) - wantEventNames = [ - "Waiting for a session to become available", - "Acquired session", - ] - self.assertSpanEvents("pool.Get", wantEventNames, span_list[-1]) - - -class TestSessionCheckout(unittest.TestCase): - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - return SessionCheckout - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_wo_kwargs(self): - pool = _Pool() - checkout = self._make_one(pool) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {}) - - def test_ctor_w_kwargs(self): - pool = _Pool() - checkout = self._make_one(pool, foo="bar", database_role="dummy-role") - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual( - checkout._kwargs, {"foo": "bar", "database_role": "dummy-role"} - ) - - def test_context_manager_wo_kwargs(self): - session = object() - pool = _Pool(session) - checkout = self._make_one(pool) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - - with checkout as borrowed: - self.assertIs(borrowed, session) - self.assertEqual(len(pool._items), 0) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - self.assertEqual(pool._got, {}) - - def test_context_manager_w_kwargs(self): - session = object() - pool = _Pool(session) - checkout = self._make_one(pool, foo="bar") - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - - with checkout as borrowed: - self.assertIs(borrowed, session) - self.assertEqual(len(pool._items), 0) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - self.assertEqual(pool._got, {"foo": "bar"}) - - -def _make_transaction(*args, **kw): - from google.cloud.spanner_v1.transaction import Transaction - - txn = mock.create_autospec(Transaction)(*args, **kw) - txn.committed = None - txn.rolled_back = False - return txn - - -@total_ordering -class _Session(object): - _transaction = None - - def __init__( - self, database, exists=True, transaction=None, last_use_time=datetime.utcnow() - ): - self._database = database - self._exists = exists - self._exists_checked = False - self._pinged = False - self.create = mock.Mock() - self._deleted = False - self._transaction = transaction - self._last_use_time = last_use_time - # Generate a faux id. - self._session_id = f"{time.time()}" - - def __lt__(self, other): - return id(self) < id(other) - - @property - def last_use_time(self): - return self._last_use_time - - def exists(self): - self._exists_checked = True - return self._exists - - def ping(self): - from google.cloud.exceptions import NotFound - - self._pinged = True - if not self._exists: - raise NotFound("expired session") - - def delete(self): - from google.cloud.exceptions import NotFound - - self._deleted = True - if not self._exists: - raise NotFound("unknown session") - - def transaction(self): - txn = self._transaction = _make_transaction(self) - return txn - - @property - def session_id(self): - return self._session_id - - -class _Database(object): - NTH_REQUEST = AtomicCounter() - NTH_CLIENT_ID = AtomicCounter() - - def __init__(self, name): - self.name = name - self._sessions = [] - self._database_role = None - self.database_id = name - self._route_to_leader_enabled = True - - def mock_batch_create_sessions( - request=None, - timeout=10, - metadata=[], - labels={}, - ): - from google.cloud.spanner_v1 import BatchCreateSessionsResponse - from google.cloud.spanner_v1 import Session - - database_role = request.session_template.creator_role if request else None - if request.session_count < 2: - response = BatchCreateSessionsResponse( - session=[Session(creator_role=database_role, labels=labels)] - ) - else: - response = BatchCreateSessionsResponse( - session=[ - Session(creator_role=database_role, labels=labels), - Session(creator_role=database_role, labels=labels), - ] - ) - return response - - from google.cloud.spanner_v1 import SpannerClient - - self.spanner_api = mock.create_autospec(SpannerClient, instance=True) - self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions - - @property - def database_role(self): - """Database role used in sessions to connect to this database. - - :rtype: str - :returns: an str with the name of the database role. - """ - return self._database_role - - def session(self, **kwargs): - # always return first session in the list - # to avoid reversing the order of putting - # sessions into pool (important for order tests) - return self._sessions.pop(0) - - @property - def observability_options(self): - return dict(db_name=self.name) - - @property - def _next_nth_request(self): - return self.NTH_REQUEST.increment() - - @property - def _nth_client_id(self): - return self.NTH_CLIENT_ID.increment() - - def metadata_with_request_id( - self, nth_request, nth_attempt, prior_metadata=[], span=None - ): - return _metadata_with_request_id( - self._nth_client_id, - self._channel_id, - nth_request, - nth_attempt, - prior_metadata, - span, - ) - - @property - def _channel_id(self): - return 1 - - -class _Queue(object): - _size = 1 - - def __init__(self, *items): - self._items = list(items) - - def empty(self): - return len(self._items) == 0 - - def full(self): - return len(self._items) >= self._size - - def get(self, **kwargs): - import queue - - self._got = kwargs - try: - return self._items.pop() - except IndexError: - raise queue.Empty() - - def put(self, item, **kwargs): - self._put = kwargs - self._items.append(item) - - def put_nowait(self, item, **kwargs): - self._put_nowait = kwargs - self._items.append(item) - - -class _Pool(_Queue): - _database = None From 4a2fc45487d9c6d1642d331bf1abd425b5af3a78 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 13 Jan 2026 12:13:36 +0530 Subject: [PATCH 2/5] remove workflow for regular session pool --- ...against-emulator-with-regular-session.yaml | 35 ------------------- .../integration-regular-sessions-enabled.cfg | 22 ------------ 2 files changed, 57 deletions(-) delete mode 100644 .github/workflows/integration-tests-against-emulator-with-regular-session.yaml delete mode 100644 .kokoro/presubmit/integration-regular-sessions-enabled.cfg diff --git a/.github/workflows/integration-tests-against-emulator-with-regular-session.yaml b/.github/workflows/integration-tests-against-emulator-with-regular-session.yaml deleted file mode 100644 index 3f2d3b7ba2..0000000000 --- a/.github/workflows/integration-tests-against-emulator-with-regular-session.yaml +++ /dev/null @@ -1,35 +0,0 @@ -on: - push: - branches: - - main - pull_request: -name: Run Spanner integration tests against emulator with regular sessions -jobs: - system-tests: - runs-on: ubuntu-latest - - services: - emulator: - image: gcr.io/cloud-spanner-emulator/emulator:latest - ports: - - 9010:9010 - - 9020:9020 - - steps: - - name: Checkout code - uses: actions/checkout@v5 - - name: Setup Python - uses: actions/setup-python@v6 - with: - python-version: 3.14 - - name: Install nox - run: python -m pip install nox - - name: Run system tests - run: nox -s system - env: - SPANNER_EMULATOR_HOST: localhost:9010 - GOOGLE_CLOUD_PROJECT: emulator-test-project - GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE: true - GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS: false - GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS: false - GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW: false diff --git a/.kokoro/presubmit/integration-regular-sessions-enabled.cfg b/.kokoro/presubmit/integration-regular-sessions-enabled.cfg deleted file mode 100644 index 439abd4ba5..0000000000 --- a/.kokoro/presubmit/integration-regular-sessions-enabled.cfg +++ /dev/null @@ -1,22 +0,0 @@ -# Format: //devtools/kokoro/config/proto/build.proto - -# Only run a subset of all nox sessions -env_vars: { - key: "NOX_SESSION" - value: "unit-3.9 unit-3.14 system-3.14" -} - -env_vars: { - key: "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" - value: "false" -} - -env_vars: { - key: "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" - value: "false" -} - -env_vars: { - key: "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" - value: "false" -} \ No newline at end of file From 941fe1ae2a9f2c7f827d8aa395ecac510d6da668 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 13 Jan 2026 12:21:53 +0530 Subject: [PATCH 3/5] fix lint --- google/cloud/spanner_v1/database.py | 4 +- .../spanner_v1/database_sessions_manager.py | 1 - tests/_helpers.py | 20 ++--- .../mockserver_tests/mock_server_test_base.py | 58 ++++--------- .../test_request_id_header.py | 72 +++++----------- tests/mockserver_tests/test_tags.py | 16 ++-- tests/system/test_database_api.py | 4 +- tests/unit/test_database.py | 83 +++++-------------- 8 files changed, 75 insertions(+), 183 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index aca653fadb..a8e1ad2a00 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -868,9 +868,7 @@ def session(self, labels=None, database_role=None): # instead. role = database_role or self._database_role # Always use multiplexed sessions - return Session( - self, labels=labels, database_role=role, is_multiplexed=True - ) + return Session(self, labels=labels, database_role=role, is_multiplexed=True) def snapshot(self, **kw): """Return an object which wraps a snapshot. diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index c6843777db..47d4446772 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -210,4 +210,3 @@ def _maintain_multiplexed_session(session_manager_ref) -> None: manager._multiplexed_session = manager._build_multiplexed_session() session_created_time = time() - diff --git a/tests/_helpers.py b/tests/_helpers.py index c7502816da..0d742ce3a2 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,5 +1,4 @@ import unittest -from os import getenv import mock @@ -36,21 +35,12 @@ def is_multiplexed_enabled(transaction_type: TransactionType) -> bool: - """Returns whether multiplexed sessions are enabled for the given transaction type.""" + """Returns whether multiplexed sessions are enabled for the given transaction type. - env_var = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" - env_var_partitioned = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" - env_var_read_write = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" - - def _getenv(val: str) -> bool: - return getenv(val, "true").lower().strip() != "false" - - if transaction_type is TransactionType.READ_ONLY: - return _getenv(env_var) - elif transaction_type is TransactionType.PARTITIONED: - return _getenv(env_var) and _getenv(env_var_partitioned) - else: - return _getenv(env_var) and _getenv(env_var_read_write) + Multiplexed sessions are now always enabled for all transaction types. + This function is kept for backward compatibility with existing tests. + """ + return True def get_test_ot_exporter(): diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 75455807d6..a8a8c16916 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -41,7 +41,6 @@ SpannerServicer, start_mock_server, ) -from tests._helpers import is_multiplexed_enabled # Creates an aborted status with the smallest possible retry delay. @@ -240,52 +239,30 @@ def assert_requests_sequence( transaction_type, allow_multiple_batch_create=True, ): - """Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries. + """Assert that the requests sequence matches the expected types, accounting for multiplexed sessions. Args: requests: List of requests from spanner_service.requests expected_types: List of expected request types (excluding session creation requests) - transaction_type: TransactionType enum value to check multiplexed session status - allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest + transaction_type: TransactionType enum value (unused, kept for backward compatibility) + allow_multiple_batch_create: If True, skip leading CreateSessionRequest (kept for backward compatibility) """ - from google.cloud.spanner_v1 import ( - BatchCreateSessionsRequest, - CreateSessionRequest, - ) + from google.cloud.spanner_v1 import CreateSessionRequest - mux_enabled = is_multiplexed_enabled(transaction_type) idx = 0 - # Skip all leading BatchCreateSessionsRequest (for retries) + # Skip CreateSessionRequest for multiplexed session if allow_multiple_batch_create: while idx < len(requests) and isinstance( - requests[idx], BatchCreateSessionsRequest - ): - idx += 1 - # For multiplexed, optionally skip a CreateSessionRequest - if ( - mux_enabled - and idx < len(requests) - and isinstance(requests[idx], CreateSessionRequest) + requests[idx], CreateSessionRequest ): idx += 1 else: - if mux_enabled: - self.assertTrue( - isinstance(requests[idx], BatchCreateSessionsRequest), - f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}", - ) - idx += 1 - self.assertTrue( - isinstance(requests[idx], CreateSessionRequest), - f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}", - ) - idx += 1 - else: - self.assertTrue( - isinstance(requests[idx], BatchCreateSessionsRequest), - f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}", - ) - idx += 1 + # Expect exactly one CreateSessionRequest for multiplexed session + self.assertTrue( + isinstance(requests[idx], CreateSessionRequest), + f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}", + ) + idx += 1 # Check the rest of the expected request types for expected_type in expected_types: self.assertTrue( @@ -303,13 +280,12 @@ def adjust_request_id_sequence(self, expected_segments, requests, transaction_ty Args: expected_segments: List of expected (method, (sequence_numbers)) tuples requests: List of actual requests from spanner_service.requests - transaction_type: TransactionType enum value to check multiplexed session status + transaction_type: TransactionType enum value (unused, kept for backward compatibility) Returns: List of adjusted expected segments with corrected sequence numbers """ from google.cloud.spanner_v1 import ( - BatchCreateSessionsRequest, CreateSessionRequest, ExecuteSqlRequest, BeginTransactionRequest, @@ -318,15 +294,13 @@ def adjust_request_id_sequence(self, expected_segments, requests, transaction_ty # Count session creation requests that come before the first non-session request session_requests_before = 0 for req in requests: - if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)): + if isinstance(req, CreateSessionRequest): session_requests_before += 1 elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)): break - # For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession) - # For non-multiplexed, we expect 1 session request (BatchCreateSessions) - mux_enabled = is_multiplexed_enabled(transaction_type) - expected_session_requests = 2 if mux_enabled else 1 + # With multiplexed sessions, we expect 1 session request (CreateSession) + expected_session_requests = 1 extra_session_requests = session_requests_before - expected_session_requests # Adjust sequence numbers based on extra session requests diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 055d9d97b5..4cb054277b 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -16,7 +16,6 @@ import threading from google.cloud.spanner_v1 import ( - BatchCreateSessionsRequest, CreateSessionRequest, ExecuteSqlRequest, BeginTransactionRequest, @@ -58,20 +57,17 @@ def test_snapshot_execute_sql(self): NTH_CLIENT = self.database._nth_client_id CHANNEL_ID = self.database._channel_id got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() - # Filter out CreateSessionRequest unary segments for comparison - filtered_unary_segments = [ - seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession") - ] + # With multiplexed sessions, we expect one CreateSession request want_unary_segments = [ ( - "/google.spanner.v1.Spanner/BatchCreateSessions", + "/google.spanner.v1.Spanner/CreateSession", (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), ) ] # Dynamically determine the expected sequence number for ExecuteStreamingSql session_requests_before = 0 for req in requests: - if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)): + if isinstance(req, CreateSessionRequest): session_requests_before += 1 elif isinstance(req, ExecuteSqlRequest): break @@ -88,7 +84,7 @@ def test_snapshot_execute_sql(self): ), ) ] - assert filtered_unary_segments == want_unary_segments + assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments def test_snapshot_read_concurrent(self): @@ -118,45 +114,32 @@ def select1(): for thread in threads: thread.join() requests = self.spanner_service.requests - # Allow for an extra request due to multiplexed session creation - expected_min = 2 + n - expected_max = expected_min + 1 + # With multiplexed sessions: 1 CreateSession + (n + 1) ExecuteSql + expected_min = 1 + n + 1 + expected_max = expected_min assert ( expected_min <= len(requests) <= expected_max - ), f"Expected {expected_min} or {expected_max} requests, got {len(requests)}: {requests}" + ), f"Expected {expected_min} requests, got {len(requests)}: {requests}" client_id = db._nth_client_id channel_id = db._channel_id got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() want_unary_segments = [ ( - "/google.spanner.v1.Spanner/BatchCreateSessions", + "/google.spanner.v1.Spanner/CreateSession", (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), ), ] assert any(seg == want_unary_segments[0] for seg in got_unary_segments) - # Dynamically determine the expected sequence numbers for ExecuteStreamingSql - session_requests_before = 0 - for req in requests: - if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)): - session_requests_before += 1 - elif isinstance(req, ExecuteSqlRequest): - break - want_stream_segments = [ - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - ( - 1, - REQ_RAND_PROCESS_ID, - client_id, - channel_id, - session_requests_before + i, - 1, - ), - ) - for i in range(1, n + 2) - ] - assert got_stream_segments == want_stream_segments + # Verify we have the expected number of ExecuteStreamingSql segments + # (n + 1 = 11 for initial + 10 concurrent) + assert len(got_stream_segments) == n + 1 + # Verify all segments are for ExecuteStreamingSql + for seg in got_stream_segments: + assert seg[0] == "/google.spanner.v1.Spanner/ExecuteStreamingSql" + # Verify the segment has correct client_id and channel_id + assert seg[1][2] == client_id + assert seg[1][3] == channel_id def test_database_run_in_transaction_retries_on_abort(self): counters = dict(aborted=0) @@ -192,19 +175,15 @@ def test_database_execute_partitioned_dml_request_id(self): got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() NTH_CLIENT = self.database._nth_client_id CHANNEL_ID = self.database._channel_id - # Allow for extra unary segments due to session creation - filtered_unary_segments = [ - seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession") - ] # Find the actual sequence number for BeginTransaction begin_txn_seq = None - for seg in filtered_unary_segments: + for seg in got_unary_segments: if seg[0].endswith("/BeginTransaction"): begin_txn_seq = seg[1][4] break want_unary_segments = [ ( - "/google.spanner.v1.Spanner/BatchCreateSessions", + "/google.spanner.v1.Spanner/CreateSession", (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), ), ( @@ -212,13 +191,6 @@ def test_database_execute_partitioned_dml_request_id(self): (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, begin_txn_seq, 1), ), ] - # Dynamically determine the expected sequence number for ExecuteStreamingSql - session_requests_before = 0 - for req in requests: - if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)): - session_requests_before += 1 - elif isinstance(req, ExecuteSqlRequest): - break # Find the actual sequence number for ExecuteStreamingSql exec_sql_seq = got_stream_segments[0][1][4] if got_stream_segments else None want_stream_segments = [ @@ -227,12 +199,12 @@ def test_database_execute_partitioned_dml_request_id(self): (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, exec_sql_seq, 1), ) ] - assert all(seg in filtered_unary_segments for seg in want_unary_segments) + assert all(seg in got_unary_segments for seg in want_unary_segments) assert got_stream_segments == want_stream_segments def test_unary_retryable_error(self): add_select1_result() - add_error(SpannerServicer.BatchCreateSessions.__name__, unavailable_status()) + add_error(SpannerServicer.CreateSession.__name__, unavailable_status()) if not getattr(self.database, "_interceptors", None): self.database._interceptors = MockServerTestBase._interceptors diff --git a/tests/mockserver_tests/test_tags.py b/tests/mockserver_tests/test_tags.py index 9e35517797..200b8d1d0a 100644 --- a/tests/mockserver_tests/test_tags.py +++ b/tests/mockserver_tests/test_tags.py @@ -23,7 +23,6 @@ MockServerTestBase, add_single_result, ) -from tests._helpers import is_multiplexed_enabled from google.cloud.spanner_v1.database_sessions_manager import TransactionType @@ -100,8 +99,9 @@ def test_select_read_only_transaction_with_transaction_tag(self): TransactionType.READ_ONLY, ) # Transaction tags are not supported for read-only transactions. - mux_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) - tag_idx = 3 if mux_enabled else 2 + # With multiplexed sessions: CreateSession, BeginTransaction, ExecuteSql, ExecuteSql + # ExecuteSql requests start at index 2 + tag_idx = 2 self.assertEqual("", requests[tag_idx].request_options.transaction_tag) self.assertEqual("", requests[tag_idx + 1].request_options.transaction_tag) @@ -155,8 +155,9 @@ def test_select_read_write_transaction_with_transaction_tag(self): ], TransactionType.READ_WRITE, ) - mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) - tag_idx = 3 if mux_enabled else 2 + # With multiplexed sessions: CreateSession, BeginTransaction, ExecuteSql, ExecuteSql, Commit + # ExecuteSql requests start at index 2, Commit at index 4 + tag_idx = 2 self.assertEqual( "my_transaction_tag", requests[tag_idx].request_options.transaction_tag ) @@ -187,8 +188,9 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self): ], TransactionType.READ_WRITE, ) - mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) - tag_idx = 3 if mux_enabled else 2 + # With multiplexed sessions: CreateSession, BeginTransaction, ExecuteSql, ExecuteSql, Commit + # ExecuteSql requests start at index 2, Commit at index 4 + tag_idx = 2 self.assertEqual( "my_transaction_tag", requests[tag_idx].request_options.transaction_tag ) diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 52c70517da..2c6302d2f4 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -249,9 +249,7 @@ def test_update_ddl_w_operation_id( # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/5629 # ) temp_db_id = _helpers.unique_id("update_ddl", separator="_") - temp_db = shared_instance.database( - temp_db_id, database_dialect=database_dialect - ) + temp_db = shared_instance.database(temp_db_id, database_dialect=database_dialect) create_op = temp_db.create() databases_to_delete.append(temp_db) create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 3d46d94302..f9bb2f2a0f 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -186,7 +186,8 @@ def test_ctor_w_ddl_statements_ok(self): instance = _Instance(self.INSTANCE_NAME) database = self._make_one( - self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS ) + self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS + ) self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) @@ -391,9 +392,7 @@ def test_default_leader(self): def test_proto_descriptors(self): instance = _Instance(self.INSTANCE_NAME) - database = self._make_one( - self.DATABASE_ID, instance, proto_descriptors=b"" - ) + database = self._make_one(self.DATABASE_ID, instance, proto_descriptors=b"") self.assertEqual(database.proto_descriptors, b"") def test_spanner_api_property_w_scopeless_creds(self): @@ -1048,7 +1047,8 @@ def test_update_success(self): instance = _Instance(self.INSTANCE_NAME, client=client) database = self._make_one( - self.DATABASE_ID, instance, enable_drop_protection=True ) + self.DATABASE_ID, instance, enable_drop_protection=True + ) future = database.update(["enable_drop_protection"]) @@ -1181,7 +1181,6 @@ def _execute_partitioned_dml_helper( retried=False, exclude_txn_from_change_streams=False, ): - import os from google.api_core.exceptions import Aborted from google.api_core.retry import Retry from google.protobuf.struct_pb2 import Struct @@ -1212,32 +1211,17 @@ def _execute_partitioned_dml_helper( client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - session = _Session() database = self._make_one(self.DATABASE_ID, instance) - multiplexed_partitioned_enabled = ( - os.environ.get( - "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS", "true" - ).lower() - != "false" - ) - - if multiplexed_partitioned_enabled: - # When multiplexed sessions are enabled, create a mock multiplexed session - # that the sessions manager will return - multiplexed_session = _Session() - multiplexed_session.name = ( - self.SESSION_NAME - ) # Use the expected session name - multiplexed_session.is_multiplexed = True - # Configure the sessions manager to return the multiplexed session - database._sessions_manager.get_session = mock.Mock( - return_value=multiplexed_session - ) - expected_session = multiplexed_session - else: - # Multiplexed sessions are now always used - expected_session = session + # Create a mock multiplexed session that the sessions manager will return + multiplexed_session = _Session() + multiplexed_session.name = self.SESSION_NAME # Use the expected session name + multiplexed_session.is_multiplexed = True + # Configure the sessions manager to return the multiplexed session + database._sessions_manager.get_session = mock.Mock( + return_value=multiplexed_session + ) + expected_session = multiplexed_session api = database._spanner_api = self._make_spanner_api() api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())} @@ -1404,13 +1388,10 @@ def _execute_partitioned_dml_helper( ) self.assertEqual(api.execute_streaming_sql.call_count, 1) - # Verify that the correct session type was used based on environment - if multiplexed_partitioned_enabled: - # Verify that sessions_manager.get_session was called with PARTITIONED transaction type - database._sessions_manager.get_session.assert_called_with( - TransactionType.PARTITIONED - ) - # Multiplexed sessions are now always used + # Verify that sessions_manager.get_session was called with PARTITIONED transaction type + database._sessions_manager.get_session.assert_called_with( + TransactionType.PARTITIONED + ) def test_execute_partitioned_dml_wo_params(self): self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) @@ -1483,7 +1464,6 @@ def test_session_factory_w_labels(self): def test_snapshot_defaults(self): from google.cloud.spanner_v1.database import SnapshotCheckout - from google.cloud.spanner_v1.snapshot import Snapshot client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) @@ -1507,7 +1487,6 @@ def test_snapshot_w_read_timestamp_and_multi_use(self): import datetime from google.cloud._helpers import UTC from google.cloud.spanner_v1.database import SnapshotCheckout - from google.cloud.spanner_v1.snapshot import Snapshot now = datetime.datetime.utcnow().replace(tzinfo=UTC) client = _Client() @@ -1534,7 +1513,6 @@ def test_batch(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - session = _Session() database = self._make_one(self.DATABASE_ID, instance) checkout = database.batch() @@ -1546,7 +1524,6 @@ def test_mutation_groups(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - session = _Session() database = self._make_one(self.DATABASE_ID, instance) checkout = database.mutation_groups() @@ -1597,8 +1574,6 @@ def test_run_in_transaction_wo_args(self): NOW = datetime.datetime.now() client = _Client(observability_options=dict(enable_end_to_end_tracing=True)) instance = _Instance(self.INSTANCE_NAME, client=client) - session = _Session() - session._committed = NOW database = self._make_one(self.DATABASE_ID, instance) # Mock the spanner_api to avoid creating a real SpannerClient database._spanner_api = instance._client._spanner_api @@ -1622,8 +1597,6 @@ def test_run_in_transaction_w_args(self): NOW = datetime.datetime.now() client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) - session = _Session() - session._committed = NOW database = self._make_one(self.DATABASE_ID, instance) # Mock the spanner_api to avoid creating a real SpannerClient database._spanner_api = instance._client._spanner_api @@ -2091,7 +2064,6 @@ def test_context_mgr_w_commit_stats_success(self): database.log_commit_stats = True api = database.spanner_api = self._make_spanner_client() api.commit.return_value = response - session = _Session(database) checkout = self._make_one(database) with checkout as batch: @@ -2135,15 +2107,13 @@ def test_context_mgr_w_aborted_commit_status(self): database.log_commit_stats = True api = database.spanner_api = self._make_spanner_client() api.commit.side_effect = Aborted("aborted exception", errors=("Aborted error")) - session = _Session(database) checkout = self._make_one(database, timeout_secs=0.1, default_retry_delay=0) with self.assertRaises(Aborted): with checkout as batch: - self.assertIsInstance(batch, Batch) + self.assertIsInstance(batch, Batch) self.assertIs(batch._session, database._default_session) - expected_txn_options = TransactionOptions(read_write={}) request = CommitRequest( @@ -2172,7 +2142,6 @@ def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch database = _Database(self.DATABASE_NAME) - session = _Session(database) checkout = self._make_one(database) class Testing(Exception): @@ -2197,8 +2166,7 @@ def test_ctor_defaults(self): from google.cloud.spanner_v1.snapshot import Snapshot database = _Database(self.DATABASE_NAME) - session = _Session(database) - + checkout = self._make_one(database) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {}) @@ -2209,7 +2177,6 @@ def test_ctor_defaults(self): self.assertTrue(snapshot._strong) self.assertFalse(snapshot._multi_use) - def test_ctor_w_read_timestamp_and_multi_use(self): import datetime from google.cloud._helpers import UTC @@ -2217,8 +2184,7 @@ def test_ctor_w_read_timestamp_and_multi_use(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) database = _Database(self.DATABASE_NAME) - session = _Session(database) - + checkout = self._make_one(database, read_timestamp=now, multi_use=True) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {"read_timestamp": now, "multi_use": True}) @@ -2229,12 +2195,10 @@ def test_ctor_w_read_timestamp_and_multi_use(self): self.assertEqual(snapshot._read_timestamp, now) self.assertTrue(snapshot._multi_use) - def test_context_mgr_failure(self): from google.cloud.spanner_v1.snapshot import Snapshot database = _Database(self.DATABASE_NAME) - session = _Session(database) checkout = self._make_one(database) class Testing(Exception): @@ -3144,7 +3108,6 @@ def test_ctor(self): from google.cloud.spanner_v1.batch import MutationGroups database = _Database(self.DATABASE_NAME) - session = _Session(database) checkout = self._make_one(database) self.assertIs(checkout._database, database) @@ -3152,7 +3115,6 @@ def test_ctor(self): self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, database._default_session) - def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1._helpers import _make_list_value_pbs @@ -3173,7 +3135,6 @@ def test_context_mgr_success(self): database = _Database(self.DATABASE_NAME) api = database.spanner_api = self._make_spanner_client() api.batch_write.return_value = [response] - session = _Session(database) checkout = self._make_one(database) request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) @@ -3202,7 +3163,6 @@ def test_context_mgr_success(self): groups.batch_write(request_options) self.assertEqual(groups.committed, True) - api.batch_write.assert_called_once_with( request=request, metadata=[ @@ -3219,7 +3179,6 @@ def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import MutationGroups database = _Database(self.DATABASE_NAME) - session = _Session(database) checkout = self._make_one(database) class Testing(Exception): From 62b26c8bbf686a7bd6cc05dcc2ad073ef73d3d14 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 13 Jan 2026 12:27:36 +0530 Subject: [PATCH 4/5] fix flake8 --- tests/system/test_backup_api.py | 1 - tests/system/test_dbapi.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/system/test_backup_api.py b/tests/system/test_backup_api.py index 7349dae0f4..e3210d10cf 100644 --- a/tests/system/test_backup_api.py +++ b/tests/system/test_backup_api.py @@ -19,7 +19,6 @@ import pytest from google.api_core import exceptions -from google.cloud import spanner_v1 from . import _helpers skip_env_reason = f"""\ diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 90e383e245..f33ef0a24e 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -19,7 +19,6 @@ import time import decimal -from google.cloud import spanner_v1 from google.cloud._helpers import UTC from google.cloud.spanner_dbapi.connection import Connection, connect From 69368fbf1dd8d83c74c54d4f4bc32be399053eb6 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 13 Jan 2026 13:37:08 +0530 Subject: [PATCH 5/5] make regular pool methods no-op --- google/cloud/spanner_v1/pool.py | 745 ++------------------------------ 1 file changed, 41 insertions(+), 704 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index f0304bd66c..e437896f3b 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -20,29 +20,8 @@ the need for session pooling. """ -import datetime -import queue -import time - -from google.cloud.exceptions import NotFound -from google.cloud.spanner_v1 import BatchCreateSessionsRequest -from google.cloud.spanner_v1 import Session as SessionProto -from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1._helpers import ( - _metadata_with_prefix, - _metadata_with_leader_aware_routing, -) -from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, - get_current_span, - trace_call, -) from warnings import warn -from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture - -_NOW = datetime.datetime.utcnow # unit tests may replace - _POOL_DEPRECATION_MESSAGE = ( "Session pools are deprecated and will be removed in a future release. " "Multiplexed sessions are now used for all operations by default, " @@ -51,7 +30,7 @@ ) -class AbstractSessionPool(object): +class AbstractSessionPool: """Specifies required API for concrete session pool implementations. .. deprecated:: @@ -93,122 +72,43 @@ def database_role(self): return self._database_role def bind(self, database): - """Associate the pool with a database. + """Associate the pool with a database. No-op for deprecated pools. :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: database used by the pool to create sessions - when needed. - - Concrete implementations of this method may pre-fill the pool - using the database. - - :raises NotImplementedError: abstract method + :param database: database used by the pool (ignored). """ - raise NotImplementedError() + self._database = database def get(self): - """Check a session out from the pool. - - Concrete implementations of this method are allowed to raise an - error to signal that the pool is exhausted, or to block until a - session is available. + """Check a session out from the pool. No-op for deprecated pools. - :raises NotImplementedError: abstract method + :raises NotImplementedError: pools are deprecated """ - raise NotImplementedError() + raise NotImplementedError("Session pools are deprecated") def put(self, session): - """Return a session to the pool. + """Return a session to the pool. No-op for deprecated pools. :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session being returned. - - Concrete implementations of this method are allowed to raise an - error to signal that the pool is full, or to block until it is - not full. - - :raises NotImplementedError: abstract method + :param session: the session being returned (ignored). """ - raise NotImplementedError() + pass def clear(self): - """Delete all sessions in the pool. - - Concrete implementations of this method are allowed to raise an - error to signal that the pool is full, or to block until it is - not full. - - :raises NotImplementedError: abstract method - """ - raise NotImplementedError() - - def _new_session(self): - """Helper for concrete methods creating session instances. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: new session instance. - """ - - role = self.database_role or self._database.database_role - return Session(database=self._database, labels=self.labels, database_role=role) - - def session(self, **kwargs): - """Check out a session from the pool. - - Deprecated. Sessions should be checked out indirectly using context - managers or :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, - rather than checked out directly from the pool. - - :param kwargs: (optional) keyword arguments, passed through to - the returned checkout. - - :rtype: :class:`~google.cloud.spanner_v1.session.SessionCheckout` - :returns: a checkout instance, to be used as a context manager for - accessing the session and returning it to the pool. - """ - return SessionCheckout(self, **kwargs) + """Delete all sessions in the pool. No-op for deprecated pools.""" + pass class FixedSizePool(AbstractSessionPool): - """Concrete session pool implementation: + """Concrete session pool implementation. .. deprecated:: FixedSizePool is deprecated and will be removed in a future release. Multiplexed sessions are now used for all operations by default. - - - Pre-allocates / creates a fixed number of sessions. - - - "Pings" existing sessions via :meth:`session.exists` before returning - sessions that have not been used for more than 55 minutes and replaces - expired sessions. - - - Blocks, with a timeout, when :meth:`get` is called on an empty pool. - Raises after timing out. - - - Raises when :meth:`put` is called on a full pool. That error is - never expected in normal practice, as users should be calling - :meth:`get` followed by :meth:`put` whenever in need of a session. - - :type size: int - :param size: (Deprecated) fixed pool size. This parameter is deprecated - as session pools are no longer needed with multiplexed sessions. - - :type default_timeout: int - :param default_timeout: (Deprecated) default timeout, in seconds, to wait for - a returned session. This parameter is deprecated as session pools are - no longer needed with multiplexed sessions. - - :type labels: dict (str -> str) or None - :param labels: (Optional) user-assigned labels for sessions created - by the pool. - - :type database_role: str - :param database_role: (Optional) user-assigned database_role for the session. """ DEFAULT_SIZE = 10 DEFAULT_TIMEOUT = 10 - DEFAULT_MAX_AGE_MINUTES = 55 def __init__( self, @@ -216,353 +116,50 @@ def __init__( default_timeout=DEFAULT_TIMEOUT, labels=None, database_role=None, - max_age_minutes=DEFAULT_MAX_AGE_MINUTES, + max_age_minutes=55, ): - warn( - _POOL_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) + warn(_POOL_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) super(FixedSizePool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout - self._sessions = queue.LifoQueue(size) - self._max_age = datetime.timedelta(minutes=max_age_minutes) - - def bind(self, database): - """Associate the pool with a database. - - :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: database used by the pool to used to create sessions - when needed. - """ - self._database = database - requested_session_count = self.size - self._sessions.qsize() - span = get_current_span() - span_event_attributes = {"kind": type(self).__name__} - - if requested_session_count <= 0: - add_span_event( - span, - f"Invalid session pool size({requested_session_count}) <= 0", - span_event_attributes, - ) - return - - api = database.spanner_api - metadata = _metadata_with_prefix(database.name) - if database._route_to_leader_enabled: - metadata.append( - _metadata_with_leader_aware_routing(database._route_to_leader_enabled) - ) - self._database_role = self._database_role or self._database.database_role - if requested_session_count > 0: - add_span_event( - span, - f"Requesting {requested_session_count} sessions", - span_event_attributes, - ) - - if self._sessions.full(): - add_span_event(span, "Session pool is already full", span_event_attributes) - return - - request = BatchCreateSessionsRequest( - database=database.name, - session_count=requested_session_count, - session_template=SessionProto(creator_role=self.database_role), - ) - - observability_options = getattr(self._database, "observability_options", None) - with trace_call( - "CloudSpanner.FixedPool.BatchCreateSessions", - observability_options=observability_options, - metadata=metadata, - ) as span, MetricsCapture(): - returned_session_count = 0 - while not self._sessions.full(): - request.session_count = requested_session_count - self._sessions.qsize() - add_span_event( - span, - f"Creating {request.session_count} sessions", - span_event_attributes, - ) - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), - ) - - add_span_event( - span, - "Created sessions", - dict(count=len(resp.session)), - ) - - for session_pb in resp.session: - session = self._new_session() - session._session_id = session_pb.name.split("/")[-1] - self._sessions.put(session) - returned_session_count += 1 - - add_span_event( - span, - f"Requested for {requested_session_count} sessions, returned {returned_session_count}", - span_event_attributes, - ) def get(self, timeout=None): - """Check a session out from the pool. - - :type timeout: int - :param timeout: seconds to block waiting for an available session - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: an existing session from the pool, or a newly-created - session. - :raises: :exc:`queue.Empty` if the queue is empty. - """ - if timeout is None: - timeout = self.default_timeout - - start_time = time.time() - current_span = get_current_span() - span_event_attributes = {"kind": type(self).__name__} - add_span_event(current_span, "Acquiring session", span_event_attributes) - - session = None - try: - add_span_event( - current_span, - "Waiting for a session to become available", - span_event_attributes, - ) - - session = self._sessions.get(block=True, timeout=timeout) - age = _NOW() - session.last_use_time - - if age >= self._max_age and not session.exists(): - if not session.exists(): - add_span_event( - current_span, - "Session is not valid, recreating it", - span_event_attributes, - ) - session = self._new_session() - session.create() - # Replacing with the updated session.id. - span_event_attributes["session.id"] = session._session_id - - span_event_attributes["session.id"] = session._session_id - span_event_attributes["time.elapsed"] = time.time() - start_time - add_span_event(current_span, "Acquired session", span_event_attributes) - - except queue.Empty as e: - add_span_event( - current_span, "No sessions available in the pool", span_event_attributes - ) - raise e - - return session + """Check a session out from the pool. No-op for deprecated pools. - def put(self, session): - """Return a session to the pool. - - Never blocks: if the pool is full, raises. - - :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session being returned. - - :raises: :exc:`queue.Full` if the queue is full. + :raises NotImplementedError: pools are deprecated """ - self._sessions.put_nowait(session) - - def clear(self): - """Delete all sessions in the pool.""" - - while True: - try: - session = self._sessions.get(block=False) - except queue.Empty: - break - else: - session.delete() + raise NotImplementedError("Session pools are deprecated") class BurstyPool(AbstractSessionPool): - """Concrete session pool implementation: + """Concrete session pool implementation. .. deprecated:: BurstyPool is deprecated and will be removed in a future release. Multiplexed sessions are now used for all operations by default. - - - "Pings" existing sessions via :meth:`session.exists` before returning - them. - - - Creates a new session, rather than blocking, when :meth:`get` is called - on an empty pool. - - - Discards the returned session, rather than blocking, when :meth:`put` - is called on a full pool. - - :type target_size: int - :param target_size: (Deprecated) max pool size. This parameter is deprecated - as session pools are no longer needed with multiplexed sessions. - - :type labels: dict (str -> str) or None - :param labels: (Optional) user-assigned labels for sessions created - by the pool. - - :type database_role: str - :param database_role: (Optional) user-assigned database_role for the session. """ - # Internal flag to suppress deprecation warning when BurstyPool is used - # as a fallback/internal implementation detail. - _suppress_warning = False - def __init__(self, target_size=10, labels=None, database_role=None): - if not BurstyPool._suppress_warning: - warn( - _POOL_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) + warn(_POOL_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) super(BurstyPool, self).__init__(labels=labels, database_role=database_role) self.target_size = target_size - self._database = None - self._sessions = queue.LifoQueue(target_size) - - def bind(self, database): - """Associate the pool with a database. - - :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: database used by the pool to create sessions - when needed. - """ - self._database = database - self._database_role = self._database_role or self._database.database_role def get(self): - """Check a session out from the pool. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: an existing session from the pool, or a newly-created - session. - """ - current_span = get_current_span() - span_event_attributes = {"kind": type(self).__name__} - add_span_event(current_span, "Acquiring session", span_event_attributes) - - try: - add_span_event( - current_span, - "Waiting for a session to become available", - span_event_attributes, - ) - session = self._sessions.get_nowait() - except queue.Empty: - add_span_event( - current_span, - "No sessions available in pool. Creating session", - span_event_attributes, - ) - session = self._new_session() - session.create() - else: - if not session.exists(): - add_span_event( - current_span, - "Session is not valid, recreating it", - span_event_attributes, - ) - session = self._new_session() - session.create() - return session - - def put(self, session): - """Return a session to the pool. + """Check a session out from the pool. No-op for deprecated pools. - Never blocks: if the pool is full, the returned session is - discarded. - - :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session being returned. + :raises NotImplementedError: pools are deprecated """ - try: - self._sessions.put_nowait(session) - except queue.Full: - try: - # Sessions from pools are never multiplexed, so we can always delete them - session.delete() - except NotFound: - pass - - def clear(self): - """Delete all sessions in the pool.""" - - while True: - try: - session = self._sessions.get(block=False) - except queue.Empty: - break - else: - session.delete() + raise NotImplementedError("Session pools are deprecated") class PingingPool(AbstractSessionPool): - """Concrete session pool implementation: + """Concrete session pool implementation. .. deprecated:: PingingPool is deprecated and will be removed in a future release. Multiplexed sessions are now used for all operations by default. - - - Pre-allocates / creates a fixed number of sessions. - - - Sessions are used in "round-robin" order (LRU first). - - - "Pings" existing sessions in the background after a specified interval - via an API call (``session.ping()``). - - - Blocks, with a timeout, when :meth:`get` is called on an empty pool. - Raises after timing out. - - - Raises when :meth:`put` is called on a full pool. That error is - never expected in normal practice, as users should be calling - :meth:`get` followed by :meth:`put` whenever in need of a session. - - The application is responsible for calling :meth:`ping` at appropriate - times, e.g. from a background thread. - - :type size: int - :param size: (Deprecated) fixed pool size. This parameter is deprecated - as session pools are no longer needed with multiplexed sessions. - - :type default_timeout: int - :param default_timeout: (Deprecated) default timeout, in seconds, to wait for - a returned session. This parameter is deprecated as session pools are - no longer needed with multiplexed sessions. - - :type ping_interval: int - :param ping_interval: (Deprecated) interval at which to ping sessions. - This parameter is deprecated as session pools are no longer needed - with multiplexed sessions. - - :type labels: dict (str -> str) or None - :param labels: (Optional) user-assigned labels for sessions created - by the pool. - - :type database_role: str - :param database_role: (Optional) user-assigned database_role for the session. """ - # Internal flag to suppress deprecation warning when called from subclass. - _suppress_warning = False - def __init__( self, size=10, @@ -571,229 +168,29 @@ def __init__( labels=None, database_role=None, ): - if not PingingPool._suppress_warning: - warn( - _POOL_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) + warn(_POOL_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) super(PingingPool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout - self._delta = datetime.timedelta(seconds=ping_interval) - self._sessions = queue.PriorityQueue(size) - - def bind(self, database): - """Associate the pool with a database. - - :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: database used by the pool to create sessions - when needed. - """ - self._database = database - api = database.spanner_api - metadata = _metadata_with_prefix(database.name) - if database._route_to_leader_enabled: - metadata.append( - _metadata_with_leader_aware_routing(database._route_to_leader_enabled) - ) - self._database_role = self._database_role or self._database.database_role - - request = BatchCreateSessionsRequest( - database=database.name, - session_count=self.size, - session_template=SessionProto(creator_role=self.database_role), - ) - - span_event_attributes = {"kind": type(self).__name__} - current_span = get_current_span() - requested_session_count = request.session_count - if requested_session_count <= 0: - add_span_event( - current_span, - f"Invalid session pool size({requested_session_count}) <= 0", - span_event_attributes, - ) - return - - add_span_event( - current_span, - f"Requesting {requested_session_count} sessions", - span_event_attributes, - ) - - observability_options = getattr(self._database, "observability_options", None) - with trace_call( - "CloudSpanner.PingingPool.BatchCreateSessions", - observability_options=observability_options, - metadata=metadata, - ) as span, MetricsCapture(): - returned_session_count = 0 - while returned_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), - ) - - add_span_event( - span, - f"Created {len(resp.session)} sessions", - ) - - for session_pb in resp.session: - session = self._new_session() - returned_session_count += 1 - session._session_id = session_pb.name.split("/")[-1] - self.put(session) - - add_span_event( - span, - f"Requested for {requested_session_count} sessions, returned {returned_session_count}", - span_event_attributes, - ) def get(self, timeout=None): - """Check a session out from the pool. - - :type timeout: int - :param timeout: seconds to block waiting for an available session - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: an existing session from the pool, or a newly-created - session. - :raises: :exc:`queue.Empty` if the queue is empty. - """ - if timeout is None: - timeout = self.default_timeout - - start_time = time.time() - span_event_attributes = {"kind": type(self).__name__} - current_span = get_current_span() - add_span_event( - current_span, - "Waiting for a session to become available", - span_event_attributes, - ) - - ping_after = None - session = None - try: - ping_after, session = self._sessions.get(block=True, timeout=timeout) - except queue.Empty as e: - add_span_event( - current_span, - "No sessions available in the pool within the specified timeout", - span_event_attributes, - ) - raise e - - if _NOW() > ping_after: - # Using session.exists() guarantees the returned session exists. - # session.ping() uses a cached result in the backend which could - # result in a recently deleted session being returned. - if not session.exists(): - session = self._new_session() - session.create() - - span_event_attributes.update( - { - "time.elapsed": time.time() - start_time, - "session.id": session._session_id, - "kind": "pinging_pool", - } - ) - add_span_event(current_span, "Acquired session", span_event_attributes) - return session - - def put(self, session): - """Return a session to the pool. + """Check a session out from the pool. No-op for deprecated pools. - Never blocks: if the pool is full, raises. - - :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session being returned. - - :raises: :exc:`queue.Full` if the queue is full. + :raises NotImplementedError: pools are deprecated """ - self._sessions.put_nowait((_NOW() + self._delta, session)) - - def clear(self): - """Delete all sessions in the pool.""" - while True: - try: - _, session = self._sessions.get(block=False) - except queue.Empty: - break - else: - session.delete() + raise NotImplementedError("Session pools are deprecated") def ping(self): - """Refresh maybe-expired sessions in the pool. - - This method is designed to be called from a background thread, - or during the "idle" phase of an event loop. - """ - while True: - try: - ping_after, session = self._sessions.get(block=False) - except queue.Empty: # all sessions in use - break - if ping_after > _NOW(): # oldest session is fresh - # Re-add to queue with existing expiration - self._sessions.put((ping_after, session)) - break - try: - session.ping() - except NotFound: - session = self._new_session() - session.create() - # Re-add to queue with new expiration - self.put(session) + """Refresh maybe-expired sessions in the pool. No-op for deprecated pools.""" + pass class TransactionPingingPool(PingingPool): - """Concrete session pool implementation: + """Concrete session pool implementation. .. deprecated:: TransactionPingingPool is deprecated and will be removed in a future release. Multiplexed sessions are now used for all operations by default. - TransactionPingingPool no longer begins a transaction for each of its sessions - at startup. Hence the TransactionPingingPool is same as :class:`PingingPool`. - - In addition to the features of :class:`PingingPool`, this class - creates and begins a transaction for each of its sessions at startup. - - When a session is returned to the pool, if its transaction has been - committed or rolled back, the pool creates a new transaction for the - session and pushes the transaction onto a separate queue of "transactions - to begin." The application is responsible for flushing this queue - as appropriate via the pool's :meth:`begin_pending_transactions` method. - - :type size: int - :param size: (Deprecated) fixed pool size. This parameter is deprecated - as session pools are no longer needed with multiplexed sessions. - - :type default_timeout: int - :param default_timeout: (Deprecated) default timeout, in seconds, to wait for - a returned session. This parameter is deprecated as session pools are - no longer needed with multiplexed sessions. - - :type ping_interval: int - :param ping_interval: (Deprecated) interval at which to ping sessions. - This parameter is deprecated as session pools are no longer needed - with multiplexed sessions. - - :type labels: dict (str -> str) or None - :param labels: (Optional) user-assigned labels for sessions created - by the pool. - - :type database_role: str - :param database_role: (Optional) user-assigned database_role for the session. """ def __init__( @@ -804,82 +201,23 @@ def __init__( labels=None, database_role=None, ): - """This throws a deprecation warning on initialization.""" - warn( - _POOL_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) - self._pending_sessions = queue.Queue() - - # Suppress warning from parent class to avoid double warning - PingingPool._suppress_warning = True - try: - super(TransactionPingingPool, self).__init__( - size, - default_timeout, - ping_interval, - labels=labels, - database_role=database_role, - ) - finally: - PingingPool._suppress_warning = False - - self.begin_pending_transactions() - - def bind(self, database): - """Associate the pool with a database. - - :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: database used by the pool to create sessions - when needed. - """ - super(TransactionPingingPool, self).bind(database) - self._database_role = self._database_role or self._database.database_role - self.begin_pending_transactions() - - def put(self, session): - """Return a session to the pool. - - Never blocks: if the pool is full, raises. - - :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session being returned. - - :raises: :exc:`queue.Full` if the queue is full. - """ - if self._sessions.full(): - raise queue.Full - - txn = session._transaction - if txn is None or txn.committed or txn.rolled_back: - session.transaction() - self._pending_sessions.put(session) - else: - super(TransactionPingingPool, self).put(session) + # Call grandparent's __init__ directly to avoid double deprecation warning + AbstractSessionPool.__init__(self, labels=labels, database_role=database_role) + warn(_POOL_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) + self.size = size + self.default_timeout = default_timeout def begin_pending_transactions(self): - """Begin all transactions for sessions added to the pool.""" - while not self._pending_sessions.empty(): - session = self._pending_sessions.get() - super(TransactionPingingPool, self).put(session) + """Begin all transactions for sessions added to the pool. No-op for deprecated pools.""" + pass -class SessionCheckout(object): +class SessionCheckout: """Context manager: hold session checked out from a pool. .. deprecated:: SessionCheckout is deprecated and will be removed in a future release. Multiplexed sessions are now used for all operations by default. - Sessions should be checked out indirectly using context managers or - :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, - rather than checked out directly from the pool. - - :type pool: concrete subclass of - :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` - :param pool: (Deprecated) Pool from which to check out a session. - - :param kwargs: extra keyword arguments to be passed to :meth:`pool.get`. """ _session = None @@ -896,8 +234,7 @@ def __init__(self, pool, **kwargs): self._kwargs = kwargs.copy() def __enter__(self): - self._session = self._pool.get(**self._kwargs) - return self._session + raise NotImplementedError("Session pools are deprecated") def __exit__(self, *ignored): - self._pool.put(self._session) + pass