From 33bfffa8af382d1f8f6ef33eb7a3d2043ede7ca7 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Mon, 23 Feb 2026 15:27:24 -0800 Subject: [PATCH 1/5] bring back packages --- src/snowflake/snowpark/session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ad583da1a7..0a131a3929 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -601,6 +601,8 @@ def __init__( self._conn = conn self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} + # unused, needed for test infra? + self._packages: Dict[str, str] = {} # map of artifact repository name -> packages that should be added to functions under that repository self._artifact_repository_packages: DefaultDict[ str, Dict[str, str] From c59a539aabdc33462b887f703a2d6fc8ce632415 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Mon, 23 Feb 2026 17:10:48 -0800 Subject: [PATCH 2/5] keep the PRIVATE api consistent from before --- src/snowflake/snowpark/_internal/udf_utils.py | 20 ++++--- src/snowflake/snowpark/session.py | 34 +++++++++--- tests/integ/test_packaging.py | 54 ++++--------------- tests/unit/test_session.py | 24 ++++----- tests/unit/test_stored_procedure.py | 2 + tests/unit/test_udaf.py | 1 + tests/unit/test_udf.py | 1 + tests/unit/test_udf_utils.py | 6 +-- tests/unit/test_udtf.py | 1 + 9 files changed, 66 insertions(+), 77 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 7c0f9e01f4..f33de2311f 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1142,9 +1142,9 @@ def add_snowpark_package_to_sproc_packages( packages = [this_package] else: with session._package_lock: - existing_packages = session._artifact_repository_packages[ + existing_packages = session._get_packages_by_artifact_repository( artifact_repository - ] + ) if package_name not in existing_packages: packages = list(existing_packages.values()) + [this_package] return packages @@ -1239,7 +1239,9 @@ def resolve_imports_and_packages( ) existing_packages_dict = ( - session._artifact_repository_packages[artifact_repository] if session else {} + session._get_packages_by_artifact_repository(artifact_repository) + if session + else {} ) if artifact_repository != _ANACONDA_SHARED_REPOSITORY: @@ -1248,7 +1250,9 @@ def resolve_imports_and_packages( if not packages and session: resolved_packages = list( session._resolve_packages( - [], artifact_repository, existing_packages_dict + [], + artifact_repository=artifact_repository, + existing_packages_dict=existing_packages_dict, ) ) elif packages: @@ -1286,8 +1290,8 @@ def resolve_imports_and_packages( resolved_packages = ( session._resolve_packages( packages, - artifact_repository, - {}, # ignore session packages if passed in explicitly + artifact_repository=artifact_repository, + existing_packages_dict={}, # ignore session packages if passed in explicitly include_pandas=is_pandas_udf, statement_params=statement_params, _suppress_local_package_warnings=_suppress_local_package_warnings, @@ -1295,8 +1299,8 @@ def resolve_imports_and_packages( if packages is not None else session._resolve_packages( [], - artifact_repository, - existing_packages_dict, + artifact_repository=artifact_repository, + existing_packages_dict=existing_packages_dict, validate_package=False, include_pandas=is_pandas_udf, statement_params=statement_params, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 0a131a3929..5352c27c32 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -601,7 +601,8 @@ def __init__( self._conn = conn self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} - # unused, needed for test infra? + # packages under the DEFAULT_ARTIFACT_REPOSITORY + # due to server side accessing private session members, this cannot be merged with _artifact_repository_packages self._packages: Dict[str, str] = {} # map of artifact repository name -> packages that should be added to functions under that repository self._artifact_repository_packages: DefaultDict[ @@ -1601,6 +1602,14 @@ def _list_files_in_stage( prefix_length = get_stage_file_prefix_length(stage_location) return {str(row[0])[prefix_length:] for row in file_list} + def _get_packages_by_artifact_repository( + self, artifact_repository: str + ) -> Dict[str, str]: + if artifact_repository == _DEFAULT_ARTIFACT_REPOSITORY: + return self._packages + else: + return self._artifact_repository_packages[artifact_repository] + def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, str]: """ Returns a ``dict`` of packages added for user-defined functions (UDFs). @@ -1615,7 +1624,7 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s artifact_repository = self._get_default_artifact_repository() with self._package_lock: - return self._artifact_repository_packages[artifact_repository].copy() + return self._get_packages_by_artifact_repository(artifact_repository).copy() def add_packages( self, @@ -1688,8 +1697,10 @@ def add_packages( self._resolve_packages( parse_positional_args_to_list(*packages), - artifact_repository, - self._artifact_repository_packages[artifact_repository], + artifact_repository=artifact_repository, + existing_packages_dict=self._get_packages_by_artifact_repository( + artifact_repository + ), ) def remove_package( @@ -1726,7 +1737,7 @@ def remove_package( artifact_repository = self._get_default_artifact_repository() with self._package_lock: - packages = self._artifact_repository_packages[artifact_repository] + packages = self._get_packages_by_artifact_repository(artifact_repository) if package_name in packages: packages.pop(package_name) else: @@ -1744,7 +1755,7 @@ def clear_packages( artifact_repository = self._get_default_artifact_repository() with self._package_lock: - self._artifact_repository_packages[artifact_repository].clear() + self._get_packages_by_artifact_repository(artifact_repository).clear() def add_requirements( self, @@ -2112,11 +2123,11 @@ def _get_req_identifiers_list( def _resolve_packages( self, packages: List[Union[str, ModuleType]], - artifact_repository: str, - existing_packages_dict: Dict[str, str], + existing_packages_dict: Dict[str, str] = None, validate_package: bool = True, include_pandas: bool = False, statement_params: Optional[Dict[str, str]] = None, + artifact_repository: str = None, **kwargs, ) -> List[str]: """ @@ -2134,6 +2145,13 @@ def _resolve_packages( Returns: List[str]: List of package specifiers """ + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + if existing_packages_dict is None: + existing_packages_dict = self._get_packages_by_artifact_repository( + artifact_repository + ) + # Always include cloudpickle extra_modules = [cloudpickle] if include_pandas: diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 02998f4a3a..f73bcb09fa 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -271,8 +271,8 @@ def extract_major_minor_patch(version_string): resolved_packages = session._resolve_packages( [numpy, pandas, dateutil], - _ANACONDA_SHARED_REPOSITORY, - {}, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + existing_packages_dict={}, validate_package=False, ) # resolved_packages is a list of strings like @@ -1204,17 +1204,10 @@ def test_replicate_local_environment(session): "force_push": True, } - assert not any( - [ - package.startswith("cloudpickle") - for package in session._artifact_repository_packages[ - _ANACONDA_SHARED_REPOSITORY - ] - ] - ) + assert not any([package.startswith("cloudpickle") for package in session._packages]) def naive_add_packages(self, packages): - self._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = packages + self._packages = packages with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True): with patch.object(Session, "add_packages", new=naive_add_packages): @@ -1228,22 +1221,10 @@ def naive_add_packages(self, packages): }, ) - assert any( - [ - package.startswith("cloudpickle==") - for package in session._artifact_repository_packages[ - _ANACONDA_SHARED_REPOSITORY - ] - ] - ) + assert any([package.startswith("cloudpickle==") for package in session._packages]) for default_package in DEFAULT_PACKAGES: assert not any( - [ - package.startswith(default_package) - for package in session._artifact_repository_packages[ - _ANACONDA_SHARED_REPOSITORY - ] - ] + [package.startswith(default_package) for package in session._packages] ) session.clear_packages() @@ -1262,29 +1243,12 @@ def naive_add_packages(self, packages): ignore_packages=ignored_packages, relax=True ) - assert any( - [ - package == "cloudpickle" - for package in session._artifact_repository_packages[ - _ANACONDA_SHARED_REPOSITORY - ] - ] - ) + assert any([package == "cloudpickle" for package in session._packages]) for default_package in DEFAULT_PACKAGES: assert not any( - [ - package.startswith(default_package) - for package in session._artifact_repository_packages[ - _ANACONDA_SHARED_REPOSITORY - ] - ] + [package.startswith(default_package) for package in session._packages] ) for ignored_package in ignored_packages: assert not any( - [ - package.startswith(ignored_package) - for package in session._artifact_repository_packages[ - _ANACONDA_SHARED_REPOSITORY - ] - ] + [package.startswith(ignored_package) for package in session._packages] ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 7a64f052d4..ed7cb44427 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -213,8 +213,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True session._resolve_packages( ["random_package_name"], - _ANACONDA_SHARED_REPOSITORY, - {}, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + existing_packages_dict={}, validate_package=True, include_pandas=False, ) @@ -248,8 +248,8 @@ def run_query(sql: str): ): session._resolve_packages( ["random_package_name"], - _ANACONDA_SHARED_REPOSITORY, - {}, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + existing_packages_dict={}, validate_package=True, include_pandas=False, ) @@ -273,8 +273,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True resolved_packages = session._resolve_packages( ["random_package_name"], - _ANACONDA_SHARED_REPOSITORY, - existing_packages, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + existing_packages_dict=existing_packages, validate_package=True, include_pandas=False, ) @@ -305,8 +305,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True ): session._resolve_packages( ["snowflake-snowpark-python"], - _ANACONDA_SHARED_REPOSITORY, - {}, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + existing_packages_dict={}, validate_package=True, include_pandas=False, _suppress_local_package_warnings=True, @@ -333,16 +333,16 @@ def assert_packages(packages): packages = session._resolve_packages( ["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"], - "snowflake.snowpark.pypi_shared_repository", - existing_packages, + artifact_repository="snowflake.snowpark.pypi_shared_repository", + existing_packages_dict=existing_packages, ) assert_packages(packages) packages = session._resolve_packages( [], - "snowflake.snowpark.pypi_shared_repository", - existing_packages, + artifact_repository="snowflake.snowpark.pypi_shared_repository", + existing_packages_dict=existing_packages, ) assert_packages(packages) diff --git a/tests/unit/test_stored_procedure.py b/tests/unit/test_stored_procedure.py index 2e6aff999a..eb9917f682 100644 --- a/tests/unit/test_stored_procedure.py +++ b/tests/unit/test_stored_procedure.py @@ -44,6 +44,7 @@ def test_stored_procedure_execute_as(execute_as): fake_session._analyzer = Analyzer(fake_session) fake_session._runtime_version_from_requirement = None fake_session._artifact_repository_packages = defaultdict(dict) + fake_session._packages = {} def return1(_): return 1 @@ -92,6 +93,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.sproc = StoredProcedureRegistration(fake_session) fake_session._artifact_repository_packages = defaultdict(dict) + fake_session._packages = {} with pytest.raises(SnowparkSQLException) as ex_info: sproc(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udaf.py b/tests/unit/test_udaf.py index 803f8dc886..464be0c856 100644 --- a/tests/unit/test_udaf.py +++ b/tests/unit/test_udaf.py @@ -59,6 +59,7 @@ def test_do_register_udaf_negative(cleanup_registration_patch): fake_session._runtime_version_from_requirement = None fake_session.udaf = UDAFRegistration(fake_session) fake_session._artifact_repository_packages = defaultdict(dict) + fake_session._packages = {} with pytest.raises(SnowparkSQLException) as ex_info: @udaf(session=fake_session) diff --git a/tests/unit/test_udf.py b/tests/unit/test_udf.py index d20c53d048..b3c70bfc70 100644 --- a/tests/unit/test_udf.py +++ b/tests/unit/test_udf.py @@ -34,6 +34,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.udf = UDFRegistration(fake_session) fake_session._artifact_repository_packages = defaultdict(dict) + fake_session._packages = {} with pytest.raises(SnowparkSQLException) as ex_info: udf(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index a301d8769e..0d319d6b1c 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -from collections import defaultdict import logging import os import pickle @@ -254,8 +253,7 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): def test_add_snowpark_package_to_sproc_packages_to_session(): fake_session = mock.create_autospec(Session) - fake_session._artifact_repository_packages = defaultdict(dict) - fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = { + fake_session._packages = { "random_package_one": "random_package_one", "random_package_two": "random_package_two", } @@ -272,7 +270,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): assert len(result) == 3 assert final_name in result - fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY][ + fake_session._packages[ "snowflake-snowpark-python" ] = "snowflake-snowpark-python==1.12.0" result = add_snowpark_package_to_sproc_packages( diff --git a/tests/unit/test_udtf.py b/tests/unit/test_udtf.py index ba986364ed..e6f3585641 100644 --- a/tests/unit/test_udtf.py +++ b/tests/unit/test_udtf.py @@ -42,6 +42,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): fake_session._runtime_version_from_requirement = None fake_session.udtf = UDTFRegistration(fake_session) fake_session._artifact_repository_packages = defaultdict(dict) + fake_session._packages = {} with pytest.raises(SnowparkSQLException) as ex_info: @udtf(output_schema=["num"], session=fake_session) From 51ed81777d9e542da35d2f6984567d355bfd0ff1 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 24 Feb 2026 10:11:04 -0800 Subject: [PATCH 3/5] fix ut --- tests/unit/test_udf_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 0d319d6b1c..32a1626eb9 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -258,6 +258,9 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_two": "random_package_two", } fake_session._package_lock = threading.RLock() + fake_session._get_packages_by_artifact_repository.side_effect = ( + lambda a: Session._get_packages_by_artifact_repository(fake_session, a) + ) result = add_snowpark_package_to_sproc_packages( session=fake_session, packages=None, From c8061b46d5654fd31f875af81c27cc0107366a8a Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 24 Feb 2026 11:08:42 -0800 Subject: [PATCH 4/5] entity selector args --- src/snowflake/snowpark/session.py | 14 ++++++++++++-- tests/integ/test_udf.py | 30 ++++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5352c27c32..80f7a292d4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2424,7 +2424,10 @@ def _get_default_artifact_repository(self) -> str: if isinstance(self._conn, MockServerConnection): return _DEFAULT_ARTIFACT_REPOSITORY - cache_key = (self.get_current_database(), self.get_current_schema()) + account = self.get_current_account() + database = self.get_current_database() + schema = self.get_current_schema() + cache_key = (database, schema) if ( self._default_artifact_repository_cache is not None @@ -2434,8 +2437,15 @@ def _get_default_artifact_repository(self) -> str: try: python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + entity_selector_args = ( + f"'schema', '{schema}'" + if schema + else f"'database', '{database}'" + if database + else f"'account', '{account}'" + ) result = self._run_query( - f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" + f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}', {entity_selector_args})" ) value = result[0][0] if result else None resolved = value or _DEFAULT_ARTIFACT_REPOSITORY diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 0499a01a3d..d9029efd23 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3108,18 +3108,17 @@ def test_urllib() -> str: ) def test_use_default_artifact_repository(db_parameters): with Session.builder.configs(db_parameters).create() as session: + temp_database = Utils.random_temp_database() temp_schema = Utils.random_temp_schema() - session.sql(f"create schema {temp_schema}").collect() - session.sql(f"use schema {temp_schema}").collect() + session.sql(f"create database {temp_database}").collect() + session.sql(f"use database {temp_database}").collect() session.sql( "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" ).collect() session.sql( - "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY" + "ALTER database set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.anaconda_shared_repository" ).collect() - session.add_packages("art", "cloudpickle") - def test_art() -> str: import art # art is not available in the conda channel, but is in pypi @@ -3128,6 +3127,25 @@ def test_art() -> str: temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + # should not work in the database where the default is anaconda + with pytest.raises( + Exception, + match="Cannot add package art because it is not available in Snowflake", + ): + udf( + session=session, + func=test_art, + name=temp_func_name, + packages=["art", "cloudpickle"], + ) + + session.sql(f"create schema {temp_schema}").collect() + session.use_schema(temp_schema) + session.sql( + "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY" + ).collect() + session.add_packages("art", "cloudpickle") + try: # Test function registration udf( @@ -3142,4 +3160,4 @@ def test_art() -> str: finally: session._run_query(f"drop function if exists {temp_func_name}(int)") - session.sql(f"drop schema {temp_schema}").collect() + session.sql(f"drop database {temp_database}").collect() From 1bdf013a091920ffa8c5467fbba60c8e47b164e5 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 24 Feb 2026 13:13:17 -0800 Subject: [PATCH 5/5] test optional --- tests/unit/test_session.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index ed7cb44427..69cf7b6ac2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -348,6 +348,26 @@ def assert_packages(packages): assert_packages(packages) +def test_resolve_packages_optional_artifact_repository(mock_server_connection): + session = Session(mock_server_connection) + session._get_default_artifact_repository = MagicMock( + return_value="snowflake.snowpark.pypi_shared_repository" + ) + session._artifact_repository_packages = { + "snowflake.snowpark.pypi_shared_repository": { + "numpy": "numpy==1.0.0", + } + } + result = session._resolve_packages( + ["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"], + ) + assert sorted(result) == [ + "cloudpickle==1.0.0", + "numpy==1.0.0", + "snowflake-snowpark-python==1.0.0", + ] + + @pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas") def test_write_pandas_wrong_table_type(mock_server_connection): session = Session(mock_server_connection)