From 8aac920b5c55dd51855c84427c9c0d0035199680 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Thu, 5 Feb 2026 11:16:54 -0800 Subject: [PATCH 01/20] initial implementation --- src/snowflake/snowpark/_internal/udf_utils.py | 39 ++++++++++-- src/snowflake/snowpark/session.py | 59 +++++++++++++++---- tests/unit/test_session.py | 36 ++++++++++- 3 files changed, 114 insertions(+), 20 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index b6139fde7c..062c69e6f7 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1222,21 +1222,45 @@ def resolve_imports_and_packages( Optional[str], bool, ]: - if artifact_repository and artifact_repository != "conda": - # Artifact Repository packages are not resolved + from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY + + use_default_artifact_repository = artifact_repository is None + if use_default_artifact_repository: + artifact_repository = ( + session._get_default_artifact_repository() + if session is not None + else ANACONDA_SHARED_REPOSITORY + ) + + # TODO: if the user explicitly passes in the current default, should we use self._packages? + # note that the current default could change after calling session.add_packages, so it's hard + # to know what the intended default is + existing_packages_dict = {} + if session: + existing_packages_dict = ( + session._packages + if use_default_artifact_repository + else session._artifact_repository_packages[artifact_repository] + ) + + if artifact_repository != ANACONDA_SHARED_REPOSITORY: + # Non-conda artifact repository - skip conda-based package resolution resolved_packages = [] if not packages and session: resolved_packages = list( - session._resolve_packages([], artifact_repository=artifact_repository) + session._resolve_packages( + [], artifact_repository, existing_packages_dict + ) ) elif packages: if not all(isinstance(package, str) for package in packages): raise TypeError( - "Artifact repository requires that all packages be passed as str." + "Non-conda artifact repository requires that all packages be passed as str." ) + # TODO: this will not automatically add required packages like cloudpickle, is that ok? resolved_packages = packages else: - # resolve packages + # resolve packages using conda channel if session is None: # In case of sandbox resolved_packages = resolve_packages_in_client_side_sandbox( packages=packages @@ -1245,6 +1269,8 @@ def resolve_imports_and_packages( resolved_packages = ( session._resolve_packages( packages, + artifact_repository, + {}, # ignore session packages if passed in explicitly include_pandas=is_pandas_udf, statement_params=statement_params, _suppress_local_package_warnings=_suppress_local_package_warnings, @@ -1252,7 +1278,8 @@ def resolve_imports_and_packages( if packages is not None else session._resolve_packages( [], - session._packages, + artifact_repository, + existing_packages_dict, validate_package=False, include_pandas=is_pandas_udf, statement_params=statement_params, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 99293569ff..e9d50c2fa6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -321,6 +321,10 @@ WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None WRITE_ARROW_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None +# The fully qualified name of the Anaconda shared repository (conda channel). +# Used as the fallback/default when the system function is unavailable or returns NULL. +ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" + def _get_active_session() -> "Session": with _session_management_lock: @@ -599,7 +603,10 @@ def __init__( self._conn = conn self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} + # packages that should be added under the default artifact repository + # TODO: now that we have dynamic defaults, should we remove this and just use _artifact_repository_packages always? self._packages: Dict[str, str] = {} + # packages that should be added under an explicit artifact repository self._artifact_repository_packages: DefaultDict[ str, Dict[str, str] ] = defaultdict(dict) @@ -1669,10 +1676,20 @@ def add_packages( to ensure the consistent experience of a UDF between your local environment and the Snowflake server. """ + use_default_artifact_repository = artifact_repository is None + if use_default_artifact_repository: + artifact_repository = self._get_default_artifact_repository() + + existing_packages_dict = ( + self._packages + if use_default_artifact_repository + else self._artifact_repository_packages[artifact_repository] + ) + self._resolve_packages( parse_positional_args_to_list(*packages), - self._packages, - artifact_repository=artifact_repository, + existing_packages_dict, + artifact_repository, ) def remove_package( @@ -2097,11 +2114,11 @@ def _get_req_identifiers_list( def _resolve_packages( self, packages: List[Union[str, ModuleType]], - existing_packages_dict: Optional[Dict[str, str]] = None, + artifact_repository: str, + existing_packages_dict: Dict[str, str], validate_package: bool = True, include_pandas: bool = False, statement_params: Optional[Dict[str, str]] = None, - artifact_repository: Optional[str] = None, **kwargs, ) -> List[str]: """ @@ -2128,18 +2145,12 @@ def _resolve_packages( package_dict = self._parse_packages(packages) if ( isinstance(self._conn, MockServerConnection) - or artifact_repository is not None + or artifact_repository != ANACONDA_SHARED_REPOSITORY ): - # in local testing we don't resolve the packages, we just return what is added + # in local testing or non-conda, we don't resolve the packages, we just return what is added errors = [] with self._package_lock: - if artifact_repository is None: - result_dict = self._packages - else: - result_dict = self._artifact_repository_packages[ - artifact_repository - ] - + result_dict = existing_packages_dict # assumption: packages is empty for pkg_name, _, pkg_req in package_dict.values(): if ( pkg_name in result_dict @@ -2377,6 +2388,28 @@ def _upload_unsupported_packages( return supported_dependencies + new_dependencies + def _get_default_artifact_repository(self) -> str: + """ + Returns the default artifact repository for the current session context + by calling SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY. + + Falls back to the Anaconda shared repository (conda) if: + - the system function is not available / fails, or + - the system function returns NULL (value was never set). + """ + if isinstance(self._conn, MockServerConnection): + return ANACONDA_SHARED_REPOSITORY + + try: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + result = self._run_query( + f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" + ) + value = result[0][0] if result else None + return value or ANACONDA_SHARED_REPOSITORY + except Exception: + return ANACONDA_SHARED_REPOSITORY + def _is_anaconda_terms_acknowledged(self) -> bool: return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 20651499d3..9ecf5b4fe5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -28,7 +28,10 @@ SnowparkInvalidObjectNameException, SnowparkSessionException, ) -from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING +from snowflake.snowpark.session import ( + _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, + ANACONDA_SHARED_REPOSITORY, +) from snowflake.snowpark.types import StructField, StructType @@ -674,3 +677,34 @@ def test_parameter_version(version_value, expected_parameter_value, parameter_na ) session = Session(fake_server_connection) assert getattr(session, parameter_name, None) is expected_parameter_value + + +def test_get_default_artifact_repository(): + fake_server_connection = mock.create_autospec(ServerConnection) + fake_server_connection._thread_safe_session_enabled = True + session = Session(fake_server_connection) + + with mock.patch.object( + session, + "_run_query", + return_value=[["snowflake.snowpark.pypi_shared_repository"]], + ): + result = session._get_default_artifact_repository() + assert result == "snowflake.snowpark.pypi_shared_repository" + + with mock.patch.object( + session, + "_run_query", + return_value=[[None]], + ): + result = session._get_default_artifact_repository() + assert result == ANACONDA_SHARED_REPOSITORY + + # throws error + with mock.patch.object( + session, + "_run_query", + side_effect=ProgrammingError("Function not found"), + ): + result = session._get_default_artifact_repository() + assert result == ANACONDA_SHARED_REPOSITORY From e481ef85f1e958af4e59ab064c3ce45c2a346565 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Thu, 5 Feb 2026 11:36:11 -0800 Subject: [PATCH 02/20] remove cloudpickle todo --- src/snowflake/snowpark/_internal/udf_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 062c69e6f7..03107a44e7 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1257,7 +1257,6 @@ def resolve_imports_and_packages( raise TypeError( "Non-conda artifact repository requires that all packages be passed as str." ) - # TODO: this will not automatically add required packages like cloudpickle, is that ok? resolved_packages = packages else: # resolve packages using conda channel From 8e3950f30723b3af3de7555a096bd945f6c5a9c8 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Thu, 5 Feb 2026 12:03:02 -0800 Subject: [PATCH 03/20] simple int test --- tests/integ/conftest.py | 4 ++++ tests/integ/test_udf.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 9df63dcc92..4166dd1eb9 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -357,6 +357,10 @@ def session( "alter session set ENABLE_EXTRACTION_PUSHDOWN_EXTERNAL_PARQUET_FOR_COPY_PHASE_I='Track';" ).collect() session.sql("alter session set ENABLE_ROW_ACCESS_POLICY=true").collect() + # TODO: remove + session.sql( + "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" + ).collect() try: yield session diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 7e1c07b456..7b460a0859 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3009,3 +3009,40 @@ def test_urllib() -> str: ) df = session.create_dataframe([1]).to_df(["a"]) Utils.check_answer(df.select(ar_udf()), [Row("test")]) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="artifact repository not supported in local testing", +) +# @pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" +) +def test_use_default_artifact_repository(session): + session.sql( + "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository" + ).collect() + + def test_art() -> str: + import art # art is not available in the conda channel, but is in pypi + + return "art works!" if art.text2art("test") else "art does not work!" + + temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + + try: + # Test function registration + udf( + func=test_art, + name=temp_func_name, + packages=["art", "cloudpickle"], + ) + + # Test UDF call + df = session.create_dataframe([1]).to_df(["a"]) + Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("art works!")]) + finally: + session._run_query(f"drop function if exists {temp_func_name}(int)") + + session.sql("ALTER schema unset DEFAULT_PYTHON_ARTIFACT_REPOSITORY").collect() From 62cf6a03d0eff4a4a333fc2c8aa202a6a913f521 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Thu, 5 Feb 2026 13:08:55 -0800 Subject: [PATCH 04/20] update tests + fix bug --- src/snowflake/snowpark/session.py | 2 +- tests/integ/test_udf.py | 7 +++++-- tests/unit/test_session.py | 17 ++++++++++++++--- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index e9d50c2fa6..62bab4277a 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1688,8 +1688,8 @@ def add_packages( self._resolve_packages( parse_positional_args_to_list(*packages), - existing_packages_dict, artifact_repository, + existing_packages_dict, ) def remove_package( diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 7b460a0859..484cec8a53 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3020,14 +3020,18 @@ def test_urllib() -> str: sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" ) def test_use_default_artifact_repository(session): + # TODO: is this safe with parallel testing? session.sql( "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository" ).collect() + session.add_packages("art", "cloudpickle") + def test_art() -> str: import art # art is not available in the conda channel, but is in pypi - return "art works!" if art.text2art("test") else "art does not work!" + _ = art.text2art("test") + return "art works!" temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) @@ -3036,7 +3040,6 @@ def test_art() -> str: udf( func=test_art, name=temp_func_name, - packages=["art", "cloudpickle"], ) # Test UDF call diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9ecf5b4fe5..a4cda2afe2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -214,7 +214,11 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True session.table.side_effect = mock_get_information_schema_packages session._resolve_packages( - ["random_package_name"], validate_package=True, include_pandas=False + ["random_package_name"], + ANACONDA_SHARED_REPOSITORY, + {}, + validate_package=True, + include_pandas=False, ) @@ -245,7 +249,11 @@ def run_query(sql: str): "#using-third-party-packages-from-anaconda.", ): session._resolve_packages( - ["random_package_name"], validate_package=True, include_pandas=False + ["random_package_name"], + ANACONDA_SHARED_REPOSITORY, + {}, + validate_package=True, + include_pandas=False, ) @@ -267,7 +275,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True resolved_packages = session._resolve_packages( ["random_package_name"], - existing_packages_dict=existing_packages, + ANACONDA_SHARED_REPOSITORY, + existing_packages, validate_package=True, include_pandas=False, ) @@ -298,6 +307,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True ): session._resolve_packages( ["snowflake-snowpark-python"], + ANACONDA_SHARED_REPOSITORY, + {}, validate_package=True, include_pandas=False, _suppress_local_package_warnings=True, From e8dd42dff0ded1badabc683a6db905a07ae55cf0 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Thu, 5 Feb 2026 13:21:00 -0800 Subject: [PATCH 05/20] non conda resolve test --- src/snowflake/snowpark/session.py | 2 +- tests/unit/test_session.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 62bab4277a..8d9e9e842b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2150,7 +2150,7 @@ def _resolve_packages( # in local testing or non-conda, we don't resolve the packages, we just return what is added errors = [] with self._package_lock: - result_dict = existing_packages_dict # assumption: packages is empty + result_dict = existing_packages_dict for pkg_name, _, pkg_req in package_dict.values(): if ( pkg_name in result_dict diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index a4cda2afe2..14743ca966 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -318,6 +318,38 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True assert caplog.text == "" +def test_resolve_packages_non_conda_artifact_repository(mock_server_connection): + session = Session(mock_server_connection) + + existing_packages = {} + + def assert_packages(packages): + assert sorted(packages) == [ + "cloudpickle==1.0.0", + "snowflake-snowpark-python==1.0.0", + ] + assert existing_packages == { + "snowflake-snowpark-python": "snowflake-snowpark-python==1.0.0", + "cloudpickle": "cloudpickle==1.0.0", + } + + packages = session._resolve_packages( + ["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"], + "snowflake.snowpark.pypi_shared_repository", + existing_packages, + ) + + assert_packages(packages) + + packages = session._resolve_packages( + [], + "snowflake.snowpark.pypi_shared_repository", + existing_packages, + ) + + assert_packages(packages) + + @pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas") def test_write_pandas_wrong_table_type(mock_server_connection): session = Session(mock_server_connection) From 97ce30dc771c298b2db62301922e0879feedd1c3 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Thu, 5 Feb 2026 17:52:43 -0800 Subject: [PATCH 06/20] remove _packages, merge with _artifact_repository_packages --- src/snowflake/snowpark/_internal/udf_utils.py | 19 +++---- src/snowflake/snowpark/session.py | 48 +++++++---------- src/snowflake/snowpark/stored_procedure.py | 4 ++ tests/integ/test_packaging.py | 51 ++++++++++++++++--- tests/unit/test_stored_procedure.py | 2 - tests/unit/test_udaf.py | 1 - tests/unit/test_udf_utils.py | 5 +- tests/unit/test_udtf.py | 1 - 8 files changed, 75 insertions(+), 56 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 03107a44e7..5907e928b9 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1121,6 +1121,7 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}): def add_snowpark_package_to_sproc_packages( session: Optional["snowflake.snowpark.Session"], packages: Optional[List[Union[str, ModuleType]]], + artifact_repository: Optional[str], ) -> List[Union[str, ModuleType]]: major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -1224,24 +1225,16 @@ def resolve_imports_and_packages( ]: from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY - use_default_artifact_repository = artifact_repository is None - if use_default_artifact_repository: + if artifact_repository is None: artifact_repository = ( session._get_default_artifact_repository() - if session is not None + if session else ANACONDA_SHARED_REPOSITORY ) - # TODO: if the user explicitly passes in the current default, should we use self._packages? - # note that the current default could change after calling session.add_packages, so it's hard - # to know what the intended default is - existing_packages_dict = {} - if session: - existing_packages_dict = ( - session._packages - if use_default_artifact_repository - else session._artifact_repository_packages[artifact_repository] - ) + existing_packages_dict = ( + session._artifact_repository_packages[artifact_repository] if session else {} + ) if artifact_repository != ANACONDA_SHARED_REPOSITORY: # Non-conda artifact repository - skip conda-based package resolution diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 8d9e9e842b..e824b432ad 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -603,10 +603,7 @@ def __init__( self._conn = conn self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} - # packages that should be added under the default artifact repository - # TODO: now that we have dynamic defaults, should we remove this and just use _artifact_repository_packages always? - self._packages: Dict[str, str] = {} - # packages that should be added under an explicit artifact repository + # map of artifact repository name -> packages that should be added to functions under that repository self._artifact_repository_packages: DefaultDict[ str, Dict[str, str] ] = defaultdict(dict) @@ -1605,11 +1602,13 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s Args: artifact_repository: When set this will function will return the packages for a specific artifact repository. + Otherwise, uses the default artifact repository configured in the current context. """ + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + with self._package_lock: - if artifact_repository: - return self._artifact_repository_packages[artifact_repository].copy() - return self._packages.copy() + return self._artifact_repository_packages[artifact_repository].copy() def add_packages( self, @@ -1676,20 +1675,13 @@ def add_packages( to ensure the consistent experience of a UDF between your local environment and the Snowflake server. """ - use_default_artifact_repository = artifact_repository is None - if use_default_artifact_repository: + if artifact_repository is None: artifact_repository = self._get_default_artifact_repository() - existing_packages_dict = ( - self._packages - if use_default_artifact_repository - else self._artifact_repository_packages[artifact_repository] - ) - self._resolve_packages( parse_positional_args_to_list(*packages), artifact_repository, - existing_packages_dict, + self._artifact_repository_packages[artifact_repository], ) def remove_package( @@ -1721,17 +1713,13 @@ def remove_package( 0 """ package_name = Requirement(package).name + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + with self._package_lock: - if ( - artifact_repository is not None - and package_name - in self._artifact_repository_packages.get(artifact_repository, {}) - ): - self._artifact_repository_packages[artifact_repository].pop( - package_name - ) - elif package_name in self._packages: - self._packages.pop(package_name) + packages = self._artifact_repository_packages[artifact_repository] + if package_name in packages: + packages.pop(package_name) else: raise ValueError(f"{package_name} is not in the package list") @@ -1743,11 +1731,11 @@ def clear_packages( Clears all third-party packages of a user-defined function (UDF). When artifact_repository is set packages are only clear from the specified repository. """ + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + with self._package_lock: - if artifact_repository is not None: - self._artifact_repository_packages.get(artifact_repository, {}).clear() - else: - self._packages.clear() + self._artifact_repository_packages[artifact_repository].clear() def add_requirements( self, diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 5a1a6c81d8..cd5abc1eb1 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -939,10 +939,14 @@ def _do_register_sp( UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names[1:]) ] + if artifact_repository is None: + artifact_repository = self._session._get_default_artifact_repository() + # Add in snowflake-snowpark-python if it is not already in the package list. packages = add_snowpark_package_to_sproc_packages( session=self._session, packages=packages, + artifact_repository=artifact_repository, ) ( diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 64b8fee018..c1c4d12cb4 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -20,6 +20,7 @@ get_signature, ) from snowflake.snowpark.functions import call_udf, col, count_distinct, sproc, udf +from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import DateType, StringType from tests.utils import IS_IN_STORED_PROC, TempObjectType, TestFiles, Utils @@ -1200,10 +1201,17 @@ def test_replicate_local_environment(session): "force_push": True, } - assert not any([package.startswith("cloudpickle") for package in session._packages]) + assert not any( + [ + package.startswith("cloudpickle") + for package in session._artifact_repository_packages[ + ANACONDA_SHARED_REPOSITORY + ] + ] + ) def naive_add_packages(self, packages): - self._packages = packages + self._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY] = packages with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True): with patch.object(Session, "add_packages", new=naive_add_packages): @@ -1217,10 +1225,22 @@ def naive_add_packages(self, packages): }, ) - assert any([package.startswith("cloudpickle==") for package in session._packages]) + assert any( + [ + package.startswith("cloudpickle==") + for package in session._artifact_repository_packages[ + ANACONDA_SHARED_REPOSITORY + ] + ] + ) for default_package in DEFAULT_PACKAGES: assert not any( - [package.startswith(default_package) for package in session._packages] + [ + package.startswith(default_package) + for package in session._artifact_repository_packages[ + ANACONDA_SHARED_REPOSITORY + ] + ] ) session.clear_packages() @@ -1239,12 +1259,29 @@ def naive_add_packages(self, packages): ignore_packages=ignored_packages, relax=True ) - assert any([package == "cloudpickle" for package in session._packages]) + assert any( + [ + package == "cloudpickle" + for package in session._artifact_repository_packages[ + ANACONDA_SHARED_REPOSITORY + ] + ] + ) for default_package in DEFAULT_PACKAGES: assert not any( - [package.startswith(default_package) for package in session._packages] + [ + package.startswith(default_package) + for package in session._artifact_repository_packages[ + ANACONDA_SHARED_REPOSITORY + ] + ] ) for ignored_package in ignored_packages: assert not any( - [package.startswith(ignored_package) for package in session._packages] + [ + package.startswith(ignored_package) + for package in session._artifact_repository_packages[ + ANACONDA_SHARED_REPOSITORY + ] + ] ) diff --git a/tests/unit/test_stored_procedure.py b/tests/unit/test_stored_procedure.py index 638bae84d5..f8db1e5026 100644 --- a/tests/unit/test_stored_procedure.py +++ b/tests/unit/test_stored_procedure.py @@ -42,7 +42,6 @@ def test_stored_procedure_execute_as(execute_as): fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session) fake_session._runtime_version_from_requirement = None - fake_session._packages = {} def return1(_): return 1 @@ -90,7 +89,6 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.sproc = StoredProcedureRegistration(fake_session) - fake_session._packages = {} with pytest.raises(SnowparkSQLException) as ex_info: sproc(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udaf.py b/tests/unit/test_udaf.py index 20da11c7f4..702bfb713c 100644 --- a/tests/unit/test_udaf.py +++ b/tests/unit/test_udaf.py @@ -56,7 +56,6 @@ def test_do_register_udaf_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session._runtime_version_from_requirement = None - fake_session._packages = [] fake_session.udaf = UDAFRegistration(fake_session) with pytest.raises(SnowparkSQLException) as ex_info: diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 623738c567..6df775164f 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -25,6 +25,7 @@ resolve_packages_in_client_side_sandbox, ) from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import StringType from snowflake.snowpark.version import VERSION @@ -248,7 +249,7 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): def test_add_snowpark_package_to_sproc_packages_to_session(): fake_session = mock.create_autospec(Session) - fake_session._packages = { + fake_session._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY] = { "random_package_one": "random_package_one", "random_package_two": "random_package_two", } @@ -261,7 +262,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): assert len(result) == 3 assert final_name in result - fake_session._packages[ + fake_session._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY][ "snowflake-snowpark-python" ] = "snowflake-snowpark-python==1.12.0" result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) diff --git a/tests/unit/test_udtf.py b/tests/unit/test_udtf.py index 794fa82b8c..7b6f32cc70 100644 --- a/tests/unit/test_udtf.py +++ b/tests/unit/test_udtf.py @@ -39,7 +39,6 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session._runtime_version_from_requirement = None - fake_session._packages = [] fake_session.udtf = UDTFRegistration(fake_session) with pytest.raises(SnowparkSQLException) as ex_info: From 24a01855a191a85ef622a8f50aee6dcfb7ac2be2 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 10:01:29 -0800 Subject: [PATCH 07/20] remove more _packages --- src/snowflake/snowpark/_internal/udf_utils.py | 9 ++++++--- src/snowflake/snowpark/stored_procedure.py | 11 +++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 5907e928b9..77f8ad1af2 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1121,7 +1121,7 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}): def add_snowpark_package_to_sproc_packages( session: Optional["snowflake.snowpark.Session"], packages: Optional[List[Union[str, ModuleType]]], - artifact_repository: Optional[str], + artifact_repository: str, ) -> List[Union[str, ModuleType]]: major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -1137,8 +1137,11 @@ def add_snowpark_package_to_sproc_packages( packages = [this_package] else: with session._package_lock: - if package_name not in session._packages: - packages = list(session._packages.values()) + [this_package] + existing_packages = session._artifact_repository_packages[ + artifact_repository + ] + if package_name not in existing_packages: + packages = list(existing_packages.values()) + [this_package] return packages return add_package_to_existing_packages(packages, package_name, this_package) diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index cd5abc1eb1..1f207e1df9 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -939,14 +939,17 @@ def _do_register_sp( UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names[1:]) ] - if artifact_repository is None: - artifact_repository = self._session._get_default_artifact_repository() + effective_artifact_repository = artifact_repository + if effective_artifact_repository is None: + effective_artifact_repository = ( + self._session._get_default_artifact_repository() + ) # Add in snowflake-snowpark-python if it is not already in the package list. packages = add_snowpark_package_to_sproc_packages( session=self._session, packages=packages, - artifact_repository=artifact_repository, + artifact_repository=effective_artifact_repository, ) ( @@ -971,7 +974,7 @@ def _do_register_sp( skip_upload_on_content_match=skip_upload_on_content_match, is_permanent=is_permanent, force_inline_code=force_inline_code, - artifact_repository=artifact_repository, + artifact_repository=effective_artifact_repository, _suppress_local_package_warnings=kwargs.get( "_suppress_local_package_warnings", False ), From 07cc8b930da4e6208dbbbb77aaac840d385453e4 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 10:44:12 -0800 Subject: [PATCH 08/20] add default cache --- src/snowflake/snowpark/session.py | 30 ++++++++++++++++++++++++++++-- tests/unit/test_session.py | 30 +++++++++++++++++++++++------- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index e824b432ad..8d515e57b8 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -607,6 +607,13 @@ def __init__( self._artifact_repository_packages: DefaultDict[ str, Dict[str, str] ] = defaultdict(dict) + # Single-entry cache for the default artifact repository value. + # Stores a tuple of ((database, schema), cached_value). Only one entry is + # kept at a time – switching to a different database/schema evicts the old + # value and triggers a fresh query on the next call. + self._default_artifact_repository_cache: Optional[ + Tuple[Tuple[Optional[str], Optional[str]], str] + ] = None self._session_id = self._conn.get_session_id() self._session_info = f""" "version" : {get_version()}, @@ -2381,22 +2388,41 @@ def _get_default_artifact_repository(self) -> str: Returns the default artifact repository for the current session context by calling SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY. + The result is cached per (database, schema) pair so that + repeated invocations in the same context do not issue + redundant system-function queries. Only one cache entry is kept at + a time; switching to a different database or schema evicts the + previous entry and triggers a fresh query on the next call. + Falls back to the Anaconda shared repository (conda) if: + - the session uses a mock connection (local testing), or - the system function is not available / fails, or - the system function returns NULL (value was never set). """ if isinstance(self._conn, MockServerConnection): return ANACONDA_SHARED_REPOSITORY + cache_key = (self.get_current_database(), self.get_current_schema()) + + if ( + self._default_artifact_repository_cache is not None + and self._default_artifact_repository_cache[0] == cache_key + ): + return self._default_artifact_repository_cache[1] + try: python_version = f"{sys.version_info.major}.{sys.version_info.minor}" result = self._run_query( f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" ) value = result[0][0] if result else None - return value or ANACONDA_SHARED_REPOSITORY + resolved = value or ANACONDA_SHARED_REPOSITORY except Exception: - return ANACONDA_SHARED_REPOSITORY + resolved = ANACONDA_SHARED_REPOSITORY + + with self._package_lock: + self._default_artifact_repository_cache = (cache_key, resolved) + return resolved def _is_anaconda_terms_acknowledged(self) -> bool: return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 14743ca966..6180822507 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -731,23 +731,39 @@ def test_get_default_artifact_repository(): session, "_run_query", return_value=[["snowflake.snowpark.pypi_shared_repository"]], + ) as mocked_run_query, mock.patch.object( + session, "get_current_database", return_value="DB1" + ), mock.patch.object( + session, "get_current_schema", return_value="SCHEMA1" ): result = session._get_default_artifact_repository() assert result == "snowflake.snowpark.pypi_shared_repository" + result = session._get_default_artifact_repository() + assert result == "snowflake.snowpark.pypi_shared_repository" + + assert mocked_run_query.call_count == 1 + with mock.patch.object( - session, - "_run_query", - return_value=[[None]], + session, "_run_query", return_value=[[None]] + ) as mocked_run_query, mock.patch.object( + session, "get_current_database", return_value="DB2" + ), mock.patch.object( + session, "get_current_schema", return_value="SCHEMA2" ): result = session._get_default_artifact_repository() assert result == ANACONDA_SHARED_REPOSITORY - # throws error + assert mocked_run_query.call_count == 1 + with mock.patch.object( - session, - "_run_query", - side_effect=ProgrammingError("Function not found"), + session, "_run_query", side_effect=ProgrammingError("Not found") + ) as mocked_run_query, mock.patch.object( + session, "get_current_database", return_value="DB1" + ), mock.patch.object( + session, "get_current_schema", return_value="SCHEMA1" ): result = session._get_default_artifact_repository() assert result == ANACONDA_SHARED_REPOSITORY + + assert mocked_run_query.call_count == 1 From bde687f2f1c3ac52fdae00788a7d90ffb9a63d7c Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 11:07:12 -0800 Subject: [PATCH 09/20] fix UT --- src/snowflake/snowpark/stored_procedure.py | 4 ++++ tests/unit/test_stored_procedure.py | 3 +++ tests/unit/test_udaf.py | 2 ++ tests/unit/test_udf.py | 2 ++ tests/unit/test_udf_utils.py | 20 ++++++++++++++++---- tests/unit/test_udtf.py | 2 ++ 6 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 1f207e1df9..a42f739beb 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -941,8 +941,12 @@ def _do_register_sp( effective_artifact_repository = artifact_repository if effective_artifact_repository is None: + from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY + effective_artifact_repository = ( self._session._get_default_artifact_repository() + if self._session + else ANACONDA_SHARED_REPOSITORY ) # Add in snowflake-snowpark-python if it is not already in the package list. diff --git a/tests/unit/test_stored_procedure.py b/tests/unit/test_stored_procedure.py index f8db1e5026..2e6aff999a 100644 --- a/tests/unit/test_stored_procedure.py +++ b/tests/unit/test_stored_procedure.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from unittest import mock @@ -42,6 +43,7 @@ def test_stored_procedure_execute_as(execute_as): fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session) fake_session._runtime_version_from_requirement = None + fake_session._artifact_repository_packages = defaultdict(dict) def return1(_): return 1 @@ -89,6 +91,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.sproc = StoredProcedureRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: sproc(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udaf.py b/tests/unit/test_udaf.py index 702bfb713c..803f8dc886 100644 --- a/tests/unit/test_udaf.py +++ b/tests/unit/test_udaf.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from unittest import mock @@ -57,6 +58,7 @@ def test_do_register_udaf_negative(cleanup_registration_patch): fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session._runtime_version_from_requirement = None fake_session.udaf = UDAFRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: @udaf(session=fake_session) diff --git a/tests/unit/test_udf.py b/tests/unit/test_udf.py index 469cde1a7c..030e233a1a 100644 --- a/tests/unit/test_udf.py +++ b/tests/unit/test_udf.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from unittest import mock @@ -32,6 +33,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.udf = UDFRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: udf(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 6df775164f..fb4f8a14db 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -225,7 +225,9 @@ def test_resolve_imports_and_packages_imports_as_str(tmp_path_factory): ) def test_add_snowpark_package_to_sproc_packages_add_package(packages): old_packages_length = len(packages) if packages else 0 - result = add_snowpark_package_to_sproc_packages(session=None, packages=packages) + result = add_snowpark_package_to_sproc_packages( + session=None, packages=packages, artifact_repository=ANACONDA_SHARED_REPOSITORY + ) major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -241,7 +243,9 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): "random_package_two", "snowflake-snowpark-python==1.12.0", ] - result = add_snowpark_package_to_sproc_packages(session=None, packages=packages) + result = add_snowpark_package_to_sproc_packages( + session=None, packages=packages, artifact_repository=ANACONDA_SHARED_REPOSITORY + ) assert len(result) == len(packages) assert "snowflake-snowpark-python==1.12.0" in result @@ -254,7 +258,11 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_two": "random_package_two", } fake_session._package_lock = threading.RLock() - result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) + result = add_snowpark_package_to_sproc_packages( + session=fake_session, + packages=None, + artifact_repository=ANACONDA_SHARED_REPOSITORY, + ) major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -265,7 +273,11 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): fake_session._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY][ "snowflake-snowpark-python" ] = "snowflake-snowpark-python==1.12.0" - result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) + result = add_snowpark_package_to_sproc_packages( + session=fake_session, + packages=None, + artifact_repository=ANACONDA_SHARED_REPOSITORY, + ) assert result is None diff --git a/tests/unit/test_udtf.py b/tests/unit/test_udtf.py index 7b6f32cc70..ba986364ed 100644 --- a/tests/unit/test_udtf.py +++ b/tests/unit/test_udtf.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from typing import Tuple from unittest import mock @@ -40,6 +41,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session._runtime_version_from_requirement = None fake_session.udtf = UDTFRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: @udtf(output_schema=["num"], session=fake_session) From 97ef1aafbe76283405d17e35f474323f76dfa79c Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 11:33:34 -0800 Subject: [PATCH 10/20] doc udpates --- src/snowflake/snowpark/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 8d515e57b8..7310f8cfa9 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -322,7 +322,6 @@ WRITE_ARROW_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None # The fully qualified name of the Anaconda shared repository (conda channel). -# Used as the fallback/default when the system function is unavailable or returns NULL. ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" @@ -2398,6 +2397,9 @@ def _get_default_artifact_repository(self) -> str: - the session uses a mock connection (local testing), or - the system function is not available / fails, or - the system function returns NULL (value was never set). + + If the Snowflake default artifact repository changes in the future, the + fallback needs to be updated here. """ if isinstance(self._conn, MockServerConnection): return ANACONDA_SHARED_REPOSITORY From 1f45379ee8c3ddfd66dcc927f29f9ee5a5f09fd6 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 12:06:28 -0800 Subject: [PATCH 11/20] fix more tests --- src/snowflake/snowpark/session.py | 6 +++- tests/integ/conftest.py | 4 --- tests/integ/test_packaging.py | 5 ++- tests/integ/test_udf.py | 54 +++++++++++++++++-------------- tests/unit/test_udf_utils.py | 2 ++ 5 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 7310f8cfa9..960d16b34c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2418,8 +2418,12 @@ def _get_default_artifact_repository(self) -> str: f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" ) value = result[0][0] if result else None + print("RESULT:", result) resolved = value or ANACONDA_SHARED_REPOSITORY - except Exception: + except Exception as e: + _logger.warning( + f"Error getting default artifact repository: {e}. Using fallback." + ) resolved = ANACONDA_SHARED_REPOSITORY with self._package_lock: diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 4166dd1eb9..9df63dcc92 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -357,10 +357,6 @@ def session( "alter session set ENABLE_EXTRACTION_PUSHDOWN_EXTERNAL_PARQUET_FOR_COPY_PHASE_I='Track';" ).collect() session.sql("alter session set ENABLE_ROW_ACCESS_POLICY=true").collect() - # TODO: remove - session.sql( - "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" - ).collect() try: yield session diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index c1c4d12cb4..0a358d8ae5 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -270,7 +270,10 @@ def extract_major_minor_patch(version_string): return match.group(1) if match else version_string resolved_packages = session._resolve_packages( - [numpy, pandas, dateutil], validate_package=False + [numpy, pandas, dateutil], + ANACONDA_SHARED_REPOSITORY, + {}, + validate_package=False, ) # resolved_packages is a list of strings like # ['numpy==2.0.2', 'pandas==2.3.0', 'python-dateutil==2.9.0.post0', 'cloudpickle==3.0.0'] diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 484cec8a53..453418f59f 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3015,37 +3015,43 @@ def test_urllib() -> str: "config.getoption('local_testing_mode', default=False)", reason="artifact repository not supported in local testing", ) -# @pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") @pytest.mark.skipif( sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" ) -def test_use_default_artifact_repository(session): - # TODO: is this safe with parallel testing? - session.sql( - "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository" - ).collect() +def test_use_default_artifact_repository(db_parameters): + with Session.builder.configs(db_parameters).create() as session: + session.use_schema("public") + session.sql( + "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" + ).collect() + session.sql( + "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository" + ).collect() - session.add_packages("art", "cloudpickle") + session.add_packages("art", "cloudpickle") - def test_art() -> str: - import art # art is not available in the conda channel, but is in pypi + def test_art() -> str: + import art # art is not available in the conda channel, but is in pypi - _ = art.text2art("test") - return "art works!" + _ = art.text2art("test") + return "art works!" - temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) - try: - # Test function registration - udf( - func=test_art, - name=temp_func_name, - ) + try: + # Test function registration + udf( + session=session, + func=test_art, + name=temp_func_name, + ) - # Test UDF call - df = session.create_dataframe([1]).to_df(["a"]) - Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("art works!")]) - finally: - session._run_query(f"drop function if exists {temp_func_name}(int)") + # Test UDF call + df = session.create_dataframe([1]).to_df(["a"]) + Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("art works!")]) + finally: + session._run_query(f"drop function if exists {temp_func_name}(int)") - session.sql("ALTER schema unset DEFAULT_PYTHON_ARTIFACT_REPOSITORY").collect() + session.sql("ALTER schema unset DEFAULT_PYTHON_ARTIFACT_REPOSITORY").collect() diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index fb4f8a14db..f68e662728 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import logging import os import pickle @@ -253,6 +254,7 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): def test_add_snowpark_package_to_sproc_packages_to_session(): fake_session = mock.create_autospec(Session) + fake_session._artifact_repository_packages = defaultdict(dict) fake_session._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY] = { "random_package_one": "random_package_one", "random_package_two": "random_package_two", From 44ee0e4ca376d1bd98dce6544f47f99f91ed583b Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 12:50:55 -0800 Subject: [PATCH 12/20] remove debug log --- src/snowflake/snowpark/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 960d16b34c..2e6d77fb87 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2418,7 +2418,6 @@ def _get_default_artifact_repository(self) -> str: f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" ) value = result[0][0] if result else None - print("RESULT:", result) resolved = value or ANACONDA_SHARED_REPOSITORY except Exception as e: _logger.warning( From 026c024c0884d310f00686e9c3fd8f28a93252a9 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 13:11:58 -0800 Subject: [PATCH 13/20] cleanup --- CHANGELOG.md | 1 + src/snowflake/snowpark/_internal/udf_utils.py | 7 +++++-- src/snowflake/snowpark/session.py | 21 ++++++++++++------- src/snowflake/snowpark/stored_procedure.py | 4 ++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 042714b559..577866d1dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### New Features - Added support for the `DECFLOAT` data type that allows users to represent decimal numbers exactly with 38 digits of precision and a dynamic base-10 exponent. +- Added support for the `DEFAULT_PYTHON_ARTIFACT_REPOSITORY` parameter that allows users to configure the default artifact repository at the account, database, and schema level. #### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 0989ef83d9..baf14981e8 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1227,13 +1227,16 @@ def resolve_imports_and_packages( Optional[str], bool, ]: - from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY + from snowflake.snowpark.session import ( + ANACONDA_SHARED_REPOSITORY, + DEFAULT_ARTIFACT_REPOSITORY, + ) if artifact_repository is None: artifact_repository = ( session._get_default_artifact_repository() if session - else ANACONDA_SHARED_REPOSITORY + else DEFAULT_ARTIFACT_REPOSITORY ) existing_packages_dict = ( diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 2e6d77fb87..56ee967d70 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -323,6 +323,8 @@ # The fully qualified name of the Anaconda shared repository (conda channel). ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" +# In case of failures or the current default artifact repository is unset, we fallback to this +DEFAULT_ARTIFACT_REPOSITORY = ANACONDA_SHARED_REPOSITORY def _get_active_session() -> "Session": @@ -608,8 +610,8 @@ def __init__( ] = defaultdict(dict) # Single-entry cache for the default artifact repository value. # Stores a tuple of ((database, schema), cached_value). Only one entry is - # kept at a time – switching to a different database/schema evicts the old - # value and triggers a fresh query on the next call. + # kept at a time – switching to a different database/schema will evict the old + # value and trigger a fresh query on the next call. self._default_artifact_repository_cache: Optional[ Tuple[Tuple[Optional[str], Optional[str]], str] ] = None @@ -1641,7 +1643,8 @@ def add_packages( for this argument. If a ``module`` object is provided, the package will be installed with the version in the local environment. artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions - using that repository will use the packages. (Default None) + using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in the + current context. Example:: @@ -1701,7 +1704,8 @@ def remove_package( Args: package: The package name. artifact_repository: When set this parameter specifies that the package should be removed - from the default packages for a specific artifact repository. + from the default packages for a specific artifact repository. Otherwise, uses the default + artifact repository configured in the current context. Examples:: @@ -1758,7 +1762,8 @@ def add_requirements( Args: file_path: The path of a local requirement file. artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions - using that repository will use the packages. (Default None) + using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in + the current context. Example:: @@ -2402,7 +2407,7 @@ def _get_default_artifact_repository(self) -> str: fallback needs to be updated here. """ if isinstance(self._conn, MockServerConnection): - return ANACONDA_SHARED_REPOSITORY + return DEFAULT_ARTIFACT_REPOSITORY cache_key = (self.get_current_database(), self.get_current_schema()) @@ -2418,12 +2423,12 @@ def _get_default_artifact_repository(self) -> str: f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" ) value = result[0][0] if result else None - resolved = value or ANACONDA_SHARED_REPOSITORY + resolved = value or DEFAULT_ARTIFACT_REPOSITORY except Exception as e: _logger.warning( f"Error getting default artifact repository: {e}. Using fallback." ) - resolved = ANACONDA_SHARED_REPOSITORY + resolved = DEFAULT_ARTIFACT_REPOSITORY with self._package_lock: self._default_artifact_repository_cache = (cache_key, resolved) diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index a42f739beb..623f3dc88f 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -941,12 +941,12 @@ def _do_register_sp( effective_artifact_repository = artifact_repository if effective_artifact_repository is None: - from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY + from snowflake.snowpark.session import DEFAULT_ARTIFACT_REPOSITORY effective_artifact_repository = ( self._session._get_default_artifact_repository() if self._session - else ANACONDA_SHARED_REPOSITORY + else DEFAULT_ARTIFACT_REPOSITORY ) # Add in snowflake-snowpark-python if it is not already in the package list. From eea6081f8f2d16e38ed7a74ff13dd8920b52f055 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Fri, 6 Feb 2026 13:49:39 -0800 Subject: [PATCH 14/20] wrap getting the default in a lock --- src/snowflake/snowpark/session.py | 49 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 56ee967d70..072626fafb 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2398,41 +2398,38 @@ def _get_default_artifact_repository(self) -> str: a time; switching to a different database or schema evicts the previous entry and triggers a fresh query on the next call. - Falls back to the Anaconda shared repository (conda) if: + Falls back to the Snowflake default artifact repository if: - the session uses a mock connection (local testing), or - the system function is not available / fails, or - the system function returns NULL (value was never set). - - If the Snowflake default artifact repository changes in the future, the - fallback needs to be updated here. """ - if isinstance(self._conn, MockServerConnection): - return DEFAULT_ARTIFACT_REPOSITORY + with self._package_lock: + if isinstance(self._conn, MockServerConnection): + return DEFAULT_ARTIFACT_REPOSITORY - cache_key = (self.get_current_database(), self.get_current_schema()) + cache_key = (self.get_current_database(), self.get_current_schema()) - if ( - self._default_artifact_repository_cache is not None - and self._default_artifact_repository_cache[0] == cache_key - ): - return self._default_artifact_repository_cache[1] + if ( + self._default_artifact_repository_cache is not None + and self._default_artifact_repository_cache[0] == cache_key + ): + return self._default_artifact_repository_cache[1] - try: - python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - result = self._run_query( - f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" - ) - value = result[0][0] if result else None - resolved = value or DEFAULT_ARTIFACT_REPOSITORY - except Exception as e: - _logger.warning( - f"Error getting default artifact repository: {e}. Using fallback." - ) - resolved = DEFAULT_ARTIFACT_REPOSITORY + try: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + result = self._run_query( + f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" + ) + value = result[0][0] if result else None + resolved = value or DEFAULT_ARTIFACT_REPOSITORY + except Exception as e: + _logger.warning( + f"Error getting default artifact repository: {e}. Using fallback." + ) + resolved = DEFAULT_ARTIFACT_REPOSITORY - with self._package_lock: self._default_artifact_repository_cache = (cache_key, resolved) - return resolved + return resolved def _is_anaconda_terms_acknowledged(self) -> bool: return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0] From 9746f7f15bff932e948a3489927466f894b42f84 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Mon, 9 Feb 2026 16:48:11 -0800 Subject: [PATCH 15/20] update comment --- src/snowflake/snowpark/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 072626fafb..d8b7c1fc92 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2424,7 +2424,7 @@ def _get_default_artifact_repository(self) -> str: resolved = value or DEFAULT_ARTIFACT_REPOSITORY except Exception as e: _logger.warning( - f"Error getting default artifact repository: {e}. Using fallback." + f"Error getting default artifact repository: {e}. Using fallback: {DEFAULT_ARTIFACT_REPOSITORY}." ) resolved = DEFAULT_ARTIFACT_REPOSITORY From 8562a724a8781ffae3cdfb46afd94f0b622ec51c Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 10 Feb 2026 10:55:38 -0800 Subject: [PATCH 16/20] update test --- tests/integ/test_udf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 1bfb50af2e..ce3a5a18a2 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3053,7 +3053,7 @@ def test_urllib() -> str: reason="artifact repository not supported in local testing", ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") +# @pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") @pytest.mark.skipif( sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" ) @@ -3063,6 +3063,7 @@ def test_use_default_artifact_repository(db_parameters): session.sql( "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" ).collect() + session.sql("ALTER SESSION SET ENABLE_PYPI_SHARED_REPOSITORY = true").collect() session.sql( "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository" ).collect() From 4a51691887282d59483a08fe1018629d67513fb8 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 10 Feb 2026 13:25:20 -0800 Subject: [PATCH 17/20] try using test artifact repo --- tests/integ/test_udf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index ce3a5a18a2..360d00e2ae 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3053,7 +3053,7 @@ def test_urllib() -> str: reason="artifact repository not supported in local testing", ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -# @pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") +@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") @pytest.mark.skipif( sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" ) @@ -3063,9 +3063,8 @@ def test_use_default_artifact_repository(db_parameters): session.sql( "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" ).collect() - session.sql("ALTER SESSION SET ENABLE_PYPI_SHARED_REPOSITORY = true").collect() session.sql( - "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository" + "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY" ).collect() session.add_packages("art", "cloudpickle") From 273b028a0ab5b41219dfd2d3c5bb787488e0d894 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 10 Feb 2026 15:19:53 -0800 Subject: [PATCH 18/20] filter out system call checks in modin --- tests/integ/utils/sql_counter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/integ/utils/sql_counter.py b/tests/integ/utils/sql_counter.py index c8c3134ff9..42175dda77 100644 --- a/tests/integ/utils/sql_counter.py +++ b/tests/integ/utils/sql_counter.py @@ -87,6 +87,7 @@ # that this parameter is unset, as currently required by Snowpark pandas. # 8. SHOW OBJECTS LIKE [TABLE_NAME] IN SCHEMA [SCHEMA] LIMIT 1 ... is to check the row count of a table we are reading # from, if it exists +# 9. SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY is cached per DB/schema context, so it's not consistent when it is executed FILTER_OUT_QUERIES = [ ["create SCOPED TEMPORARY", "stage if not exists"], ["PUT", "file:///tmp/placeholder/snowpark.zip"], @@ -96,6 +97,7 @@ ["drop table if exists", "/* internal query to drop unused temp table */"], ["SHOW PARAMETERS LIKE", "QUOTED_IDENTIFIERS_IGNORE_CASE"], ["SHOW OBJECTS LIKE", "IN SCHEMA", "LIMIT 1"], + ["SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY"], ] # define global at module-level @@ -485,7 +487,7 @@ def sql_count_checker( } # also look into kwargs for count configuration. Right now, describe_count and window_count are the # counts can be passed optionally - for (key, value) in kwargs.items(): + for key, value in kwargs.items(): if key.endswith("_count"): count_kwargs.update({key: value}) @@ -528,7 +530,7 @@ def get_readable_sql_count_values(tr): def update_test_code_with_sql_counts( - sql_count_records: Dict[str, Dict[str, List[Dict[str, Optional[PythonScalar]]]]] + sql_count_records: Dict[str, Dict[str, List[Dict[str, Optional[PythonScalar]]]]], ): """This helper takes sql count records and rewrites the source test files to validate sql counts where possible. From 4f3d174e3fac8317a9302ab69ad76da1cd99a429 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Tue, 17 Feb 2026 13:17:23 -0800 Subject: [PATCH 19/20] private --- src/snowflake/snowpark/_internal/udf_utils.py | 8 ++++---- src/snowflake/snowpark/session.py | 14 +++++++------- src/snowflake/snowpark/stored_procedure.py | 4 ++-- tests/integ/test_packaging.py | 18 +++++++++--------- tests/unit/test_session.py | 14 +++++++------- tests/unit/test_udf_utils.py | 14 +++++++------- 6 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index baf14981e8..c1043304af 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1228,22 +1228,22 @@ def resolve_imports_and_packages( bool, ]: from snowflake.snowpark.session import ( - ANACONDA_SHARED_REPOSITORY, - DEFAULT_ARTIFACT_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, + _DEFAULT_ARTIFACT_REPOSITORY, ) if artifact_repository is None: artifact_repository = ( session._get_default_artifact_repository() if session - else DEFAULT_ARTIFACT_REPOSITORY + else _DEFAULT_ARTIFACT_REPOSITORY ) existing_packages_dict = ( session._artifact_repository_packages[artifact_repository] if session else {} ) - if artifact_repository != ANACONDA_SHARED_REPOSITORY: + if artifact_repository != _ANACONDA_SHARED_REPOSITORY: # Non-conda artifact repository - skip conda-based package resolution resolved_packages = [] if not packages and session: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index d8b7c1fc92..b462e91d26 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -322,9 +322,9 @@ WRITE_ARROW_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None # The fully qualified name of the Anaconda shared repository (conda channel). -ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" +_ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" # In case of failures or the current default artifact repository is unset, we fallback to this -DEFAULT_ARTIFACT_REPOSITORY = ANACONDA_SHARED_REPOSITORY +_DEFAULT_ARTIFACT_REPOSITORY = _ANACONDA_SHARED_REPOSITORY def _get_active_session() -> "Session": @@ -2144,7 +2144,7 @@ def _resolve_packages( package_dict = self._parse_packages(packages) if ( isinstance(self._conn, MockServerConnection) - or artifact_repository != ANACONDA_SHARED_REPOSITORY + or artifact_repository != _ANACONDA_SHARED_REPOSITORY ): # in local testing or non-conda, we don't resolve the packages, we just return what is added errors = [] @@ -2405,7 +2405,7 @@ def _get_default_artifact_repository(self) -> str: """ with self._package_lock: if isinstance(self._conn, MockServerConnection): - return DEFAULT_ARTIFACT_REPOSITORY + return _DEFAULT_ARTIFACT_REPOSITORY cache_key = (self.get_current_database(), self.get_current_schema()) @@ -2421,12 +2421,12 @@ def _get_default_artifact_repository(self) -> str: f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" ) value = result[0][0] if result else None - resolved = value or DEFAULT_ARTIFACT_REPOSITORY + resolved = value or _DEFAULT_ARTIFACT_REPOSITORY except Exception as e: _logger.warning( - f"Error getting default artifact repository: {e}. Using fallback: {DEFAULT_ARTIFACT_REPOSITORY}." + f"Error getting default artifact repository: {e}. Using fallback: {_DEFAULT_ARTIFACT_REPOSITORY}." ) - resolved = DEFAULT_ARTIFACT_REPOSITORY + resolved = _DEFAULT_ARTIFACT_REPOSITORY self._default_artifact_repository_cache = (cache_key, resolved) return resolved diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 623f3dc88f..469855def5 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -941,12 +941,12 @@ def _do_register_sp( effective_artifact_repository = artifact_repository if effective_artifact_repository is None: - from snowflake.snowpark.session import DEFAULT_ARTIFACT_REPOSITORY + from snowflake.snowpark.session import _DEFAULT_ARTIFACT_REPOSITORY effective_artifact_repository = ( self._session._get_default_artifact_repository() if self._session - else DEFAULT_ARTIFACT_REPOSITORY + else _DEFAULT_ARTIFACT_REPOSITORY ) # Add in snowflake-snowpark-python if it is not already in the package list. diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 0a358d8ae5..8f58b5ef2b 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -20,7 +20,7 @@ get_signature, ) from snowflake.snowpark.functions import call_udf, col, count_distinct, sproc, udf -from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY +from snowflake.snowpark.session import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import DateType, StringType from tests.utils import IS_IN_STORED_PROC, TempObjectType, TestFiles, Utils @@ -271,7 +271,7 @@ def extract_major_minor_patch(version_string): resolved_packages = session._resolve_packages( [numpy, pandas, dateutil], - ANACONDA_SHARED_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, {}, validate_package=False, ) @@ -1208,13 +1208,13 @@ def test_replicate_local_environment(session): [ package.startswith("cloudpickle") for package in session._artifact_repository_packages[ - ANACONDA_SHARED_REPOSITORY + _ANACONDA_SHARED_REPOSITORY ] ] ) def naive_add_packages(self, packages): - self._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY] = packages + self._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = packages with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True): with patch.object(Session, "add_packages", new=naive_add_packages): @@ -1232,7 +1232,7 @@ def naive_add_packages(self, packages): [ package.startswith("cloudpickle==") for package in session._artifact_repository_packages[ - ANACONDA_SHARED_REPOSITORY + _ANACONDA_SHARED_REPOSITORY ] ] ) @@ -1241,7 +1241,7 @@ def naive_add_packages(self, packages): [ package.startswith(default_package) for package in session._artifact_repository_packages[ - ANACONDA_SHARED_REPOSITORY + _ANACONDA_SHARED_REPOSITORY ] ] ) @@ -1266,7 +1266,7 @@ def naive_add_packages(self, packages): [ package == "cloudpickle" for package in session._artifact_repository_packages[ - ANACONDA_SHARED_REPOSITORY + _ANACONDA_SHARED_REPOSITORY ] ] ) @@ -1275,7 +1275,7 @@ def naive_add_packages(self, packages): [ package.startswith(default_package) for package in session._artifact_repository_packages[ - ANACONDA_SHARED_REPOSITORY + _ANACONDA_SHARED_REPOSITORY ] ] ) @@ -1284,7 +1284,7 @@ def naive_add_packages(self, packages): [ package.startswith(ignored_package) for package in session._artifact_repository_packages[ - ANACONDA_SHARED_REPOSITORY + _ANACONDA_SHARED_REPOSITORY ] ] ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6180822507..7adc1c11b0 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -30,7 +30,7 @@ ) from snowflake.snowpark.session import ( _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, - ANACONDA_SHARED_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, ) from snowflake.snowpark.types import StructField, StructType @@ -215,7 +215,7 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True session._resolve_packages( ["random_package_name"], - ANACONDA_SHARED_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, {}, validate_package=True, include_pandas=False, @@ -250,7 +250,7 @@ def run_query(sql: str): ): session._resolve_packages( ["random_package_name"], - ANACONDA_SHARED_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, {}, validate_package=True, include_pandas=False, @@ -275,7 +275,7 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True resolved_packages = session._resolve_packages( ["random_package_name"], - ANACONDA_SHARED_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, existing_packages, validate_package=True, include_pandas=False, @@ -307,7 +307,7 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True ): session._resolve_packages( ["snowflake-snowpark-python"], - ANACONDA_SHARED_REPOSITORY, + _ANACONDA_SHARED_REPOSITORY, {}, validate_package=True, include_pandas=False, @@ -752,7 +752,7 @@ def test_get_default_artifact_repository(): session, "get_current_schema", return_value="SCHEMA2" ): result = session._get_default_artifact_repository() - assert result == ANACONDA_SHARED_REPOSITORY + assert result == _ANACONDA_SHARED_REPOSITORY assert mocked_run_query.call_count == 1 @@ -764,6 +764,6 @@ def test_get_default_artifact_repository(): session, "get_current_schema", return_value="SCHEMA1" ): result = session._get_default_artifact_repository() - assert result == ANACONDA_SHARED_REPOSITORY + assert result == _ANACONDA_SHARED_REPOSITORY assert mocked_run_query.call_count == 1 diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index f68e662728..a3f9e03fc3 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -26,7 +26,7 @@ resolve_packages_in_client_side_sandbox, ) from snowflake.snowpark._internal.utils import TempObjectType -from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY +from snowflake.snowpark.session import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import StringType from snowflake.snowpark.version import VERSION @@ -227,7 +227,7 @@ def test_resolve_imports_and_packages_imports_as_str(tmp_path_factory): def test_add_snowpark_package_to_sproc_packages_add_package(packages): old_packages_length = len(packages) if packages else 0 result = add_snowpark_package_to_sproc_packages( - session=None, packages=packages, artifact_repository=ANACONDA_SHARED_REPOSITORY + session=None, packages=packages, artifact_repository=_ANACONDA_SHARED_REPOSITORY ) major, minor, patch = VERSION @@ -245,7 +245,7 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): "snowflake-snowpark-python==1.12.0", ] result = add_snowpark_package_to_sproc_packages( - session=None, packages=packages, artifact_repository=ANACONDA_SHARED_REPOSITORY + session=None, packages=packages, artifact_repository=_ANACONDA_SHARED_REPOSITORY ) assert len(result) == len(packages) @@ -255,7 +255,7 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): def test_add_snowpark_package_to_sproc_packages_to_session(): fake_session = mock.create_autospec(Session) fake_session._artifact_repository_packages = defaultdict(dict) - fake_session._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY] = { + fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = { "random_package_one": "random_package_one", "random_package_two": "random_package_two", } @@ -263,7 +263,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): result = add_snowpark_package_to_sproc_packages( session=fake_session, packages=None, - artifact_repository=ANACONDA_SHARED_REPOSITORY, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, ) major, minor, patch = VERSION @@ -272,13 +272,13 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): assert len(result) == 3 assert final_name in result - fake_session._artifact_repository_packages[ANACONDA_SHARED_REPOSITORY][ + fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY][ "snowflake-snowpark-python" ] = "snowflake-snowpark-python==1.12.0" result = add_snowpark_package_to_sproc_packages( session=fake_session, packages=None, - artifact_repository=ANACONDA_SHARED_REPOSITORY, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, ) assert result is None From 70e97e77d5064d63aba796423d55325d31888191 Mon Sep 17 00:00:00 2001 From: Ben Kogan Date: Wed, 18 Feb 2026 14:28:59 -0800 Subject: [PATCH 20/20] move constants to context + temp schema in test --- src/snowflake/snowpark/_internal/udf_utils.py | 9 ++++----- src/snowflake/snowpark/context.py | 5 +++++ src/snowflake/snowpark/session.py | 7 ++----- tests/integ/test_packaging.py | 2 +- tests/integ/test_udf.py | 6 ++++-- tests/unit/test_session.py | 6 ++---- tests/unit/test_udf_utils.py | 2 +- 7 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index c1043304af..7c0f9e01f4 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -59,6 +59,10 @@ ) from snowflake.snowpark.types import DataType, StructField, StructType from snowflake.snowpark.version import VERSION +from snowflake.snowpark.context import ( + _ANACONDA_SHARED_REPOSITORY, + _DEFAULT_ARTIFACT_REPOSITORY, +) if installed_pandas: from snowflake.snowpark.types import ( @@ -1227,11 +1231,6 @@ def resolve_imports_and_packages( Optional[str], bool, ]: - from snowflake.snowpark.session import ( - _ANACONDA_SHARED_REPOSITORY, - _DEFAULT_ARTIFACT_REPOSITORY, - ) - if artifact_repository is None: artifact_repository = ( session._get_default_artifact_repository() diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index a9e585d74b..c21ddec5eb 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -48,6 +48,11 @@ # example: _integral_type_default_precision = {IntegerType: 9}, IntegerType default _precision is 9 now _integral_type_default_precision = {} +# The fully qualified name of the Anaconda shared repository (conda channel). +_ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" +# In case of failures or the current default artifact repository is unset, we fallback to this +_DEFAULT_ARTIFACT_REPOSITORY = _ANACONDA_SHARED_REPOSITORY + def configure_development_features( *, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index b462e91d26..ad583da1a7 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -156,6 +156,8 @@ from snowflake.snowpark.context import ( _is_execution_environment_sandboxed_for_client, _use_scoped_temp_objects, + _ANACONDA_SHARED_REPOSITORY, + _DEFAULT_ARTIFACT_REPOSITORY, ) from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.dataframe_reader import DataFrameReader @@ -321,11 +323,6 @@ WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None WRITE_ARROW_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None -# The fully qualified name of the Anaconda shared repository (conda channel). -_ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" -# In case of failures or the current default artifact repository is unset, we fallback to this -_DEFAULT_ARTIFACT_REPOSITORY = _ANACONDA_SHARED_REPOSITORY - def _get_active_session() -> "Session": with _session_management_lock: diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 8f58b5ef2b..02998f4a3a 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -20,7 +20,7 @@ get_signature, ) from snowflake.snowpark.functions import call_udf, col, count_distinct, sproc, udf -from snowflake.snowpark.session import _ANACONDA_SHARED_REPOSITORY +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import DateType, StringType from tests.utils import IS_IN_STORED_PROC, TempObjectType, TestFiles, Utils diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 55dd6af154..bd42a3b334 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3083,7 +3083,9 @@ def test_urllib() -> str: ) def test_use_default_artifact_repository(db_parameters): with Session.builder.configs(db_parameters).create() as session: - session.use_schema("public") + temp_schema = Utils.random_temp_schema() + session.sql(f"create schema {temp_schema}").collect() + session.sql(f"use schema {temp_schema}").collect() session.sql( "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" ).collect() @@ -3115,4 +3117,4 @@ def test_art() -> str: finally: session._run_query(f"drop function if exists {temp_func_name}(int)") - session.sql("ALTER schema unset DEFAULT_PYTHON_ARTIFACT_REPOSITORY").collect() + session.sql(f"drop schema {temp_schema}").collect() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 7adc1c11b0..7a64f052d4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -28,10 +28,8 @@ SnowparkInvalidObjectNameException, SnowparkSessionException, ) -from snowflake.snowpark.session import ( - _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, - _ANACONDA_SHARED_REPOSITORY, -) +from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import StructField, StructType diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index a3f9e03fc3..a301d8769e 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -26,7 +26,7 @@ resolve_packages_in_client_side_sandbox, ) from snowflake.snowpark._internal.utils import TempObjectType -from snowflake.snowpark.session import _ANACONDA_SHARED_REPOSITORY +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import StringType from snowflake.snowpark.version import VERSION