Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 159 additions & 68 deletions kubeflow/hub/api/model_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,29 @@ def __init__(

Args:
base_url: Base URL of the model registry server including scheme.
Examples: "https://registry.example.com", "http://localhost"
The scheme is used to infer is_secure and port if not explicitly provided.
Examples: "https://registry.example.com", "http://localhost".

Keyword Args:
port: Server port. If not provided, inferred from base_url scheme:
- https:// defaults to 443
- http:// defaults to 8080
- no scheme defaults to 443
port: Server port. If not provided, inferred from `base_url` scheme:
- https:// defaults to 443
- http:// defaults to 8080
- no scheme defaults to 443
author: Name of the author.
is_secure: Whether to use a secure connection. If not provided, inferred from base_url:
- https:// sets is_secure=True
- http:// sets is_secure=False
- no scheme defaults to True
is_secure: Whether to use a secure connection. If not provided,
inferred from `base_url`:
- https:// sets is_secure=True
- http:// sets is_secure=False
- no scheme defaults to True
user_token: The PEM-encoded user token as a string.
custom_ca: Path to the PEM-encoded root certificates as a string.

Raises:
ImportError: If model-registry is not installed.
ImportError: If the `model-registry` package is not installed.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
"""
try:
from model_registry import ModelRegistry
Expand Down Expand Up @@ -104,33 +109,46 @@ def register_model(
version_description: str | None = None,
metadata: Mapping[str, SupportedTypes] | None = None,
) -> RegisteredModel:
"""Register a model.
"""Register a new model or a new version of an existing model.

This registers a model in the model registry. The model is not downloaded,
and has to be stored prior to registration.

Most models can be registered using their URI, along with optional
connection-specific parameters, `storage_key` and `storage_path` or,
simply a `service_account_name`. URI builder utilities are recommended
when referring to specialized storage; for example `utils.s3_uri_from`
helper when using S3 object storage data connections.
The model must be stored in external storage (e.g., S3, GCS) before
registration. The URI should point to the model artifacts.

Args:
name: Name of the model.
uri: URI of the model.
name: The name of the registered model.
uri: The URI where the model artifacts are stored.

Keyword Args:
version: Version of the model. Has to be unique.
model_format_name: Name of the model format (e.g., "pytorch", "tensorflow", "onnx").
Used by KServe to select the appropriate serving runtime.
model_format_version: Version of the model format (e.g., "2.0", "1.15").
author: Author of the model. Defaults to the client author.
owner: Owner of the model. Defaults to the client author.
version_description: Description of the model version.
metadata: Additional version metadata.
version: The version string for this registration. Must be unique
for this model name.
model_format_name: The format of the model (e.g., "pytorch",
"tensorflow"). Used by KServe for inference.
model_format_version: The version of the model format (e.g., "2.0").
author: The author of the model. Defaults to the client author.
owner: The owner of the model. Defaults to the client author.
version_description: A description of this specific model version.
metadata: A dictionary of additional metadata to store with the version.

Returns:
Registered model.
model_registry.types.RegisteredModel: The registered model object.

Raises:
model_registry.exceptions.StoreError: If the registry backend fails
to register the model.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")

model = client.register_model(
name="mnist-classifier",
uri="s3://my-bucket/models/mnist/v1/",
version="v1.0.0",
model_format_name="pytorch",
version_description="Initial release of MNIST model"
)
print(f"Registered model ID: {model.id}")
"""
return self._registry.register_model(
name=name,
Expand All @@ -145,17 +163,29 @@ def register_model(
)

def update_model(self, model: RegisteredModel) -> RegisteredModel:
"""Update a registered model.
"""Update the metadata of an existing registered model.

Args:
model: The registered model to update. Must have an ID.
model: The `RegisteredModel` instance to update. It must have
a valid ID.

Returns:
Updated registered model.
model_registry.types.RegisteredModel: The updated registered model.

Raises:
TypeError: If model is not a RegisteredModel instance.
model_registry.exceptions.StoreError: If model does not have an ID.
TypeError: If the input is not a `RegisteredModel` instance.
model_registry.exceptions.StoreError: If the registered model does
not have an ID.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
model = client.get_model(name="mnist-classifier")

# Update description
model.description = "Updated description"
updated_model = client.update_model(model)
"""
from model_registry.types import RegisteredModel

Expand All @@ -164,17 +194,28 @@ def update_model(self, model: RegisteredModel) -> RegisteredModel:
return self._registry.update(model)

def update_model_version(self, model_version: ModelVersion) -> ModelVersion:
"""Update a model version.
"""Update an existing model version's metadata.

Args:
model_version: The model version to update. Must have an ID.
model_version: The `ModelVersion` instance to update. It must have
a valid ID.

Returns:
Updated model version.
model_registry.types.ModelVersion: The updated model version.

Raises:
TypeError: If model_version is not a ModelVersion instance.
model_registry.exceptions.StoreError: If model version does not have an ID.
TypeError: If the input is not a `ModelVersion` instance.
model_registry.exceptions.StoreError: If the version does not have an ID.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
version = client.get_model_version(name="mnist", version="v1.0.0")

# Update metadata
version.metadata["accuracy"] = 0.98
client.update_model_version(version)
"""
from model_registry.types import ModelVersion

Expand All @@ -183,17 +224,28 @@ def update_model_version(self, model_version: ModelVersion) -> ModelVersion:
return self._registry.update(model_version)

def update_model_artifact(self, model_artifact: ModelArtifact) -> ModelArtifact:
"""Update a model artifact.
"""Update an existing model artifact's metadata.

Args:
model_artifact: The model artifact to update. Must have an ID.
model_artifact: The `ModelArtifact` instance to update. It must
have a valid ID.

Returns:
Updated model artifact.
model_registry.types.ModelArtifact: The updated model artifact.

Raises:
TypeError: If model_artifact is not a ModelArtifact instance.
model_registry.exceptions.StoreError: If model artifact does not have an ID.
TypeError: If the input is not a `ModelArtifact` instance.
model_registry.exceptions.StoreError: If the artifact does not have an ID.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
artifact = client.get_model_artifact(name="mnist", version="v1.0.0")

# Update artifact description
artifact.description = "Production-ready weights"
client.update_model_artifact(artifact)
"""
from model_registry.types import ModelArtifact

Expand All @@ -202,78 +254,117 @@ def update_model_artifact(self, model_artifact: ModelArtifact) -> ModelArtifact:
return self._registry.update(model_artifact)

def get_model(self, name: str) -> RegisteredModel:
"""Get a registered model.
"""Get a specific registered model by name.

Args:
name: Name of the model.
name: The name of the registered model.

Returns:
Registered model.
model_registry.types.RegisteredModel: The registered model object.

Raises:
ValueError: If the model does not exist.
ValueError: If a registered model with the given `name` is not found.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
model = client.get_model(name="mnist-classifier")
print(f"Model ID: {model.id}")
"""
model = self._registry.get_registered_model(name)
if model is None:
raise ValueError(f"Model {name!r} not found")
return model

def get_model_version(self, name: str, version: str) -> ModelVersion:
"""Get a model version.
"""Get a specific model version.

Args:
name: Name of the model.
version: Version of the model.
name: The name of the registered model.
version: The version string to retrieve.

Returns:
Model version.
model_registry.types.ModelVersion: The model version object.

Raises:
model_registry.exceptions.StoreError: If the model does not exist.
ValueError: If the version does not exist.
model_registry.exceptions.StoreError: If the registered model does
not exist.
ValueError: If the version string is not found for the given
registered model.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
version = client.get_model_version(name="mnist", version="v1.0.0")
print(f"Version ID: {version.id}")
"""
model_version = self._registry.get_model_version(name, version)
if model_version is None:
raise ValueError(f"Model version {version!r} not found for model {name!r}")
return model_version

def get_model_artifact(self, name: str, version: str) -> ModelArtifact:
"""Get a model artifact.
"""Get the artifact associated with a specific model version.

Args:
name: Name of the model.
version: Version of the model.
name: The name of the registered model.
version: The version of the registered model.

Returns:
Model artifact.
model_registry.types.ModelArtifact: The model artifact object.

Raises:
model_registry.exceptions.StoreError: If either the model or the version don't exist.
ValueError: If the artifact does not exist.
model_registry.exceptions.StoreError: If either the registered
model or version does not exist.
ValueError: If the artifact is not found.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
artifact = client.get_model_artifact(name="mnist", version="v1.0.0")
print(f"Artifact URI: {artifact.uri}")
"""
artifact = self._registry.get_model_artifact(name, version)
if artifact is None:
raise ValueError(f"Model artifact not found for model {name!r} version {version!r}")
return artifact

def list_models(self) -> Iterator[RegisteredModel]:
"""Get an iterator for registered models.
"""Get an iterator for all registered models.

Yields:
Registered models.
model_registry.types.RegisteredModel: The next registered model.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
for model in client.list_models():
print(f"Model: {model.name}")
"""
yield from self._registry.get_registered_models()

def list_model_versions(self, name: str) -> Iterator[ModelVersion]:
"""Get an iterator for model versions.
"""Get an iterator for all versions of a specific registered model.

Args:
name: Name of the model.
name: The name of the registered model.

Yields:
Model versions.
model_registry.types.ModelVersion: The next model version.

Raises:
model_registry.exceptions.StoreError: If the model does not exist.
model_registry.exceptions.StoreError: If the registered model does
not exist.

Example:
from kubeflow.hub import ModelRegistryClient

client = ModelRegistryClient(base_url="http://localhost:8080")
for version in client.list_model_versions(name="mnist"):
print(f"Version: {version.version}")
"""
yield from self._registry.get_model_versions(name)
Loading
Loading