Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,9 +1142,9 @@ def add_snowpark_package_to_sproc_packages(
packages = [this_package]
else:
with session._package_lock:
existing_packages = session._artifact_repository_packages[
existing_packages = session._get_packages_by_artifact_repository(
artifact_repository
]
)
if package_name not in existing_packages:
packages = list(existing_packages.values()) + [this_package]
return packages
Expand Down Expand Up @@ -1239,7 +1239,9 @@ def resolve_imports_and_packages(
)

existing_packages_dict = (
session._artifact_repository_packages[artifact_repository] if session else {}
session._get_packages_by_artifact_repository(artifact_repository)
if session
else {}
)

if artifact_repository != _ANACONDA_SHARED_REPOSITORY:
Expand All @@ -1248,7 +1250,9 @@ def resolve_imports_and_packages(
if not packages and session:
resolved_packages = list(
session._resolve_packages(
[], artifact_repository, existing_packages_dict
[],
artifact_repository=artifact_repository,
existing_packages_dict=existing_packages_dict,
)
)
elif packages:
Expand Down Expand Up @@ -1286,17 +1290,17 @@ def resolve_imports_and_packages(
resolved_packages = (
session._resolve_packages(
packages,
artifact_repository,
{}, # ignore session packages if passed in explicitly
artifact_repository=artifact_repository,
existing_packages_dict={}, # ignore session packages if passed in explicitly
include_pandas=is_pandas_udf,
statement_params=statement_params,
_suppress_local_package_warnings=_suppress_local_package_warnings,
)
if packages is not None
else session._resolve_packages(
[],
artifact_repository,
existing_packages_dict,
artifact_repository=artifact_repository,
existing_packages_dict=existing_packages_dict,
validate_package=False,
include_pandas=is_pandas_udf,
statement_params=statement_params,
Expand Down
48 changes: 39 additions & 9 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ def __init__(
self._conn = conn
self._query_tag = None
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
# packages under the DEFAULT_ARTIFACT_REPOSITORY
# due to server side accessing private session members, this cannot be merged with _artifact_repository_packages
self._packages: Dict[str, str] = {}
# map of artifact repository name -> packages that should be added to functions under that repository
self._artifact_repository_packages: DefaultDict[
str, Dict[str, str]
Expand Down Expand Up @@ -1599,6 +1602,14 @@ def _list_files_in_stage(
prefix_length = get_stage_file_prefix_length(stage_location)
return {str(row[0])[prefix_length:] for row in file_list}

def _get_packages_by_artifact_repository(
self, artifact_repository: str
) -> Dict[str, str]:
if artifact_repository == _DEFAULT_ARTIFACT_REPOSITORY:
return self._packages
else:
return self._artifact_repository_packages[artifact_repository]

def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, str]:
"""
Returns a ``dict`` of packages added for user-defined functions (UDFs).
Expand All @@ -1613,7 +1624,7 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s
artifact_repository = self._get_default_artifact_repository()

with self._package_lock:
return self._artifact_repository_packages[artifact_repository].copy()
return self._get_packages_by_artifact_repository(artifact_repository).copy()

def add_packages(
self,
Expand Down Expand Up @@ -1686,8 +1697,10 @@ def add_packages(

self._resolve_packages(
parse_positional_args_to_list(*packages),
artifact_repository,
self._artifact_repository_packages[artifact_repository],
artifact_repository=artifact_repository,
existing_packages_dict=self._get_packages_by_artifact_repository(
artifact_repository
),
)

def remove_package(
Expand Down Expand Up @@ -1724,7 +1737,7 @@ def remove_package(
artifact_repository = self._get_default_artifact_repository()

with self._package_lock:
packages = self._artifact_repository_packages[artifact_repository]
packages = self._get_packages_by_artifact_repository(artifact_repository)
if package_name in packages:
packages.pop(package_name)
else:
Expand All @@ -1742,7 +1755,7 @@ def clear_packages(
artifact_repository = self._get_default_artifact_repository()

with self._package_lock:
self._artifact_repository_packages[artifact_repository].clear()
self._get_packages_by_artifact_repository(artifact_repository).clear()

def add_requirements(
self,
Expand Down Expand Up @@ -2110,11 +2123,11 @@ def _get_req_identifiers_list(
def _resolve_packages(
self,
packages: List[Union[str, ModuleType]],
artifact_repository: str,
existing_packages_dict: Dict[str, str],
existing_packages_dict: Dict[str, str] = None,
validate_package: bool = True,
include_pandas: bool = False,
statement_params: Optional[Dict[str, str]] = None,
artifact_repository: str = None,
**kwargs,
) -> List[str]:
"""
Expand All @@ -2132,6 +2145,13 @@ def _resolve_packages(
Returns:
List[str]: List of package specifiers
"""
if artifact_repository is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit on styling: can this be written as a = a or b?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, after formatting it will get turned into

artifact_repository = (
            artifact_repository or self._get_default_artifact_repository()
        )

which seems about the same tbh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I think you would have very subtle perf gain if you write it as a = a or b (you can test it by using timeit)
but since this is not on the hot path, and compared to what the rest code is doing, this overhead is negligible.

artifact_repository = self._get_default_artifact_repository()
if existing_packages_dict is None:
existing_packages_dict = self._get_packages_by_artifact_repository(
artifact_repository
)

# Always include cloudpickle
extra_modules = [cloudpickle]
if include_pandas:
Expand Down Expand Up @@ -2404,7 +2424,10 @@ def _get_default_artifact_repository(self) -> str:
if isinstance(self._conn, MockServerConnection):
return _DEFAULT_ARTIFACT_REPOSITORY

cache_key = (self.get_current_database(), self.get_current_schema())
account = self.get_current_account()
database = self.get_current_database()
schema = self.get_current_schema()
cache_key = (database, schema)

if (
self._default_artifact_repository_cache is not None
Expand All @@ -2414,8 +2437,15 @@ def _get_default_artifact_repository(self) -> str:

try:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
entity_selector_args = (
f"'schema', '{schema}'"
if schema
else f"'database', '{database}'"
if database
else f"'account', '{account}'"
)
result = self._run_query(
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')"
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}', {entity_selector_args})"
)
value = result[0][0] if result else None
resolved = value or _DEFAULT_ARTIFACT_REPOSITORY
Expand Down
54 changes: 9 additions & 45 deletions tests/integ/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def extract_major_minor_patch(version_string):

resolved_packages = session._resolve_packages(
[numpy, pandas, dateutil],
_ANACONDA_SHARED_REPOSITORY,
{},
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
existing_packages_dict={},
validate_package=False,
)
# resolved_packages is a list of strings like
Expand Down Expand Up @@ -1204,17 +1204,10 @@ def test_replicate_local_environment(session):
"force_push": True,
}

assert not any(
[
package.startswith("cloudpickle")
for package in session._artifact_repository_packages[
_ANACONDA_SHARED_REPOSITORY
]
]
)
assert not any([package.startswith("cloudpickle") for package in session._packages])

def naive_add_packages(self, packages):
self._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = packages
self._packages = packages

with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True):
with patch.object(Session, "add_packages", new=naive_add_packages):
Expand All @@ -1228,22 +1221,10 @@ def naive_add_packages(self, packages):
},
)

assert any(
[
package.startswith("cloudpickle==")
for package in session._artifact_repository_packages[
_ANACONDA_SHARED_REPOSITORY
]
]
)
assert any([package.startswith("cloudpickle==") for package in session._packages])
for default_package in DEFAULT_PACKAGES:
assert not any(
[
package.startswith(default_package)
for package in session._artifact_repository_packages[
_ANACONDA_SHARED_REPOSITORY
]
]
[package.startswith(default_package) for package in session._packages]
)

session.clear_packages()
Expand All @@ -1262,29 +1243,12 @@ def naive_add_packages(self, packages):
ignore_packages=ignored_packages, relax=True
)

assert any(
[
package == "cloudpickle"
for package in session._artifact_repository_packages[
_ANACONDA_SHARED_REPOSITORY
]
]
)
assert any([package == "cloudpickle" for package in session._packages])
for default_package in DEFAULT_PACKAGES:
assert not any(
[
package.startswith(default_package)
for package in session._artifact_repository_packages[
_ANACONDA_SHARED_REPOSITORY
]
]
[package.startswith(default_package) for package in session._packages]
)
for ignored_package in ignored_packages:
assert not any(
[
package.startswith(ignored_package)
for package in session._artifact_repository_packages[
_ANACONDA_SHARED_REPOSITORY
]
]
[package.startswith(ignored_package) for package in session._packages]
)
30 changes: 24 additions & 6 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3108,18 +3108,17 @@ def test_urllib() -> str:
)
def test_use_default_artifact_repository(db_parameters):
with Session.builder.configs(db_parameters).create() as session:
temp_database = Utils.random_temp_database()
temp_schema = Utils.random_temp_schema()
session.sql(f"create schema {temp_schema}").collect()
session.sql(f"use schema {temp_schema}").collect()
session.sql(f"create database {temp_database}").collect()
session.sql(f"use database {temp_database}").collect()
session.sql(
"ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true"
).collect()
session.sql(
"ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY"
"ALTER database set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.anaconda_shared_repository"
).collect()

session.add_packages("art", "cloudpickle")

def test_art() -> str:
import art # art is not available in the conda channel, but is in pypi

Expand All @@ -3128,6 +3127,25 @@ def test_art() -> str:

temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)

# should not work in the database where the default is anaconda
with pytest.raises(
Exception,
match="Cannot add package art because it is not available in Snowflake",
):
udf(
session=session,
func=test_art,
name=temp_func_name,
packages=["art", "cloudpickle"],
)

session.sql(f"create schema {temp_schema}").collect()
session.use_schema(temp_schema)
session.sql(
"ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY"
).collect()
session.add_packages("art", "cloudpickle")

try:
# Test function registration
udf(
Expand All @@ -3142,4 +3160,4 @@ def test_art() -> str:
finally:
session._run_query(f"drop function if exists {temp_func_name}(int)")

session.sql(f"drop schema {temp_schema}").collect()
session.sql(f"drop database {temp_database}").collect()
24 changes: 12 additions & 12 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True

session._resolve_packages(
["random_package_name"],
_ANACONDA_SHARED_REPOSITORY,
{},
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
existing_packages_dict={},
validate_package=True,
include_pandas=False,
)
Expand Down Expand Up @@ -248,8 +248,8 @@ def run_query(sql: str):
):
session._resolve_packages(
["random_package_name"],
_ANACONDA_SHARED_REPOSITORY,
{},
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
existing_packages_dict={},
validate_package=True,
include_pandas=False,
)
Expand All @@ -273,8 +273,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True

resolved_packages = session._resolve_packages(
["random_package_name"],
_ANACONDA_SHARED_REPOSITORY,
existing_packages,
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
existing_packages_dict=existing_packages,
validate_package=True,
include_pandas=False,
)
Expand Down Expand Up @@ -305,8 +305,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
):
session._resolve_packages(
["snowflake-snowpark-python"],
_ANACONDA_SHARED_REPOSITORY,
{},
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
existing_packages_dict={},
validate_package=True,
include_pandas=False,
_suppress_local_package_warnings=True,
Expand All @@ -333,16 +333,16 @@ def assert_packages(packages):

packages = session._resolve_packages(
["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"],
"snowflake.snowpark.pypi_shared_repository",
existing_packages,
artifact_repository="snowflake.snowpark.pypi_shared_repository",
existing_packages_dict=existing_packages,
)

assert_packages(packages)

packages = session._resolve_packages(
[],
"snowflake.snowpark.pypi_shared_repository",
existing_packages,
artifact_repository="snowflake.snowpark.pypi_shared_repository",
existing_packages_dict=existing_packages,
)

assert_packages(packages)
Expand Down
Loading