diff --git a/app/core/security.py b/app/core/security.py index f9cf508..e38762e 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -350,7 +350,17 @@ async def dev_generate_token(user: str, include_models: list[str] = None) -> dic # include only specific models user_models = include_models else: - user_models = list(JOB_MANIFESTS.keys()) + # aggregate all models inference names for available models + all_models = list(JOB_MANIFESTS.keys()) + user_models = [] + for model in all_models: + inference_model_name = ( + JOB_MANIFESTS.get(model) + .model_fields.get("inference_name", None) + .get_default() + ) + if inference_model_name and inference_model_name not in user_models: + user_models.append(inference_model_name) try: # Verify secret key is configured verify_secret_key() diff --git a/app/main.py b/app/main.py index 370184f..b3aeb9f 100644 --- a/app/main.py +++ b/app/main.py @@ -203,7 +203,9 @@ async def generate_token_auth( request: Request, response: Response, user: str = Query("default_user"), - include_models: str = Query(""), + include_models: str = Query( + "", description="Comma separated list of models to include" + ), ): """Generate JWT @@ -1310,15 +1312,19 @@ async def csv_sample(): def user_available_models(token: UserJWT | None = None): """Get all available models""" all_models: list = list(JOB_MANIFESTS.keys()) - # Filter models based on user available models. - # If the user has specific available models, only include those. - # get all models if running local for testing + # !important: the users available models are set in the JWT token, the finetune models `inference_name` must match the inference model name in the token + # If the user has specific available models in subscription jwt, only include those. if token: - user_models = [ - user_model - for user_model in token.available_models - if user_model in all_models - ] + user_models = [] + for model in all_models: + # match a finetune model with the available inference model subscription name + inference_model_name = ( + JOB_MANIFESTS.get(model) + .model_fields.get("inference_name", None) + .get_default() + ) + if inference_model_name and inference_model_name in token.available_models: + user_models.append(model) return user_models # Access all models if jwt not set. Authorization handled by Middleware return all_models diff --git a/app/models/base/finetuning.py b/app/models/base/finetuning.py index a9128db..6683382 100644 --- a/app/models/base/finetuning.py +++ b/app/models/base/finetuning.py @@ -55,6 +55,9 @@ class BaseFineTuneModel(BaseModel, ABC): ) # model setup name: str = Field(..., min_length=4, pattern=r"^[a-zA-Z0-9._@]+$") + inference_name: str | None = Field( + default=None, description="Name of the model to be used for inference" + ) image: str image_pull_secret: str | None = None command: list[str] diff --git a/app/models/examples/mnist.py b/app/models/examples/mnist.py index 7c70407..4c5eddc 100644 --- a/app/models/examples/mnist.py +++ b/app/models/examples/mnist.py @@ -41,7 +41,8 @@ class MNISTConfig(TrainingArguments): class MNIST(BaseFineTuneModel): """Finetune Job Spec for MNIST""" - name: str = "MNIST" # model name must match inference name to work + name: str = "MNIST" # model name in the frontend + inference_name: str | None = "MNIST" # name must match inference name to work description: str = "Example MNIST model for fine-tuning" project_url: str = "https://github.com/acceleratedscience/model-foobar" image: str = "quay.io/brian_duenas/mnist:latest"