Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,17 @@ async def dev_generate_token(user: str, include_models: list[str] = None) -> dic
# include only specific models
user_models = include_models
else:
user_models = list(JOB_MANIFESTS.keys())
# aggregate all models inference names for available models
all_models = list(JOB_MANIFESTS.keys())
user_models = []
for model in all_models:
inference_model_name = (
JOB_MANIFESTS.get(model)
.model_fields.get("inference_name", None)
.get_default()
)
if inference_model_name and inference_model_name not in user_models:
user_models.append(inference_model_name)
try:
# Verify secret key is configured
verify_secret_key()
Expand Down
24 changes: 15 additions & 9 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ async def generate_token_auth(
request: Request,
response: Response,
user: str = Query("default_user"),
include_models: str = Query(""),
include_models: str = Query(
"", description="Comma separated list of models to include"
),
):
"""Generate JWT

Expand Down Expand Up @@ -1310,15 +1312,19 @@ async def csv_sample():
def user_available_models(token: UserJWT | None = None):
"""Get all available models"""
all_models: list = list(JOB_MANIFESTS.keys())
# Filter models based on user available models.
# If the user has specific available models, only include those.
# get all models if running local for testing
# !important: the users available models are set in the JWT token, the finetune models `inference_name` must match the inference model name in the token
# If the user has specific available models in subscription jwt, only include those.
if token:
user_models = [
user_model
for user_model in token.available_models
if user_model in all_models
]
user_models = []
for model in all_models:
# match a finetune model with the available inference model subscription name
inference_model_name = (
JOB_MANIFESTS.get(model)
.model_fields.get("inference_name", None)
.get_default()
)
if inference_model_name and inference_model_name in token.available_models:
user_models.append(model)
return user_models
# Access all models if jwt not set. Authorization handled by Middleware
return all_models
Expand Down
3 changes: 3 additions & 0 deletions app/models/base/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class BaseFineTuneModel(BaseModel, ABC):
)
# model setup
name: str = Field(..., min_length=4, pattern=r"^[a-zA-Z0-9._@]+$")
inference_name: str | None = Field(
default=None, description="Name of the model to be used for inference"
)
image: str
image_pull_secret: str | None = None
command: list[str]
Expand Down
3 changes: 2 additions & 1 deletion app/models/examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class MNISTConfig(TrainingArguments):
class MNIST(BaseFineTuneModel):
"""Finetune Job Spec for MNIST"""

name: str = "MNIST" # model name must match inference name to work
name: str = "MNIST" # model name in the frontend
inference_name: str | None = "MNIST" # name must match inference name to work
description: str = "Example MNIST model for fine-tuning"
project_url: str = "https://github.com/acceleratedscience/model-foobar"
image: str = "quay.io/brian_duenas/mnist:latest"
Expand Down