diff --git a/CHANGELOG.md b/CHANGELOG.md index 86d839b42d..bbbf449201 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ #### New Features - Added support for the `DECFLOAT` data type that allows users to represent decimal numbers exactly with 38 digits of precision and a dynamic base-10 exponent. +- Added support for the `DEFAULT_PYTHON_ARTIFACT_REPOSITORY` parameter that allows users to configure the default artifact repository at the account, database, and schema level. #### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 1d80eea873..7c0f9e01f4 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -59,6 +59,10 @@ ) from snowflake.snowpark.types import DataType, StructField, StructType from snowflake.snowpark.version import VERSION +from snowflake.snowpark.context import ( + _ANACONDA_SHARED_REPOSITORY, + _DEFAULT_ARTIFACT_REPOSITORY, +) if installed_pandas: from snowflake.snowpark.types import ( @@ -1122,6 +1126,7 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}): def add_snowpark_package_to_sproc_packages( session: Optional["snowflake.snowpark.Session"], packages: Optional[List[Union[str, ModuleType]]], + artifact_repository: str, ) -> List[Union[str, ModuleType]]: major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -1137,8 +1142,11 @@ def add_snowpark_package_to_sproc_packages( packages = [this_package] else: with session._package_lock: - if package_name not in session._packages: - packages = list(session._packages.values()) + [this_package] + existing_packages = session._artifact_repository_packages[ + artifact_repository + ] + if package_name not in existing_packages: + packages = list(existing_packages.values()) + [this_package] return packages return add_package_to_existing_packages(packages, package_name, this_package) @@ -1223,17 +1231,30 @@ def resolve_imports_and_packages( Optional[str], bool, ]: - if artifact_repository and artifact_repository != "conda": - # Artifact Repository packages are not resolved + if artifact_repository is None: + artifact_repository = ( + session._get_default_artifact_repository() + if session + else _DEFAULT_ARTIFACT_REPOSITORY + ) + + existing_packages_dict = ( + session._artifact_repository_packages[artifact_repository] if session else {} + ) + + if artifact_repository != _ANACONDA_SHARED_REPOSITORY: + # Non-conda artifact repository - skip conda-based package resolution resolved_packages = [] if not packages and session: resolved_packages = list( - session._resolve_packages([], artifact_repository=artifact_repository) + session._resolve_packages( + [], artifact_repository, existing_packages_dict + ) ) elif packages: if not all(isinstance(package, str) for package in packages): raise TypeError( - "Artifact repository requires that all packages be passed as str." + "Non-conda artifact repository requires that all packages be passed as str." ) try: has_cloudpickle = bool( @@ -1256,7 +1277,7 @@ def resolve_imports_and_packages( ) else: - # resolve packages + # resolve packages using conda channel if session is None: # In case of sandbox resolved_packages = resolve_packages_in_client_side_sandbox( packages=packages @@ -1265,6 +1286,8 @@ def resolve_imports_and_packages( resolved_packages = ( session._resolve_packages( packages, + artifact_repository, + {}, # ignore session packages if passed in explicitly include_pandas=is_pandas_udf, statement_params=statement_params, _suppress_local_package_warnings=_suppress_local_package_warnings, @@ -1272,7 +1295,8 @@ def resolve_imports_and_packages( if packages is not None else session._resolve_packages( [], - session._packages, + artifact_repository, + existing_packages_dict, validate_package=False, include_pandas=is_pandas_udf, statement_params=statement_params, diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index a9e585d74b..c21ddec5eb 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -48,6 +48,11 @@ # example: _integral_type_default_precision = {IntegerType: 9}, IntegerType default _precision is 9 now _integral_type_default_precision = {} +# The fully qualified name of the Anaconda shared repository (conda channel). +_ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository" +# In case of failures or the current default artifact repository is unset, we fallback to this +_DEFAULT_ARTIFACT_REPOSITORY = _ANACONDA_SHARED_REPOSITORY + def configure_development_features( *, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 99293569ff..ad583da1a7 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -156,6 +156,8 @@ from snowflake.snowpark.context import ( _is_execution_environment_sandboxed_for_client, _use_scoped_temp_objects, + _ANACONDA_SHARED_REPOSITORY, + _DEFAULT_ARTIFACT_REPOSITORY, ) from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.dataframe_reader import DataFrameReader @@ -599,10 +601,17 @@ def __init__( self._conn = conn self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} - self._packages: Dict[str, str] = {} + # map of artifact repository name -> packages that should be added to functions under that repository self._artifact_repository_packages: DefaultDict[ str, Dict[str, str] ] = defaultdict(dict) + # Single-entry cache for the default artifact repository value. + # Stores a tuple of ((database, schema), cached_value). Only one entry is + # kept at a time – switching to a different database/schema will evict the old + # value and trigger a fresh query on the next call. + self._default_artifact_repository_cache: Optional[ + Tuple[Tuple[Optional[str], Optional[str]], str] + ] = None self._session_id = self._conn.get_session_id() self._session_info = f""" "version" : {get_version()}, @@ -1598,11 +1607,13 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s Args: artifact_repository: When set this will function will return the packages for a specific artifact repository. + Otherwise, uses the default artifact repository configured in the current context. """ + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + with self._package_lock: - if artifact_repository: - return self._artifact_repository_packages[artifact_repository].copy() - return self._packages.copy() + return self._artifact_repository_packages[artifact_repository].copy() def add_packages( self, @@ -1629,7 +1640,8 @@ def add_packages( for this argument. If a ``module`` object is provided, the package will be installed with the version in the local environment. artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions - using that repository will use the packages. (Default None) + using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in the + current context. Example:: @@ -1669,10 +1681,13 @@ def add_packages( to ensure the consistent experience of a UDF between your local environment and the Snowflake server. """ + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + self._resolve_packages( parse_positional_args_to_list(*packages), - self._packages, - artifact_repository=artifact_repository, + artifact_repository, + self._artifact_repository_packages[artifact_repository], ) def remove_package( @@ -1686,7 +1701,8 @@ def remove_package( Args: package: The package name. artifact_repository: When set this parameter specifies that the package should be removed - from the default packages for a specific artifact repository. + from the default packages for a specific artifact repository. Otherwise, uses the default + artifact repository configured in the current context. Examples:: @@ -1704,17 +1720,13 @@ def remove_package( 0 """ package_name = Requirement(package).name + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + with self._package_lock: - if ( - artifact_repository is not None - and package_name - in self._artifact_repository_packages.get(artifact_repository, {}) - ): - self._artifact_repository_packages[artifact_repository].pop( - package_name - ) - elif package_name in self._packages: - self._packages.pop(package_name) + packages = self._artifact_repository_packages[artifact_repository] + if package_name in packages: + packages.pop(package_name) else: raise ValueError(f"{package_name} is not in the package list") @@ -1726,11 +1738,11 @@ def clear_packages( Clears all third-party packages of a user-defined function (UDF). When artifact_repository is set packages are only clear from the specified repository. """ + if artifact_repository is None: + artifact_repository = self._get_default_artifact_repository() + with self._package_lock: - if artifact_repository is not None: - self._artifact_repository_packages.get(artifact_repository, {}).clear() - else: - self._packages.clear() + self._artifact_repository_packages[artifact_repository].clear() def add_requirements( self, @@ -1747,7 +1759,8 @@ def add_requirements( Args: file_path: The path of a local requirement file. artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions - using that repository will use the packages. (Default None) + using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in + the current context. Example:: @@ -2097,11 +2110,11 @@ def _get_req_identifiers_list( def _resolve_packages( self, packages: List[Union[str, ModuleType]], - existing_packages_dict: Optional[Dict[str, str]] = None, + artifact_repository: str, + existing_packages_dict: Dict[str, str], validate_package: bool = True, include_pandas: bool = False, statement_params: Optional[Dict[str, str]] = None, - artifact_repository: Optional[str] = None, **kwargs, ) -> List[str]: """ @@ -2128,18 +2141,12 @@ def _resolve_packages( package_dict = self._parse_packages(packages) if ( isinstance(self._conn, MockServerConnection) - or artifact_repository is not None + or artifact_repository != _ANACONDA_SHARED_REPOSITORY ): - # in local testing we don't resolve the packages, we just return what is added + # in local testing or non-conda, we don't resolve the packages, we just return what is added errors = [] with self._package_lock: - if artifact_repository is None: - result_dict = self._packages - else: - result_dict = self._artifact_repository_packages[ - artifact_repository - ] - + result_dict = existing_packages_dict for pkg_name, _, pkg_req in package_dict.values(): if ( pkg_name in result_dict @@ -2377,6 +2384,50 @@ def _upload_unsupported_packages( return supported_dependencies + new_dependencies + def _get_default_artifact_repository(self) -> str: + """ + Returns the default artifact repository for the current session context + by calling SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY. + + The result is cached per (database, schema) pair so that + repeated invocations in the same context do not issue + redundant system-function queries. Only one cache entry is kept at + a time; switching to a different database or schema evicts the + previous entry and triggers a fresh query on the next call. + + Falls back to the Snowflake default artifact repository if: + - the session uses a mock connection (local testing), or + - the system function is not available / fails, or + - the system function returns NULL (value was never set). + """ + with self._package_lock: + if isinstance(self._conn, MockServerConnection): + return _DEFAULT_ARTIFACT_REPOSITORY + + cache_key = (self.get_current_database(), self.get_current_schema()) + + if ( + self._default_artifact_repository_cache is not None + and self._default_artifact_repository_cache[0] == cache_key + ): + return self._default_artifact_repository_cache[1] + + try: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + result = self._run_query( + f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')" + ) + value = result[0][0] if result else None + resolved = value or _DEFAULT_ARTIFACT_REPOSITORY + except Exception as e: + _logger.warning( + f"Error getting default artifact repository: {e}. Using fallback: {_DEFAULT_ARTIFACT_REPOSITORY}." + ) + resolved = _DEFAULT_ARTIFACT_REPOSITORY + + self._default_artifact_repository_cache = (cache_key, resolved) + return resolved + def _is_anaconda_terms_acknowledged(self) -> bool: return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0] diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 5a1a6c81d8..469855def5 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -939,10 +939,21 @@ def _do_register_sp( UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names[1:]) ] + effective_artifact_repository = artifact_repository + if effective_artifact_repository is None: + from snowflake.snowpark.session import _DEFAULT_ARTIFACT_REPOSITORY + + effective_artifact_repository = ( + self._session._get_default_artifact_repository() + if self._session + else _DEFAULT_ARTIFACT_REPOSITORY + ) + # Add in snowflake-snowpark-python if it is not already in the package list. packages = add_snowpark_package_to_sproc_packages( session=self._session, packages=packages, + artifact_repository=effective_artifact_repository, ) ( @@ -967,7 +978,7 @@ def _do_register_sp( skip_upload_on_content_match=skip_upload_on_content_match, is_permanent=is_permanent, force_inline_code=force_inline_code, - artifact_repository=artifact_repository, + artifact_repository=effective_artifact_repository, _suppress_local_package_warnings=kwargs.get( "_suppress_local_package_warnings", False ), diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 64b8fee018..02998f4a3a 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -20,6 +20,7 @@ get_signature, ) from snowflake.snowpark.functions import call_udf, col, count_distinct, sproc, udf +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import DateType, StringType from tests.utils import IS_IN_STORED_PROC, TempObjectType, TestFiles, Utils @@ -269,7 +270,10 @@ def extract_major_minor_patch(version_string): return match.group(1) if match else version_string resolved_packages = session._resolve_packages( - [numpy, pandas, dateutil], validate_package=False + [numpy, pandas, dateutil], + _ANACONDA_SHARED_REPOSITORY, + {}, + validate_package=False, ) # resolved_packages is a list of strings like # ['numpy==2.0.2', 'pandas==2.3.0', 'python-dateutil==2.9.0.post0', 'cloudpickle==3.0.0'] @@ -1200,10 +1204,17 @@ def test_replicate_local_environment(session): "force_push": True, } - assert not any([package.startswith("cloudpickle") for package in session._packages]) + assert not any( + [ + package.startswith("cloudpickle") + for package in session._artifact_repository_packages[ + _ANACONDA_SHARED_REPOSITORY + ] + ] + ) def naive_add_packages(self, packages): - self._packages = packages + self._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = packages with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True): with patch.object(Session, "add_packages", new=naive_add_packages): @@ -1217,10 +1228,22 @@ def naive_add_packages(self, packages): }, ) - assert any([package.startswith("cloudpickle==") for package in session._packages]) + assert any( + [ + package.startswith("cloudpickle==") + for package in session._artifact_repository_packages[ + _ANACONDA_SHARED_REPOSITORY + ] + ] + ) for default_package in DEFAULT_PACKAGES: assert not any( - [package.startswith(default_package) for package in session._packages] + [ + package.startswith(default_package) + for package in session._artifact_repository_packages[ + _ANACONDA_SHARED_REPOSITORY + ] + ] ) session.clear_packages() @@ -1239,12 +1262,29 @@ def naive_add_packages(self, packages): ignore_packages=ignored_packages, relax=True ) - assert any([package == "cloudpickle" for package in session._packages]) + assert any( + [ + package == "cloudpickle" + for package in session._artifact_repository_packages[ + _ANACONDA_SHARED_REPOSITORY + ] + ] + ) for default_package in DEFAULT_PACKAGES: assert not any( - [package.startswith(default_package) for package in session._packages] + [ + package.startswith(default_package) + for package in session._artifact_repository_packages[ + _ANACONDA_SHARED_REPOSITORY + ] + ] ) for ignored_package in ignored_packages: assert not any( - [package.startswith(ignored_package) for package in session._packages] + [ + package.startswith(ignored_package) + for package in session._artifact_repository_packages[ + _ANACONDA_SHARED_REPOSITORY + ] + ] ) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 5cfbd26b19..0499a01a3d 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -3095,3 +3095,51 @@ def test_urllib() -> str: ) df = session.create_dataframe([1]).to_df(["a"]) Utils.check_answer(df.select(ar_udf()), [Row("test")]) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="artifact repository not supported in local testing", +) +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources") +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" +) +def test_use_default_artifact_repository(db_parameters): + with Session.builder.configs(db_parameters).create() as session: + temp_schema = Utils.random_temp_schema() + session.sql(f"create schema {temp_schema}").collect() + session.sql(f"use schema {temp_schema}").collect() + session.sql( + "ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true" + ).collect() + session.sql( + "ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY" + ).collect() + + session.add_packages("art", "cloudpickle") + + def test_art() -> str: + import art # art is not available in the conda channel, but is in pypi + + _ = art.text2art("test") + return "art works!" + + temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + + try: + # Test function registration + udf( + session=session, + func=test_art, + name=temp_func_name, + ) + + # Test UDF call + df = session.create_dataframe([1]).to_df(["a"]) + Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("art works!")]) + finally: + session._run_query(f"drop function if exists {temp_func_name}(int)") + + session.sql(f"drop schema {temp_schema}").collect() diff --git a/tests/integ/utils/sql_counter.py b/tests/integ/utils/sql_counter.py index 9872c87cfb..e44b4c5957 100644 --- a/tests/integ/utils/sql_counter.py +++ b/tests/integ/utils/sql_counter.py @@ -87,6 +87,7 @@ # that this parameter is unset, as currently required by Snowpark pandas. # 8. SHOW OBJECTS LIKE [TABLE_NAME] IN SCHEMA [SCHEMA] LIMIT 1 ... is to check the row count of a table we are reading # from, if it exists +# 9. SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY is cached per DB/schema context, so it's not consistent when it is executed FILTER_OUT_QUERIES = [ ["create SCOPED TEMPORARY", "stage if not exists"], ["PUT", "file:///tmp/placeholder/snowpark.zip"], @@ -96,6 +97,7 @@ ["drop table if exists", "/* internal query to drop unused temp table */"], ["SHOW PARAMETERS LIKE", "QUOTED_IDENTIFIERS_IGNORE_CASE"], ["SHOW OBJECTS LIKE", "IN SCHEMA", "LIMIT 1"], + ["SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY"], ] # define global at module-level @@ -486,7 +488,7 @@ def sql_count_checker( } # also look into kwargs for count configuration. Right now, describe_count and window_count are the # counts can be passed optionally - for (key, value) in kwargs.items(): + for key, value in kwargs.items(): if key.endswith("_count"): count_kwargs.update({key: value}) @@ -530,7 +532,7 @@ def get_readable_sql_count_values(tr): def update_test_code_with_sql_counts( - sql_count_records: Dict[str, Dict[str, List[Dict[str, Optional[PythonScalar]]]]] + sql_count_records: Dict[str, Dict[str, List[Dict[str, Optional[PythonScalar]]]]], ): """This helper takes sql count records and rewrites the source test files to validate sql counts where possible. diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 20651499d3..7a64f052d4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -29,6 +29,7 @@ SnowparkSessionException, ) from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import StructField, StructType @@ -211,7 +212,11 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True session.table.side_effect = mock_get_information_schema_packages session._resolve_packages( - ["random_package_name"], validate_package=True, include_pandas=False + ["random_package_name"], + _ANACONDA_SHARED_REPOSITORY, + {}, + validate_package=True, + include_pandas=False, ) @@ -242,7 +247,11 @@ def run_query(sql: str): "#using-third-party-packages-from-anaconda.", ): session._resolve_packages( - ["random_package_name"], validate_package=True, include_pandas=False + ["random_package_name"], + _ANACONDA_SHARED_REPOSITORY, + {}, + validate_package=True, + include_pandas=False, ) @@ -264,7 +273,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True resolved_packages = session._resolve_packages( ["random_package_name"], - existing_packages_dict=existing_packages, + _ANACONDA_SHARED_REPOSITORY, + existing_packages, validate_package=True, include_pandas=False, ) @@ -295,6 +305,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True ): session._resolve_packages( ["snowflake-snowpark-python"], + _ANACONDA_SHARED_REPOSITORY, + {}, validate_package=True, include_pandas=False, _suppress_local_package_warnings=True, @@ -304,6 +316,38 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True assert caplog.text == "" +def test_resolve_packages_non_conda_artifact_repository(mock_server_connection): + session = Session(mock_server_connection) + + existing_packages = {} + + def assert_packages(packages): + assert sorted(packages) == [ + "cloudpickle==1.0.0", + "snowflake-snowpark-python==1.0.0", + ] + assert existing_packages == { + "snowflake-snowpark-python": "snowflake-snowpark-python==1.0.0", + "cloudpickle": "cloudpickle==1.0.0", + } + + packages = session._resolve_packages( + ["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"], + "snowflake.snowpark.pypi_shared_repository", + existing_packages, + ) + + assert_packages(packages) + + packages = session._resolve_packages( + [], + "snowflake.snowpark.pypi_shared_repository", + existing_packages, + ) + + assert_packages(packages) + + @pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas") def test_write_pandas_wrong_table_type(mock_server_connection): session = Session(mock_server_connection) @@ -674,3 +718,50 @@ def test_parameter_version(version_value, expected_parameter_value, parameter_na ) session = Session(fake_server_connection) assert getattr(session, parameter_name, None) is expected_parameter_value + + +def test_get_default_artifact_repository(): + fake_server_connection = mock.create_autospec(ServerConnection) + fake_server_connection._thread_safe_session_enabled = True + session = Session(fake_server_connection) + + with mock.patch.object( + session, + "_run_query", + return_value=[["snowflake.snowpark.pypi_shared_repository"]], + ) as mocked_run_query, mock.patch.object( + session, "get_current_database", return_value="DB1" + ), mock.patch.object( + session, "get_current_schema", return_value="SCHEMA1" + ): + result = session._get_default_artifact_repository() + assert result == "snowflake.snowpark.pypi_shared_repository" + + result = session._get_default_artifact_repository() + assert result == "snowflake.snowpark.pypi_shared_repository" + + assert mocked_run_query.call_count == 1 + + with mock.patch.object( + session, "_run_query", return_value=[[None]] + ) as mocked_run_query, mock.patch.object( + session, "get_current_database", return_value="DB2" + ), mock.patch.object( + session, "get_current_schema", return_value="SCHEMA2" + ): + result = session._get_default_artifact_repository() + assert result == _ANACONDA_SHARED_REPOSITORY + + assert mocked_run_query.call_count == 1 + + with mock.patch.object( + session, "_run_query", side_effect=ProgrammingError("Not found") + ) as mocked_run_query, mock.patch.object( + session, "get_current_database", return_value="DB1" + ), mock.patch.object( + session, "get_current_schema", return_value="SCHEMA1" + ): + result = session._get_default_artifact_repository() + assert result == _ANACONDA_SHARED_REPOSITORY + + assert mocked_run_query.call_count == 1 diff --git a/tests/unit/test_stored_procedure.py b/tests/unit/test_stored_procedure.py index 638bae84d5..2e6aff999a 100644 --- a/tests/unit/test_stored_procedure.py +++ b/tests/unit/test_stored_procedure.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from unittest import mock @@ -42,7 +43,7 @@ def test_stored_procedure_execute_as(execute_as): fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session) fake_session._runtime_version_from_requirement = None - fake_session._packages = {} + fake_session._artifact_repository_packages = defaultdict(dict) def return1(_): return 1 @@ -90,7 +91,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.sproc = StoredProcedureRegistration(fake_session) - fake_session._packages = {} + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: sproc(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udaf.py b/tests/unit/test_udaf.py index 20da11c7f4..803f8dc886 100644 --- a/tests/unit/test_udaf.py +++ b/tests/unit/test_udaf.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from unittest import mock @@ -56,8 +57,8 @@ def test_do_register_udaf_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session._runtime_version_from_requirement = None - fake_session._packages = [] fake_session.udaf = UDAFRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: @udaf(session=fake_session) diff --git a/tests/unit/test_udf.py b/tests/unit/test_udf.py index 9a7b458f7f..d20c53d048 100644 --- a/tests/unit/test_udf.py +++ b/tests/unit/test_udf.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from unittest import mock @@ -32,6 +33,7 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session.udf = UDFRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: udf(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[]) assert ex_info.value.error_code == "1304" diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 623738c567..a301d8769e 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import logging import os import pickle @@ -25,6 +26,7 @@ resolve_packages_in_client_side_sandbox, ) from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from snowflake.snowpark.types import StringType from snowflake.snowpark.version import VERSION @@ -224,7 +226,9 @@ def test_resolve_imports_and_packages_imports_as_str(tmp_path_factory): ) def test_add_snowpark_package_to_sproc_packages_add_package(packages): old_packages_length = len(packages) if packages else 0 - result = add_snowpark_package_to_sproc_packages(session=None, packages=packages) + result = add_snowpark_package_to_sproc_packages( + session=None, packages=packages, artifact_repository=_ANACONDA_SHARED_REPOSITORY + ) major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -240,7 +244,9 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): "random_package_two", "snowflake-snowpark-python==1.12.0", ] - result = add_snowpark_package_to_sproc_packages(session=None, packages=packages) + result = add_snowpark_package_to_sproc_packages( + session=None, packages=packages, artifact_repository=_ANACONDA_SHARED_REPOSITORY + ) assert len(result) == len(packages) assert "snowflake-snowpark-python==1.12.0" in result @@ -248,12 +254,17 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package(): def test_add_snowpark_package_to_sproc_packages_to_session(): fake_session = mock.create_autospec(Session) - fake_session._packages = { + fake_session._artifact_repository_packages = defaultdict(dict) + fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = { "random_package_one": "random_package_one", "random_package_two": "random_package_two", } fake_session._package_lock = threading.RLock() - result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) + result = add_snowpark_package_to_sproc_packages( + session=fake_session, + packages=None, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + ) major, minor, patch = VERSION package_name = "snowflake-snowpark-python" @@ -261,10 +272,14 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): assert len(result) == 3 assert final_name in result - fake_session._packages[ + fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY][ "snowflake-snowpark-python" ] = "snowflake-snowpark-python==1.12.0" - result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) + result = add_snowpark_package_to_sproc_packages( + session=fake_session, + packages=None, + artifact_repository=_ANACONDA_SHARED_REPOSITORY, + ) assert result is None diff --git a/tests/unit/test_udtf.py b/tests/unit/test_udtf.py index 794fa82b8c..ba986364ed 100644 --- a/tests/unit/test_udtf.py +++ b/tests/unit/test_udtf.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict import sys from typing import Tuple from unittest import mock @@ -39,8 +40,8 @@ def test_do_register_sp_negative(cleanup_registration_patch): ) fake_session._run_query = mock.Mock(side_effect=ProgrammingError()) fake_session._runtime_version_from_requirement = None - fake_session._packages = [] fake_session.udtf = UDTFRegistration(fake_session) + fake_session._artifact_repository_packages = defaultdict(dict) with pytest.raises(SnowparkSQLException) as ex_info: @udtf(output_schema=["num"], session=fake_session)