Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8aac920
initial implementation
sfc-gh-bkogan Feb 5, 2026
e481ef8
remove cloudpickle todo
sfc-gh-bkogan Feb 5, 2026
8e3950f
simple int test
sfc-gh-bkogan Feb 5, 2026
62cf6a0
update tests + fix bug
sfc-gh-bkogan Feb 5, 2026
e8dd42d
non conda resolve test
sfc-gh-bkogan Feb 5, 2026
97ce30d
remove _packages, merge with _artifact_repository_packages
sfc-gh-bkogan Feb 6, 2026
24a0185
remove more _packages
sfc-gh-bkogan Feb 6, 2026
07cc8b9
add default cache
sfc-gh-bkogan Feb 6, 2026
bde687f
fix UT
sfc-gh-bkogan Feb 6, 2026
97ef1aa
doc udpates
sfc-gh-bkogan Feb 6, 2026
1f45379
fix more tests
sfc-gh-bkogan Feb 6, 2026
44ee0e4
remove debug log
sfc-gh-bkogan Feb 6, 2026
e2b71f0
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 6, 2026
026c024
cleanup
sfc-gh-bkogan Feb 6, 2026
eea6081
wrap getting the default in a lock
sfc-gh-bkogan Feb 6, 2026
44f7114
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 6, 2026
9746f7f
update comment
sfc-gh-bkogan Feb 10, 2026
8562a72
update test
sfc-gh-bkogan Feb 10, 2026
4a51691
try using test artifact repo
sfc-gh-bkogan Feb 10, 2026
273b028
filter out system call checks in modin
sfc-gh-bkogan Feb 10, 2026
1bbd89e
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 17, 2026
4f3d174
private
sfc-gh-bkogan Feb 17, 2026
70e97e7
move constants to context + temp schema in test
sfc-gh-bkogan Feb 18, 2026
348210f
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#### New Features

- Added support for the `DECFLOAT` data type that allows users to represent decimal numbers exactly with 38 digits of precision and a dynamic base-10 exponent.
- Added support for the `DEFAULT_PYTHON_ARTIFACT_REPOSITORY` parameter that allows users to configure the default artifact repository at the account, database, and schema level.

#### Bug Fixes

Expand Down
40 changes: 32 additions & 8 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
)
from snowflake.snowpark.types import DataType, StructField, StructType
from snowflake.snowpark.version import VERSION
from snowflake.snowpark.context import (
_ANACONDA_SHARED_REPOSITORY,
_DEFAULT_ARTIFACT_REPOSITORY,
)

if installed_pandas:
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -1122,6 +1126,7 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}):
def add_snowpark_package_to_sproc_packages(
session: Optional["snowflake.snowpark.Session"],
packages: Optional[List[Union[str, ModuleType]]],
artifact_repository: str,
) -> List[Union[str, ModuleType]]:
major, minor, patch = VERSION
package_name = "snowflake-snowpark-python"
Expand All @@ -1137,8 +1142,11 @@ def add_snowpark_package_to_sproc_packages(
packages = [this_package]
else:
with session._package_lock:
if package_name not in session._packages:
packages = list(session._packages.values()) + [this_package]
existing_packages = session._artifact_repository_packages[
artifact_repository
]
if package_name not in existing_packages:
packages = list(existing_packages.values()) + [this_package]
return packages

return add_package_to_existing_packages(packages, package_name, this_package)
Expand Down Expand Up @@ -1223,17 +1231,30 @@ def resolve_imports_and_packages(
Optional[str],
bool,
]:
if artifact_repository and artifact_repository != "conda":
# Artifact Repository packages are not resolved
if artifact_repository is None:
artifact_repository = (
session._get_default_artifact_repository()
if session
else _DEFAULT_ARTIFACT_REPOSITORY
)

existing_packages_dict = (
session._artifact_repository_packages[artifact_repository] if session else {}
)

if artifact_repository != _ANACONDA_SHARED_REPOSITORY:
# Non-conda artifact repository - skip conda-based package resolution
resolved_packages = []
if not packages and session:
resolved_packages = list(
session._resolve_packages([], artifact_repository=artifact_repository)
session._resolve_packages(
[], artifact_repository, existing_packages_dict
)
)
elif packages:
if not all(isinstance(package, str) for package in packages):
raise TypeError(
"Artifact repository requires that all packages be passed as str."
"Non-conda artifact repository requires that all packages be passed as str."
)
try:
has_cloudpickle = bool(
Expand All @@ -1256,7 +1277,7 @@ def resolve_imports_and_packages(
)

else:
# resolve packages
# resolve packages using conda channel
if session is None: # In case of sandbox
resolved_packages = resolve_packages_in_client_side_sandbox(
packages=packages
Expand All @@ -1265,14 +1286,17 @@ def resolve_imports_and_packages(
resolved_packages = (
session._resolve_packages(
packages,
artifact_repository,
{}, # ignore session packages if passed in explicitly
include_pandas=is_pandas_udf,
statement_params=statement_params,
_suppress_local_package_warnings=_suppress_local_package_warnings,
)
if packages is not None
else session._resolve_packages(
[],
session._packages,
artifact_repository,
existing_packages_dict,
validate_package=False,
include_pandas=is_pandas_udf,
statement_params=statement_params,
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
# example: _integral_type_default_precision = {IntegerType: 9}, IntegerType default _precision is 9 now
_integral_type_default_precision = {}

# The fully qualified name of the Anaconda shared repository (conda channel).
_ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository"
# In case of failures or the current default artifact repository is unset, we fallback to this
_DEFAULT_ARTIFACT_REPOSITORY = _ANACONDA_SHARED_REPOSITORY


def configure_development_features(
*,
Expand Down
119 changes: 85 additions & 34 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@
from snowflake.snowpark.context import (
_is_execution_environment_sandboxed_for_client,
_use_scoped_temp_objects,
_ANACONDA_SHARED_REPOSITORY,
_DEFAULT_ARTIFACT_REPOSITORY,
)
from snowflake.snowpark.dataframe import DataFrame
from snowflake.snowpark.dataframe_reader import DataFrameReader
Expand Down Expand Up @@ -599,10 +601,17 @@ def __init__(
self._conn = conn
self._query_tag = None
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
self._packages: Dict[str, str] = {}
# map of artifact repository name -> packages that should be added to functions under that repository
self._artifact_repository_packages: DefaultDict[
str, Dict[str, str]
] = defaultdict(dict)
# Single-entry cache for the default artifact repository value.
# Stores a tuple of ((database, schema), cached_value). Only one entry is
# kept at a time – switching to a different database/schema will evict the old
# value and trigger a fresh query on the next call.
self._default_artifact_repository_cache: Optional[
Tuple[Tuple[Optional[str], Optional[str]], str]
] = None
self._session_id = self._conn.get_session_id()
self._session_info = f"""
"version" : {get_version()},
Expand Down Expand Up @@ -1598,11 +1607,13 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s

Args:
artifact_repository: When set this will function will return the packages for a specific artifact repository.
Otherwise, uses the default artifact repository configured in the current context.
"""
if artifact_repository is None:
artifact_repository = self._get_default_artifact_repository()

with self._package_lock:
if artifact_repository:
return self._artifact_repository_packages[artifact_repository].copy()
return self._packages.copy()
return self._artifact_repository_packages[artifact_repository].copy()

def add_packages(
self,
Expand All @@ -1629,7 +1640,8 @@ def add_packages(
for this argument. If a ``module`` object is provided, the package will be
installed with the version in the local environment.
artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions
using that repository will use the packages. (Default None)
using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in the
current context.

Example::

Expand Down Expand Up @@ -1669,10 +1681,13 @@ def add_packages(
to ensure the consistent experience of a UDF between your local environment
and the Snowflake server.
"""
if artifact_repository is None:
artifact_repository = self._get_default_artifact_repository()

self._resolve_packages(
parse_positional_args_to_list(*packages),
self._packages,
artifact_repository=artifact_repository,
artifact_repository,
self._artifact_repository_packages[artifact_repository],
)

def remove_package(
Expand All @@ -1686,7 +1701,8 @@ def remove_package(
Args:
package: The package name.
artifact_repository: When set this parameter specifies that the package should be removed
from the default packages for a specific artifact repository.
from the default packages for a specific artifact repository. Otherwise, uses the default
artifact repository configured in the current context.

Examples::

Expand All @@ -1704,17 +1720,13 @@ def remove_package(
0
"""
package_name = Requirement(package).name
if artifact_repository is None:
artifact_repository = self._get_default_artifact_repository()

with self._package_lock:
if (
artifact_repository is not None
and package_name
in self._artifact_repository_packages.get(artifact_repository, {})
):
self._artifact_repository_packages[artifact_repository].pop(
package_name
)
elif package_name in self._packages:
self._packages.pop(package_name)
packages = self._artifact_repository_packages[artifact_repository]
if package_name in packages:
packages.pop(package_name)
else:
raise ValueError(f"{package_name} is not in the package list")

Expand All @@ -1726,11 +1738,11 @@ def clear_packages(
Clears all third-party packages of a user-defined function (UDF). When artifact_repository
is set packages are only clear from the specified repository.
"""
if artifact_repository is None:
artifact_repository = self._get_default_artifact_repository()

with self._package_lock:
if artifact_repository is not None:
self._artifact_repository_packages.get(artifact_repository, {}).clear()
else:
self._packages.clear()
self._artifact_repository_packages[artifact_repository].clear()

def add_requirements(
self,
Expand All @@ -1747,7 +1759,8 @@ def add_requirements(
Args:
file_path: The path of a local requirement file.
artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions
using that repository will use the packages. (Default None)
using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in
the current context.

Example::

Expand Down Expand Up @@ -2097,11 +2110,11 @@ def _get_req_identifiers_list(
def _resolve_packages(
self,
packages: List[Union[str, ModuleType]],
existing_packages_dict: Optional[Dict[str, str]] = None,
artifact_repository: str,
existing_packages_dict: Dict[str, str],
validate_package: bool = True,
include_pandas: bool = False,
statement_params: Optional[Dict[str, str]] = None,
artifact_repository: Optional[str] = None,
**kwargs,
) -> List[str]:
"""
Expand All @@ -2128,18 +2141,12 @@ def _resolve_packages(
package_dict = self._parse_packages(packages)
if (
isinstance(self._conn, MockServerConnection)
or artifact_repository is not None
or artifact_repository != _ANACONDA_SHARED_REPOSITORY
):
# in local testing we don't resolve the packages, we just return what is added
# in local testing or non-conda, we don't resolve the packages, we just return what is added
errors = []
with self._package_lock:
if artifact_repository is None:
result_dict = self._packages
else:
result_dict = self._artifact_repository_packages[
artifact_repository
]

result_dict = existing_packages_dict
for pkg_name, _, pkg_req in package_dict.values():
if (
pkg_name in result_dict
Expand Down Expand Up @@ -2377,6 +2384,50 @@ def _upload_unsupported_packages(

return supported_dependencies + new_dependencies

def _get_default_artifact_repository(self) -> str:
"""
Returns the default artifact repository for the current session context
by calling SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY.

The result is cached per (database, schema) pair so that
repeated invocations in the same context do not issue
redundant system-function queries. Only one cache entry is kept at
a time; switching to a different database or schema evicts the
previous entry and triggers a fresh query on the next call.

Falls back to the Snowflake default artifact repository if:
- the session uses a mock connection (local testing), or
- the system function is not available / fails, or
- the system function returns NULL (value was never set).
"""
with self._package_lock:
if isinstance(self._conn, MockServerConnection):
return _DEFAULT_ARTIFACT_REPOSITORY

cache_key = (self.get_current_database(), self.get_current_schema())

if (
self._default_artifact_repository_cache is not None
and self._default_artifact_repository_cache[0] == cache_key
):
return self._default_artifact_repository_cache[1]

try:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
result = self._run_query(
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')"
)
value = result[0][0] if result else None
resolved = value or _DEFAULT_ARTIFACT_REPOSITORY
except Exception as e:
_logger.warning(
f"Error getting default artifact repository: {e}. Using fallback: {_DEFAULT_ARTIFACT_REPOSITORY}."
)
resolved = _DEFAULT_ARTIFACT_REPOSITORY

self._default_artifact_repository_cache = (cache_key, resolved)
return resolved

def _is_anaconda_terms_acknowledged(self) -> bool:
return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0]

Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,10 +939,21 @@ def _do_register_sp(
UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names[1:])
]

effective_artifact_repository = artifact_repository
if effective_artifact_repository is None:
from snowflake.snowpark.session import _DEFAULT_ARTIFACT_REPOSITORY

effective_artifact_repository = (
self._session._get_default_artifact_repository()
if self._session
else _DEFAULT_ARTIFACT_REPOSITORY
)

# Add in snowflake-snowpark-python if it is not already in the package list.
packages = add_snowpark_package_to_sproc_packages(
session=self._session,
packages=packages,
artifact_repository=effective_artifact_repository,
)

(
Expand All @@ -967,7 +978,7 @@ def _do_register_sp(
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
force_inline_code=force_inline_code,
artifact_repository=artifact_repository,
artifact_repository=effective_artifact_repository,
_suppress_local_package_warnings=kwargs.get(
"_suppress_local_package_warnings", False
),
Expand Down
Loading
Loading