diff --git a/app/api/routers/invocation.py b/app/api/routers/invocation.py
index 47c6828..3c6f5c1 100644
--- a/app/api/routers/invocation.py
+++ b/app/api/routers/invocation.py
@@ -7,12 +7,14 @@
import hashlib
import logging
import pandas as pd
+from fastapi.encoders import jsonable_encoder
+
import app.api.globals as cms_globals
from typing import Dict, List, Union, Iterator, Any
from collections import defaultdict
from io import BytesIO
-from starlette.status import HTTP_400_BAD_REQUEST
+from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from typing_extensions import Annotated
from fastapi import APIRouter, Depends, Body, UploadFile, File, Request, Query, Response
from fastapi.responses import StreamingResponse, PlainTextResponse, JSONResponse
@@ -22,7 +24,7 @@
TextWithAnnotations,
TextWithPublicKey,
TextStreamItem,
- Tags,
+ Tags, OpenAIEmbeddingsRequest, OpenAIEmbeddingsResponse,
)
from app.model_services.base import AbstractModelService
from app.utils import get_settings, load_pydantic_object_from_dict
@@ -43,6 +45,7 @@
PATH_PROCESS_BULK_FILE = "/process_bulk_file"
PATH_REDACT = "/redact"
PATH_REDACT_WITH_ENCRYPTION = "/redact_with_encryption"
+PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"
router = APIRouter()
config = get_settings()
@@ -355,6 +358,93 @@ def get_redacted_text_with_encryption(
return JSONResponse(content=content)
+@router.post(
+ PATH_OPENAI_EMBEDDINGS,
+ tags=[Tags.OpenAICompatible.name],
+ response_model=None,
+ dependencies=[Depends(cms_globals.props.current_active_user)],
+ description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint",
+)
+def embed_texts(
+ request: Request,
+ request_data: Annotated[OpenAIEmbeddingsRequest, Body(
+ description="Text(s) to be embedded", media_type="application/json"
+ )],
+ tracking_id: Union[str, None] = Depends(validate_tracking_id),
+ model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
+) -> JSONResponse:
+ """
+ Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint.
+
+ Args:
+ request (Request): The request object.
+ request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s).
+ tracking_id (Union[str, None]): An optional tracking ID of the requested task.
+ model_service (AbstractModelService): The model service dependency.
+
+ Returns:
+ JSONResponse: A response containing the embeddings of the text(s).
+ """
+ tracking_id = tracking_id or str(uuid.uuid4())
+
+ if not hasattr(model_service, "create_embeddings"):
+ error_response = {
+ "error": {
+ "message": "Model does not support embeddings",
+ "type": "invalid_request_error",
+ "param": "model",
+ "code": "model_not_supported",
+ }
+ }
+ return JSONResponse(
+ content=error_response,
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR,
+ headers={"x-cms-tracking-id": tracking_id},
+ )
+
+ input_text = request_data.input
+ model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
+
+ if isinstance(input_text, str):
+ input_texts = [input_text]
+ else:
+ input_texts = input_text
+
+ try:
+ embeddings_data = []
+
+ for i, embedding in enumerate(model_service.create_embeddings(input_texts)):
+ embeddings_data.append({
+ "object": "embedding",
+ "embedding": embedding,
+ "index": i,
+ })
+
+ response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model)
+
+ return JSONResponse(
+ content=jsonable_encoder(response),
+ headers={"x-cms-tracking-id": tracking_id},
+ )
+
+ except Exception as e:
+ logger.error("Failed to create embeddings")
+ logger.exception(e)
+ error_response = {
+ "error": {
+ "message": f"Failed to create embeddings: {str(e)}",
+ "type": "server_error",
+ "code": "internal_error",
+ }
+ }
+ return JSONResponse(
+ content=error_response,
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR,
+ headers={"x-cms-tracking-id": tracking_id},
+ )
+
+
+
def _send_annotation_num_metric(annotation_num: int, handler: str) -> None:
cms_doc_annotations.labels(handler=handler).observe(annotation_num)
diff --git a/app/api/routers/stream.py b/app/api/routers/stream.py
index edba5e8..c44ecf8 100644
--- a/app/api/routers/stream.py
+++ b/app/api/routers/stream.py
@@ -11,7 +11,7 @@
from starlette.types import Receive, Scope, Send
from starlette.background import BackgroundTask
from fastapi import APIRouter, Depends, Request, Response, WebSocket, WebSocketException
-from pydantic import ValidationError
+from pydantic import ValidationError, BaseModel
from app.domain import Tags, TextStreamItem
from app.model_services.base import AbstractModelService
from app.utils import get_settings
@@ -20,7 +20,6 @@
PATH_STREAM_PROCESS = "/process"
PATH_WS = "/ws"
-PATH_GENERATE= "/generate"
router = APIRouter()
config = get_settings()
@@ -57,6 +56,22 @@ async def get_entities_stream_from_jsonlines_stream(
return _LocalStreamingResponse(annotation_stream, media_type="application/x-ndjson; charset=utf-8")
+@router.get(
+ PATH_WS,
+ tags=[Tags.Annotations.name],
+ dependencies=[Depends(cms_globals.props.current_active_user)],
+ description="WebSocket info endpoint for real-time NER entity extraction. Use ws://host:port/stream/ws to establish an actual WebSocket connection.",
+ include_in_schema=True,
+)
+async def get_inline_annotations_from_websocket_info() -> "_WebSocketInfo":
+ """
+ Information about the WebSocket endpoint for real-time NER entity extraction.
+
+ This endpoint provides documentation for the WebSocket connection available at the same path.
+ Connect to ws://host:port/stream/ws and send texts to retrieve annotated results.
+ """
+ return _WebSocketInfo()
+
@router.websocket(PATH_WS)
# @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) # Not supported yet
async def get_inline_annotations_from_websocket(
@@ -189,6 +204,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.background()
+class _WebSocketInfo(BaseModel):
+ message: str = "WebSocket endpoint for real-time NER entity extraction"
+ example: str = """
+
+"""
+ protocol: str = "WebSocket"
+
+
async def _annotation_async_gen(request: Request, model_service: AbstractModelService) -> AsyncGenerator:
try:
buffer = ""
diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py
index bc28bfd..4eafe8b 100644
--- a/app/model_services/huggingface_llm_model.py
+++ b/app/model_services/huggingface_llm_model.py
@@ -15,7 +15,7 @@
from app import __version__ as app_version
from app.exception import ConfigurationException
from app.model_services.base import AbstractModelService
-from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer
+from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer, HuggingFaceLlmUnsupervisedTrainer
from app.domain import ModelCard, ModelType, Annotation, Device
from app.config import Settings
from app.utils import (
@@ -62,6 +62,7 @@ def __init__(
self._multi_label_threshold = 0.5
self._text_generator = ThreadPoolExecutor(max_workers=50)
self.model_name = model_name or "HuggingFace LLM model"
+ self.is_4bit_quantised = False
@property
def model(self) -> PreTrainedModel:
@@ -206,6 +207,8 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N
self._model.to(get_settings().DEVICE)
if self._enable_trainer:
self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self)
+ self._unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(self)
+ self.is_4bit_quantised = load_in_4bit
def info(self) -> ModelCard:
"""
@@ -396,29 +399,47 @@ def create_embeddings(
self.model.eval()
- inputs = self.tokenizer(
- text,
- add_special_tokens=False,
- return_tensors="pt",
- padding=True,
- truncation=True,
- )
-
- inputs.to(self.model.device)
-
- with torch.no_grad():
- outputs = self.model(**inputs, output_hidden_states=True)
+ texts = [text] if isinstance(text, str) else text
+ all_embeddings = []
+
+ for txt in texts:
+ inputs = self.tokenizer(txt, add_special_tokens=False, truncation=False, padding=False)
+ input_ids = inputs["input_ids"]
+ attention_mask = inputs["attention_mask"]
+ window_size = max(self.model.config.max_position_embeddings - 2, 1)
+ stride = window_size
+ chunk_embeddings = []
+
+ for start in range(0, len(input_ids), stride):
+ end = min(start + window_size, len(input_ids))
+ chunk_inputs = {
+ "input_ids": torch.tensor(
+ [input_ids[start:end]], dtype=torch.long
+ ).to(self.model.device),
+ "attention_mask": torch.tensor(
+ [attention_mask[start:end]], dtype=torch.long
+ ).to(self.model.device),
+ }
+
+ with torch.no_grad():
+ outputs = self.model(**chunk_inputs, output_hidden_states=True)
+
+ last_hidden_state = outputs.hidden_states[-1]
+ chunk_attention_mask = chunk_inputs["attention_mask"]
+ masked_hidden_states = last_hidden_state * chunk_attention_mask.unsqueeze(-1)
+ sum_hidden_states = masked_hidden_states.sum(dim=1)
+ num_tokens = chunk_attention_mask.sum(dim=1, keepdim=True)
+ chunk_embedding = sum_hidden_states / num_tokens
+ chunk_embeddings.append(chunk_embedding)
+
+ if end >= len(input_ids):
+ break
- last_hidden_state = outputs.hidden_states[-1]
- attention_mask = inputs["attention_mask"]
- masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1)
- sum_hidden_states = masked_hidden_states.sum(dim=1)
- num_tokens = attention_mask.sum(dim=1, keepdim=True)
- embeddings = sum_hidden_states / num_tokens
- l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+ final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True)
+ l2_normalised = torch.nn.functional.normalize(final_embedding, p=2, dim=1)
+ all_embeddings.append(l2_normalised.cpu().numpy().tolist()[0])
- results = l2_normalised.cpu().numpy().tolist()
- return results[0] if isinstance(text, str) else results
+ return all_embeddings[0] if isinstance(text, str) else all_embeddings
def train_supervised(
self,
@@ -465,3 +486,49 @@ def train_supervised(
synchronised,
**hyperparams,
)
+
+ def train_unsupervised(
+ self,
+ data_file: TextIO,
+ epochs: int,
+ log_frequency: int,
+ training_id: str,
+ input_file_name: str,
+ raw_data_files: Optional[List[TextIO]] = None,
+ description: Optional[str] = None,
+ synchronised: bool = False,
+ **hyperparams: Dict[str, Any],
+ ) -> Tuple[bool, str, str]:
+ """
+ Initiates unsupervised training on the model.
+
+ Args:
+ data_file (TextIO): The file containing a JSON list of texts.
+ epochs (int): The number of training epochs.
+ log_frequency (int): The number of epochs after which training metrics will be logged.
+ training_id (str): A unique identifier for the training process.
+ input_file_name (str): The name of the input file to be logged.
+ raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None.
+ description (Optional[str]): The description of the training or change logs. Defaults to empty.
+ synchronised (bool): Whether to wait for the training to complete.
+ **hyperparams (Dict[str, Any]): Additional hyperparameters for training.
+
+ Returns:
+ Tuple[bool, str, str]: A tuple with the first element indicating success or failure.
+
+ Raises:
+ ConfigurationException: If the unsupervised trainer is not enabled.
+ """
+ if self._unsupervised_trainer is None:
+ raise ConfigurationException("The unsupervised trainer is not enabled")
+ return self._unsupervised_trainer.train(
+ data_file,
+ epochs,
+ log_frequency,
+ training_id,
+ input_file_name,
+ raw_data_files,
+ description,
+ synchronised,
+ **hyperparams,
+ )
diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py
index 98e55f9..a6eeb8c 100644
--- a/app/model_services/huggingface_ner_model.py
+++ b/app/model_services/huggingface_ner_model.py
@@ -1,9 +1,10 @@
import os
import logging
+import torch
import pandas as pd
from functools import partial
-from typing import Dict, List, Optional, Tuple, Any, TextIO
+from typing import Dict, List, Optional, Tuple, Any, TextIO, Union
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
@@ -276,6 +277,83 @@ def annotate(self, text: str) -> List[Annotation]:
def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
raise NotImplementedError("Batch annotation is not yet implemented for HuggingFace NER models")
+ def create_embeddings(
+ self,
+ text: Union[str, List[str]],
+ *args: Any,
+ **kwargs: Any
+ ) -> Union[List[float], List[List[float]]]:
+ """
+ Creates embeddings for a given text or list of texts using the model's hidden states.
+
+ Args:
+ text (Union[str, List[str]]): The text(s) to be embedded.
+ *args (Any): Additional positional arguments to be passed to this method.
+ **kwargs (Any): Additional keyword arguments to be passed to this method.
+
+ Returns:
+ List[float], List[List[float]]: The embedding vector(s) for the text(s).
+
+ Raises:
+ NotImplementedError: If the model doesn't support embeddings.
+ """
+
+ self.model.eval()
+
+ texts = [text] if isinstance(text, str) else text
+ all_embeddings = []
+
+ max_len = self.model.config.max_position_embeddings
+
+ for txt in texts:
+ encoded = self.tokenizer(
+ txt,
+ add_special_tokens=True,
+ truncation=False,
+ return_attention_mask=True,
+ )
+
+ input_ids = encoded["input_ids"]
+ chunk_embeddings = []
+ window_size = max_len - 2
+ stride = window_size
+
+ for start in range(0, len(input_ids), stride):
+ end = min(start + window_size, len(input_ids))
+
+ chunk = self.tokenizer.prepare_for_model(
+ input_ids[start:end],
+ add_special_tokens=True,
+ return_attention_mask=True,
+ truncation=True,
+ max_length=max_len,
+ padding="max_length",
+ )
+
+ chunk_inputs = {
+ "input_ids": torch.tensor([chunk["input_ids"]], device=self.model.device),
+ "attention_mask": torch.tensor([chunk["attention_mask"]], device=self.model.device),
+ }
+
+ with torch.no_grad():
+ outputs = self.model(**chunk_inputs, output_hidden_states=True)
+
+ last_hidden_state = outputs.hidden_states[-1]
+ mask = chunk_inputs["attention_mask"].unsqueeze(-1)
+ summed = (last_hidden_state * mask).sum(dim=1)
+ counts = mask.sum(dim=1).clamp(min=1)
+ chunk_embedding = summed / counts
+ chunk_embeddings.append(chunk_embedding)
+
+ if end >= len(input_ids):
+ break
+
+ final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True)
+ final_embedding = torch.nn.functional.normalize(final_embedding, p=2, dim=1)
+ all_embeddings.append(final_embedding.cpu().numpy()[0].tolist())
+
+ return all_embeddings[0] if isinstance(text, str) else all_embeddings
+
def train_supervised(
self,
data_file: TextIO,
diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py
index 0ef00d1..800e68d 100644
--- a/app/model_services/medcat_model.py
+++ b/app/model_services/medcat_model.py
@@ -1,10 +1,15 @@
import os
import logging
+import torch
+import numpy as np
import pandas as pd
from multiprocessing import cpu_count
from typing import Dict, List, Optional, TextIO, Tuple, Any, Set, Union
from medcat.cat import CAT
+from medcat.components.linking.embedding_linker import Linker
+from medcat.components.linking.vector_context_model import PerDocumentTokenCache
+from medcat.components.types import CoreComponentType
from medcat.data.entities import Entities, OnlyCUIEntities
from app import __version__ as app_version
from app.model_services.base import AbstractModelService
@@ -20,7 +25,7 @@
get_model_data_package_base_name,
load_pydantic_object_from_dict,
)
-from app.exception import ConfigurationException
+from app.exception import ConfigurationException, ManagedModelException
logger = logging.getLogger("cms")
@@ -198,7 +203,7 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
annotations_list = []
for _, doc in docs.items():
annotations_list.append([
- load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc)
+ load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc) # type: ignore
])
return annotations_list
@@ -370,6 +375,79 @@ def get_records_from_doc(self, doc: Union[Dict, Entities, OnlyCUIEntities]) -> L
records = df.to_dict("records")
return records
+ def create_embeddings(
+ self,
+ text: Union[str, List[str]],
+ *args: Any,
+ model_name: Optional[str] = None,
+ max_length: Optional[int] = None,
+ **kwargs: Any
+ ) -> Union[List[float], List[List[float]]]:
+ """
+ Creates embeddings for a given text or list of texts using MedCAT's embedding linker.
+
+ Args:
+ text (Union[str, List[str]]): The text(s) to be embedded.
+ model_name (Optional[str]): The name of the embedding model to use.
+ max_length (Optional[int]): Maximum sequence length for tokenization.
+ *args (Any): Additional positional arguments.
+ **kwargs (Any): Additional keyword arguments.
+
+ Returns:
+ Union[List[float], List[List[float]]]: The embedding vector(s) for the text(s).
+ """
+
+ assert self._model is not None, "Model is not initialised"
+ texts = [text] if isinstance(text, str) else text
+ linker = self._model.pipe.get_component(CoreComponentType.linking)
+
+ if isinstance(linker, Linker):
+ embedding_model_name = getattr(linker.cnf_l, "embedding_model_name", None)
+ if embedding_model_name is None:
+ raise ManagedModelException("Embedding linker present but no embedding_model_name found in config.")
+ linker._load_transformers(embedding_model_name)
+ with torch.no_grad():
+ emb_tensor = linker._embed(texts, linker.device)
+ embeddings = emb_tensor.cpu().numpy().tolist()
+ return embeddings[0] if isinstance(text, str) else embeddings
+ else:
+ all_embeddings = []
+ ctx_model = getattr(linker, "context_model", None)
+ if ctx_model is None:
+ raise ManagedModelException(
+ "Linker does not expose context_model so cannot compute context-based embeddings."
+ )
+ tokenizer = self._model.pipe.tokenizer
+ for txt in texts:
+ doc = tokenizer(txt)
+ if hasattr(tokenizer, "entity_from_tokens"):
+ entity = tokenizer.entity_from_tokens(list(doc))
+ else:
+ raise ManagedModelException(
+ "Tokenizer does not support entity_from_tokens so cannot build entity for context model"
+ )
+
+ cache = PerDocumentTokenCache()
+ vectors = ctx_model.get_context_vectors(entity, doc, cache)
+ weights = getattr(linker.config.components.linking, "context_vector_weights", None) # type: ignore
+ if not weights:
+ weights = {k: 1.0 for k in vectors.keys()}
+
+ combined = None
+ for size, vec in vectors.items():
+ w = weights.get(size, 1.0)
+ if combined is None:
+ combined = w * vec
+ else:
+ combined = combined + w * vec
+ if combined is not None:
+ norm = np.linalg.norm(combined)
+ if norm > 0:
+ combined = combined / norm
+ all_embeddings.append(combined.tolist())
+
+ return all_embeddings[0] if isinstance(text, str) else all_embeddings
+
@staticmethod
def _retrieve_meta_annotations(df: pd.DataFrame) -> pd.DataFrame:
meta_annotations = []
diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py
index 547f225..c401d7e 100644
--- a/app/model_services/medcat_model_deid.py
+++ b/app/model_services/medcat_model_deid.py
@@ -2,7 +2,7 @@
import inspect
import threading
import torch
-from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable, cast
+from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable, cast, Union
from functools import partial
from transformers import pipeline
from medcat.cat import CAT
@@ -91,7 +91,7 @@ def annotate(self, text: str) -> List[Annotation]:
for _, entity in doc["entities"].items():
entity["type_ids"] = ["PII"]
- records = self.get_records_from_doc({"entities": doc["entities"]})
+ records = self.get_records_from_doc({"entities": doc["entities"]}) # type: ignore
return [load_pydantic_object_from_dict(Annotation, record) for record in records]
def annotate_with_local_chunking(self, text: str) -> List[Annotation]:
@@ -179,11 +179,101 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
entity = cast(Dict[str, Any], entity)
entity["type_ids"] = ["PII"]
annotations_list.append([
- load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(entities)
+ load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(entities) # type: ignore
])
return annotations_list
+ def create_embeddings(
+ self,
+ text: Union[str, List[str]],
+ *args: Any,
+ **kwargs: Any
+ ) -> Union[List[float], List[List[float]]]:
+ """
+ Creates embeddings for a given text or list of texts using the model's hidden states.
+
+ Args:
+ text (Union[str, List[str]]): The text(s) to be embedded.
+ *args (Any): Additional positional arguments to be passed to this method.
+ **kwargs (Any): Additional keyword arguments to be passed to this method.
+
+ Returns:
+ List[float], List[List[float]]: The embedding vector(s) for the text(s).
+
+ Raises:
+ NotImplementedError: If the model doesn't support embeddings.
+ """
+
+ assert self.model is not None, "Model is not initialised"
+ ner = self.model.pipe.get_component(CoreComponentType.ner)._component # type: ignore
+ ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr(
+ ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False
+ )
+ ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr(
+ ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None
+ )
+ ner.tokenizer.hf_tokenizer.split_special_tokens = getattr(
+ ner.tokenizer.hf_tokenizer, "split_special_tokens", False
+ )
+ tokenizer = ner.tokenizer.hf_tokenizer
+ model = ner.model
+ model.eval()
+
+ texts = [text] if isinstance(text, str) else text
+ all_embeddings = []
+
+ max_len = model.config.max_position_embeddings
+
+ for txt in texts:
+ encoded = tokenizer(
+ txt,
+ add_special_tokens=True,
+ truncation=False,
+ return_attention_mask=True,
+ )
+
+ input_ids = encoded["input_ids"]
+ chunk_embeddings = []
+ window_size = max_len - 2
+ stride = window_size
+
+ for start in range(0, len(input_ids), stride):
+ end = min(start + window_size, len(input_ids))
+
+ chunk = tokenizer.prepare_for_model(
+ input_ids[start:end],
+ add_special_tokens=True,
+ return_attention_mask=True,
+ truncation=True,
+ max_length=max_len,
+ padding="max_length",
+ )
+
+ chunk_inputs = {
+ "input_ids": torch.tensor([chunk["input_ids"]], device=model.device),
+ "attention_mask": torch.tensor([chunk["attention_mask"]], device=model.device),
+ }
+
+ with torch.no_grad():
+ outputs = model(**chunk_inputs, output_hidden_states=True)
+
+ last_hidden_state = outputs.hidden_states[-1]
+ mask = chunk_inputs["attention_mask"].unsqueeze(-1)
+ summed = (last_hidden_state * mask).sum(dim=1)
+ counts = mask.sum(dim=1).clamp(min=1)
+ chunk_embedding = summed / counts
+ chunk_embeddings.append(chunk_embedding)
+
+ if end >= len(input_ids):
+ break
+
+ final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True)
+ final_embedding = torch.nn.functional.normalize(final_embedding, p=2, dim=1)
+ all_embeddings.append(final_embedding.cpu().numpy()[0].tolist())
+
+ return all_embeddings[0] if isinstance(text, str) else all_embeddings
+
def init_model(self, *args: Any, **kwargs: Any) -> None:
"""Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration.
diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py
index 5350422..89404fa 100644
--- a/app/trainers/huggingface_llm_trainer.py
+++ b/app/trainers/huggingface_llm_trainer.py
@@ -1,10 +1,12 @@
import os
import logging
import math
+import tempfile
import torch
import gc
import datasets
import re
+import random
import threading
import json
import inspect
@@ -18,6 +20,8 @@
TrainerCallback,
TrainerState,
TrainerControl,
+ DataCollatorForLanguageModeling,
+ Trainer,
)
from peft import LoraConfig, get_peft_model # type: ignore
from app.management.model_manager import ModelManager
@@ -32,13 +36,22 @@
get_default_system_prompt,
get_model_data_package_base_name,
)
-from app.trainers.base import SupervisedTrainer
-from app.domain import ModelType, TrainerBackend, LlmRole, LlmTrainerType, LlmDatasetType, PromptMessage
+from app.trainers.base import SupervisedTrainer, UnsupervisedTrainer
+from app.domain import (
+ ModelType,
+ TrainerBackend,
+ LlmRole,
+ LlmTrainerType,
+ LlmDatasetType,
+ PromptMessage,
+ DatasetSplit,
+)
from app.exception import (
TrainingCancelledException,
DatasetException,
ConfigurationException,
ExtraDependencyRequiredException,
+ ManagedModelException,
)
if TYPE_CHECKING:
from app.model_services.huggingface_llm_model import HuggingFaceLlmModel
@@ -719,6 +732,325 @@ def _evaluate_with_rewards(
return reward_avgs
+@final
+class HuggingFaceLlmUnsupervisedTrainer(UnsupervisedTrainer, _HuggingFaceLlmTrainerCommon):
+ """
+ An unsupervised trainer class for HuggingFace LLM models.
+
+ Args:
+ model_service (HuggingFaceLlmModel): An instance of the HuggingFace LLM model service.
+ """
+
+ def __init__(self, model_service: "HuggingFaceLlmModel") -> None:
+ UnsupervisedTrainer.__init__(self, model_service._config, model_service.model_name)
+ self._model_service = model_service
+ self._model_name = model_service.model_name
+ self._model_pack_path = model_service._model_pack_path
+ self._retrained_models_dir = os.path.join(
+ model_service._model_parent_dir,
+ "retrained",
+ self._model_name.replace(" ", "_"),
+ )
+ self._model_manager = ModelManager(type(model_service), model_service._config)
+ self._max_length = model_service.model.config.max_position_embeddings
+ os.makedirs(self._retrained_models_dir, exist_ok=True)
+
+ def run(
+ self,
+ training_params: Dict,
+ data_file: TextIO,
+ log_frequency: int,
+ run_id: str,
+ description: Optional[str] = None,
+ ) -> None:
+ """
+ Runs the unsupervised training loop for HuggingFace LLM models.
+
+ Args:
+ training_params (Dict): A dictionary containing parameters for the training.
+ data_file (TextIO): The file-like object containing the training data.
+ log_frequency (int): The frequency at which logs should be recorded.
+ run_id (str): The run ID of the training job.
+ description (Optional[str]): The optional description of the training.
+ """
+ eval_mode = training_params["nepochs"] == 0
+ trained_model_pack_path = None
+ redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true"
+ skip_save_model = self._config.SKIP_SAVE_MODEL == "true"
+ results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results"))
+ logs_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "logs"))
+ reset_random_seed()
+ trainer = None
+
+ if not eval_mode:
+ try:
+ copied_model_directory = None
+ if self._model_service.is_4bit_quantised:
+ logger.info("Use the LoRA adaptor for the quantised model...")
+ lora_config = LoraConfig(
+ task_type="CAUSAL_LM",
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.1,
+ target_modules=[
+ "q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj",
+ ],
+ )
+ model = get_peft_model(self._model_service.model, lora_config)
+ tokenizer = self._model_service.tokenizer
+ else:
+ logger.info("Loading a new model copy for training...")
+ copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id)
+ model, tokenizer = self._model_service.load_model(copied_model_pack_path)
+ copied_model_directory = os.path.join(
+ os.path.dirname(copied_model_pack_path),
+ get_model_data_package_base_name(copied_model_pack_path),
+ )
+
+ if non_default_device_is_available(self._config.DEVICE):
+ model.to(self._config.DEVICE)
+
+ test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"]
+ if isinstance(data_file, tempfile.TemporaryDirectory):
+ raw_dataset = datasets.load_from_disk(data_file.name)
+ if DatasetSplit.VALIDATION.value in raw_dataset.keys():
+ train_texts = raw_dataset[DatasetSplit.TRAIN.value]["text"]
+ eval_texts = raw_dataset[DatasetSplit.VALIDATION.value]["text"]
+ elif DatasetSplit.TEST.value in raw_dataset.keys():
+ train_texts = raw_dataset[DatasetSplit.TRAIN.value]["text"]
+ eval_texts = raw_dataset[DatasetSplit.TEST.value]["text"]
+ else:
+ lines = raw_dataset[DatasetSplit.TRAIN.value]["text"]
+ random.shuffle(lines)
+ train_texts = [line.strip() for line in lines[:int(len(lines) * (1 - test_size))]]
+ eval_texts = [line.strip() for line in lines[int(len(lines) * (1 - test_size)):]]
+ else:
+ with open(data_file.name, "r") as f:
+ lines = json.load(f)
+ random.shuffle(lines)
+ train_texts = [line.strip() for line in lines[:int(len(lines) * (1 - test_size))]]
+ eval_texts = [line.strip() for line in lines[int(len(lines) * (1 - test_size)):]]
+
+ train_dataset = datasets.Dataset.from_dict({"text": train_texts})
+ eval_dataset = datasets.Dataset.from_dict({"text": eval_texts})
+
+ train_dataset = train_dataset.map(
+ lambda examples: tokenizer(examples["text"], truncation=True, max_length=self._max_length),
+ batched=True,
+ remove_columns=["text"],
+ )
+ eval_dataset = eval_dataset.map(
+ lambda examples: tokenizer(examples["text"], truncation=True, max_length=self._max_length),
+ batched=True,
+ remove_columns=["text"],
+ )
+
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ training_args = TrainingArguments(
+ output_dir=results_path,
+ logging_dir=logs_path,
+ logging_steps=log_frequency,
+ num_train_epochs=training_params["nepochs"],
+ per_device_train_batch_size=4,
+ gradient_accumulation_steps=4,
+ learning_rate=5e-5,
+ weight_decay=0.01,
+ warmup_steps=500,
+ save_steps=1000,
+ eval_steps=1000,
+ report_to="none",
+ )
+
+ mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
+ cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
+ trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback]
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ data_collator=data_collator,
+ callbacks=trainer_callbacks,
+ )
+
+ self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version)
+
+ logger.info("Performing unsupervised training...")
+ trainer.train()
+
+ if cancel_event_check_callback.training_cancelled:
+ raise TrainingCancelledException("Training was cancelled by the user")
+
+ if not skip_save_model:
+ model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE)
+ model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}"
+ retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name)
+ trained_model_directory = os.path.join(
+ os.path.dirname(retrained_model_pack_path),
+ get_model_data_package_base_name(retrained_model_pack_path),
+ )
+ if hasattr(model, "merge_and_unload"):
+ model = model.merge_and_unload()
+ model.save_pretrained(
+ trained_model_directory,
+ safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"),
+ )
+ tokenizer.save_pretrained(trained_model_directory)
+ create_model_data_package(trained_model_directory, retrained_model_pack_path)
+ else:
+ model.save_pretrained(
+ copied_model_directory, # type: ignore
+ safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"),
+ )
+ create_model_data_package(copied_model_directory, retrained_model_pack_path) # type: ignore
+
+ self._tracker_client.log_model_config(model.config.to_dict()) # type: ignore
+
+ model_uri = self._tracker_client.save_model(
+ retrained_model_pack_path,
+ self._model_name,
+ self._model_manager,
+ self._model_service.info().model_type.value,
+ )
+ logger.info(f"Retrained model saved: {model_uri}")
+ else:
+ logger.info("Skipped saving the retrained model")
+
+ if redeploy:
+ self.deploy_model(self._model_service, model, tokenizer)
+ else:
+ del model
+ del tokenizer
+ gc.collect()
+ logger.info("Skipped deployment of the retrained model")
+
+ logger.info("Unsupervised training finished")
+ self._tracker_client.end_with_success()
+
+ except TrainingCancelledException as e:
+ logger.exception(e)
+ logger.info("Unsupervised training was cancelled")
+ self._tracker_client.end_with_interruption()
+ except torch.OutOfMemoryError as e:
+ logger.exception("Unsupervised training failed on CUDA OOM")
+ try:
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ torch.cuda.synchronize()
+ except Exception:
+ pass
+ self._tracker_client.log_exceptions(e)
+ self._tracker_client.end_with_failure()
+ except Exception as e:
+ logger.exception("Unsupervised training failed")
+ self._tracker_client.log_exceptions(e)
+ self._tracker_client.end_with_failure()
+ finally:
+ data_file.close()
+ with self._training_lock:
+ self._training_in_progress = False
+ self._clean_up_training_cache()
+ self._housekeep_file(trained_model_pack_path)
+ if trainer is not None:
+ del trainer
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ try:
+ logger.info("Evaluating the running model...")
+ model, tokenizer = self._model_service.model, self._model_service.tokenizer
+
+ if self._model_service.is_4bit_quantised:
+ logger.error("Cannot evaluate against a quantised model")
+ raise ManagedModelException("Cannot evaluate against a quantised model")
+
+ if non_default_device_is_available(self._config.DEVICE):
+ model.to(self._config.DEVICE)
+
+ if isinstance(data_file, tempfile.TemporaryDirectory):
+ raw_dataset = datasets.load_from_disk(data_file.name)
+ if DatasetSplit.TEST.value in raw_dataset.keys():
+ eval_texts = raw_dataset[DatasetSplit.TEST.value]["text"]
+ elif DatasetSplit.VALIDATION.value in raw_dataset.keys():
+ eval_texts = raw_dataset[DatasetSplit.VALIDATION.value]["text"]
+ else:
+ raise DatasetException("No test or validation split found in the input dataset file")
+
+ else:
+ with open(data_file.name, "r") as f:
+ eval_texts = [line.strip() for line in json.load(f)]
+
+ eval_dataset = datasets.Dataset.from_dict({"text": eval_texts})
+ eval_dataset = eval_dataset.map(
+ lambda examples: tokenizer(examples["text"], truncation=True, max_length=self._max_length),
+ batched=True,
+ remove_columns=["text"],
+ )
+
+ data_collator = DataCollatorForLanguageModeling(
+ tokenizer=tokenizer,
+ mlm=False,
+ )
+
+ training_args = TrainingArguments(
+ output_dir=results_path,
+ logging_dir=logs_path,
+ logging_steps=log_frequency,
+ per_device_eval_batch_size=4,
+ report_to="none",
+ do_train=False,
+ do_eval=True,
+ )
+
+ mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
+ cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
+ trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback]
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ eval_dataset=eval_dataset,
+ data_collator=data_collator,
+ callbacks=trainer_callbacks,
+ )
+
+ eval_metrics = trainer.evaluate()
+ if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics:
+ eval_metrics.update({"perplexity": math.exp(eval_metrics["eval_loss"])})
+ logger.info(f"Evaluation metrics: {eval_metrics}")
+ self._tracker_client.send_hf_metrics_logs(eval_metrics, 0)
+ self._tracker_client.end_with_success()
+ logger.info("Model evaluation finished")
+ except torch.OutOfMemoryError as e:
+ logger.exception("Evaluation failed on CUDA OOM")
+ try:
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ torch.cuda.synchronize()
+ except Exception:
+ pass
+ self._tracker_client.log_exceptions(e)
+ self._tracker_client.end_with_failure()
+ except Exception as e:
+ logger.exception("Evaluation failed")
+ self._tracker_client.log_exceptions(e)
+ self._tracker_client.end_with_failure()
+ finally:
+ data_file.close()
+ with self._training_lock:
+ self._training_in_progress = False
+ self._clean_up_training_cache()
+ if trainer is not None:
+ del trainer
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
@final
class MLflowLoggingCallback(TrainerCallback):
"""
diff --git a/tests/app/api/test_api.py b/tests/app/api/test_api.py
index 1d5f077..8e93585 100644
--- a/tests/app/api/test_api.py
+++ b/tests/app/api/test_api.py
@@ -14,19 +14,11 @@ def test_get_model_server():
model_service_dep = ModelServiceDep("medcat_snomed", config)
app = get_model_server(config, model_service_dep)
info = app.openapi()["info"]
- tags = app.openapi_tags
paths = [route.path for route in app.routes]
assert isinstance(info["title"], str)
assert isinstance(info["summary"], str)
assert isinstance(info["version"], str)
- assert {"name": "Metadata", "description": "Get the model card"} in tags
- assert {"name": "Annotations", "description": "Retrieve NER entities by running the model"} in tags
- assert {"name": "Redaction", "description": "Redact the extracted NER entities"} in tags
- assert {"name": "Rendering", "description": "Preview embeddable annotation snippet in HTML"} in tags
- assert {"name": "Training", "description": "Trigger model training on input annotations"} in tags
- assert {"name": "Evaluating", "description": "Evaluate the deployed model with trainer export"} in tags
- assert {"name": "Authentication", "description": "Authenticate registered users"} in tags
assert "/info" in paths
assert "/process" in paths
assert "/process_jsonl" in paths
@@ -61,13 +53,11 @@ def test_get_stream_server():
model_service_dep = ModelServiceDep("medcat_snomed", config)
app = get_stream_server(config, model_service_dep)
info = app.openapi()["info"]
- tags = app.openapi_tags
paths = [route.path for route in app.routes]
assert isinstance(info["title"], str)
assert isinstance(info["summary"], str)
assert isinstance(info["version"], str)
- assert {"name": "Streaming", "description": "Retrieve NER entities as a stream by running the model"} in tags
assert "/info" in paths
assert "/stream/process" in paths
assert "/stream/ws" in paths
@@ -84,14 +74,11 @@ def test_get_generative_server():
model_service_dep = ModelServiceDep("huggingface_llm_model", config)
app = get_generative_server(config, model_service_dep)
info = app.openapi()["info"]
- tags = app.openapi_tags
paths = [route.path for route in app.routes]
assert isinstance(info["title"], str)
assert isinstance(info["summary"], str)
assert isinstance(info["version"], str)
- assert {"name": "Metadata", "description": "Get the model card"} in tags
- assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags
assert "/info" in paths
assert "/generate" in paths
assert "/stream/generate" in paths
diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py
index 6fbb89d..0f134f9 100644
--- a/tests/app/model_services/test_huggingface_llm_model.py
+++ b/tests/app/model_services/test_huggingface_llm_model.py
@@ -148,101 +148,107 @@ async def test_generate_async(huggingface_llm_model):
assert result == "Yeah."
-def test_create_embeddings_single_text(huggingface_llm_model):
- """Test create_embeddings with single text input."""
+@patch("torch.nn.functional.normalize")
+@patch("torch.mean")
+@patch("torch.cat")
+@patch("torch.tensor")
+def test_create_embeddings_single_text(mock_tensor, mock_cat, mock_mean, mock_normalise, huggingface_llm_model):
+ def tensor_side_effect(*args, **kwargs):
+ result = MagicMock()
+ result.to.return_value = result
+ return result
+
huggingface_llm_model.init_model()
huggingface_llm_model.model = MagicMock()
+ huggingface_llm_model.model.config.max_position_embeddings = 10
huggingface_llm_model.tokenizer = MagicMock()
- mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()]
+ long_input_ids = list(range(25))
+ long_attention_mask = [1] * 25
+ huggingface_llm_model.tokenizer.return_value = {
+ "input_ids": long_input_ids,
+ "attention_mask": long_attention_mask
+ }
mock_outputs = MagicMock()
- mock_outputs.hidden_states = mock_hidden_states
- mock_last_hidden_state = MagicMock()
- mock_last_hidden_state.shape = [1, 3, 768]
- mock_hidden_states[-1] = mock_last_hidden_state
- mock_attention_mask = MagicMock()
- mock_attention_mask.shape = [1, 3]
- mock_attention_mask.sum.return_value = MagicMock()
- mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock()
- mock_inputs = MagicMock()
- mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock()
- huggingface_llm_model.tokenizer.return_value = mock_inputs
+ mock_hidden_state = MagicMock()
+ mock_outputs.hidden_states = [None, None, mock_hidden_state]
huggingface_llm_model.model.return_value = mock_outputs
- expected_result = [0.1, 0.2, 0.3]
- mock_embeddings_batch = MagicMock()
- mock_first_embedding = MagicMock()
- mock_cpu_tensor = MagicMock()
- mock_numpy_array = MagicMock()
- mock_numpy_array.tolist.return_value = expected_result
- mock_embeddings_batch.__getitem__.return_value = mock_first_embedding
- mock_first_embedding.cpu.return_value = mock_cpu_tensor
- mock_cpu_tensor.numpy.return_value = mock_numpy_array
- mock_masked_hidden_states = MagicMock()
- mock_sum_hidden_states = MagicMock()
- mock_num_tokens = MagicMock()
- mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states
- mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states
- mock_attention_mask.sum.return_value = mock_num_tokens
- mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch
-
- result = huggingface_llm_model.create_embeddings("Alright")
-
- huggingface_llm_model.model.eval.assert_called_once()
- huggingface_llm_model.tokenizer.assert_called_once_with(
- "Alright",
- add_special_tokens=False,
- return_tensors="pt",
- padding=True,
- truncation=True
- )
- huggingface_llm_model.model.assert_called_once_with(
- **mock_inputs,
- output_hidden_states=True
+ mock_chunk_embedding = MagicMock()
+ mock_final_embedding = MagicMock()
+ mock_normalised = MagicMock()
+ mock_concatenated = MagicMock()
+ mock_cat.return_value = mock_concatenated
+ mock_mean.return_value = mock_final_embedding
+ mock_normalise.return_value = mock_normalised
+ mock_normalised.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]]
+ mock_tensor.side_effect = tensor_side_effect
+ mock_masked = MagicMock()
+ mock_summed = MagicMock()
+ mock_hidden_state.__mul__.return_value = mock_masked
+ mock_masked.sum.return_value = mock_summed
+ mock_summed.__truediv__.return_value = mock_chunk_embedding
+
+ result = huggingface_llm_model.create_embeddings(
+ "This is a long text that should be chunked into multiple pieces"
)
-
- assert result is not None
+ assert huggingface_llm_model.model.call_count >= 3
+ mock_cat.assert_called_once()
+ mock_mean.assert_called_once()
+ assert result == [0.1, 0.2, 0.3]
+
+
+@patch("torch.nn.functional.normalize")
+@patch("torch.mean")
+@patch("torch.cat")
+@patch("torch.tensor")
+def test_create_embeddings_list_text(mock_tensor, mock_cat, mock_mean, mock_normalise, huggingface_llm_model):
+ def tokenizer_side_effect(text, **kwargs):
+ if isinstance(text, list):
+ return {
+ "input_ids": [list(range(10)), list(range(15))],
+ "attention_mask": [[1]*10, [1]*15]
+ }
+ else:
+ return {
+ "input_ids": list(range(len(text.split()))),
+ "attention_mask": [1] * len(text.split())
+ }
+
+ def tensor_side_effect(*args, **kwargs):
+ result = MagicMock()
+ result.to.return_value = result
+ return result
-def test_create_embeddings_list_text(huggingface_llm_model):
huggingface_llm_model.init_model()
huggingface_llm_model.model = MagicMock()
+ huggingface_llm_model.model.config.max_position_embeddings = 6
huggingface_llm_model.tokenizer = MagicMock()
- mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()]
+ huggingface_llm_model.tokenizer.side_effect = tokenizer_side_effect
mock_outputs = MagicMock()
- mock_outputs.hidden_states = mock_hidden_states
- mock_last_hidden_state = MagicMock()
- mock_last_hidden_state.shape = [2, 3, 768]
- mock_hidden_states[-1] = mock_last_hidden_state
- mock_attention_mask = MagicMock()
- mock_attention_mask.shape = [2, 3]
- mock_attention_mask.sum.return_value = MagicMock()
- mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock()
- mock_inputs = MagicMock()
- mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock()
- huggingface_llm_model.tokenizer.return_value = mock_inputs
- huggingface_llm_model.model.return_value = mock_outputs
- mock_embeddings_batch = MagicMock()
- mock_first_embedding = MagicMock()
- mock_cpu_tensor = MagicMock()
- mock_numpy_array = MagicMock()
- mock_numpy_array.tolist.return_value = [[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]]
- mock_embeddings_batch.__getitem__.return_value = mock_first_embedding
- mock_first_embedding.cpu.return_value = mock_cpu_tensor
- mock_cpu_tensor.numpy.return_value = mock_numpy_array
- mock_masked_hidden_states = MagicMock()
- mock_sum_hidden_states = MagicMock()
- mock_num_tokens = MagicMock()
- mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states
- mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states
- mock_attention_mask.sum.return_value = mock_num_tokens
- mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch
-
- result = huggingface_llm_model.create_embeddings(["Alright", "Alright"])
-
- huggingface_llm_model.tokenizer.assert_called_once_with(
- ["Alright", "Alright"],
- add_special_tokens=False,
- return_tensors="pt",
- padding=True,
- truncation=True,
- )
- assert result is not None
+ mock_hidden_state = MagicMock()
+ mock_outputs.hidden_states = [None, None, mock_hidden_state]
+ huggingface_llm_model.model.return_value = mock_outputs
+ mock_chunk_embedding = MagicMock()
+ mock_final_embedding = MagicMock()
+ mock_normalised = MagicMock()
+ mock_concatenated = MagicMock()
+ mock_cat.return_value = mock_concatenated
+ mock_mean.return_value = mock_final_embedding
+ mock_normalise.return_value = mock_normalised
+ mock_normalised.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]]
+ mock_tensor.side_effect = tensor_side_effect
+ mock_masked = MagicMock()
+ mock_summed = MagicMock()
+ mock_hidden_state.__mul__.return_value = mock_masked
+ mock_masked.sum.return_value = mock_summed
+ mock_summed.__truediv__.return_value = mock_chunk_embedding
+
+ result = huggingface_llm_model.create_embeddings([
+ "Alright?",
+ "This is a long text that should be chunked into multiple pieces",
+ ])
+
+ assert huggingface_llm_model.model.call_count >= 4
+ assert mock_cat.call_count == 2
+ assert mock_mean.call_count == 2
+ assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
diff --git a/tests/app/model_services/test_huggingface_ner_model.py b/tests/app/model_services/test_huggingface_ner_model.py
index f617979..44779da 100644
--- a/tests/app/model_services/test_huggingface_ner_model.py
+++ b/tests/app/model_services/test_huggingface_ner_model.py
@@ -99,3 +99,20 @@ def test_train_supervised(huggingface_ner_model):
with tempfile.TemporaryFile("r+") as f:
huggingface_ner_model.train_supervised(f, 1, 1, "training_id", "input_file_name")
huggingface_ner_model._supervised_trainer.train.assert_called()
+
+
+def test_create_embeddings(huggingface_ner_model):
+ huggingface_ner_model.init_model()
+
+ text = "Spinal stenosis"
+ embedding = huggingface_ner_model.create_embeddings(text)
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+ assert all(isinstance(x, float) for x in embedding)
+
+ texts = ["Spinal stenosis", "Diabetes"]
+ embeddings = huggingface_ner_model.create_embeddings(texts)
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == 2
+ assert all(isinstance(emb, list) for emb in embeddings)
+ assert all(len(emb) > 0 for emb in embeddings)
diff --git a/tests/app/model_services/test_medcat_model_deid.py b/tests/app/model_services/test_medcat_model_deid.py
index 59f6b1d..8c16d90 100644
--- a/tests/app/model_services/test_medcat_model_deid.py
+++ b/tests/app/model_services/test_medcat_model_deid.py
@@ -186,3 +186,27 @@ def test_train_supervised(medcat_deid_model):
with tempfile.TemporaryFile("r+") as f:
medcat_deid_model.train_supervised(f, 1, 1, "training_id", "input_file_name")
medcat_deid_model._supervised_trainer.train.assert_called()
+
+
+@pytest.mark.skipif(
+ not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")),
+ reason="requires the model file to be present in the resources folder",
+)
+def test_create_embeddings(medcat_umls_model):
+ medcat_umls_model.init_model()
+
+ embedding = medcat_umls_model.create_embeddings("This is a post code NW1 2DA")
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+ assert all(isinstance(x, float) for x in embedding)
+
+ embeddings = medcat_umls_model.create_embeddings([
+ "This is a post code NW1 2DA",
+ "This is a post code NW1 2DB",
+ ])
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == 2
+ for emb in embeddings:
+ assert isinstance(emb, list)
+ assert len(emb) > 0
+ assert all(isinstance(x, float) for x in emb)
diff --git a/tests/app/model_services/test_medcat_model_icd10.py b/tests/app/model_services/test_medcat_model_icd10.py
index b1f6bcd..c019c45 100644
--- a/tests/app/model_services/test_medcat_model_icd10.py
+++ b/tests/app/model_services/test_medcat_model_icd10.py
@@ -101,7 +101,7 @@ def test_annotate(medcat_icd10_model):
medcat_icd10_model.init_model()
annotations = medcat_icd10_model.annotate("Spinal stenosis")
assert len(annotations) == 1
- assert type(annotations[0]["label_name"]) is str
+ assert type(annotations[0].label_name) is str
assert annotations[0].start == 0
assert annotations[0].end == 15
assert annotations[0].accuracy > 0
@@ -133,3 +133,24 @@ def test_train_unsupervised(medcat_icd10_model):
with tempfile.TemporaryFile("r+") as f:
medcat_icd10_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name")
medcat_icd10_model._unsupervised_trainer.train.assert_called()
+
+
+@pytest.mark.skipif(
+ not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")),
+ reason="requires the model file to be present in the resources folder",
+)
+def test_create_embeddings(medcat_umls_model):
+ medcat_umls_model.init_model()
+
+ embedding = medcat_umls_model.create_embeddings("Spinal stenosis")
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+ assert all(isinstance(x, float) for x in embedding)
+
+ embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"])
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == 2
+ for emb in embeddings:
+ assert isinstance(emb, list)
+ assert len(emb) > 0
+ assert all(isinstance(x, float) for x in emb)
diff --git a/tests/app/model_services/test_medcat_model_opcs4.py b/tests/app/model_services/test_medcat_model_opcs4.py
index 12b9d0d..5fcd932 100644
--- a/tests/app/model_services/test_medcat_model_opcs4.py
+++ b/tests/app/model_services/test_medcat_model_opcs4.py
@@ -101,7 +101,7 @@ def test_annotate(medcat_opcs4_model):
medcat_opcs4_model.init_model()
annotations = medcat_opcs4_model.annotate("Spinal tap")
assert len(annotations) == 1
- assert type(annotations[0]["label_name"]) is str
+ assert type(annotations[0].label_name) is str
assert annotations[0].start == 0
assert annotations[0].end == 10
assert annotations[0].accuracy > 0
@@ -133,3 +133,24 @@ def test_train_unsupervised(medcat_opcs4_model):
with tempfile.TemporaryFile("r+") as f:
medcat_opcs4_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name")
medcat_opcs4_model._unsupervised_trainer.train.assert_called()
+
+
+@pytest.mark.skipif(
+ not os.path.exists(os.path.join(MODEL_PARENT_DIR, "opcs4_model.zip")),
+ reason="requires the model file to be present in the resources folder",
+)
+def test_create_embeddings(medcat_umls_model):
+ medcat_umls_model.init_model()
+
+ embedding = medcat_umls_model.create_embeddings("Spinal stenosis")
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+ assert all(isinstance(x, float) for x in embedding)
+
+ embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"])
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == 2
+ for emb in embeddings:
+ assert isinstance(emb, list)
+ assert len(emb) > 0
+ assert all(isinstance(x, float) for x in emb)
diff --git a/tests/app/model_services/test_medcat_model_snomed.py b/tests/app/model_services/test_medcat_model_snomed.py
index 4928660..b4a4ae7 100644
--- a/tests/app/model_services/test_medcat_model_snomed.py
+++ b/tests/app/model_services/test_medcat_model_snomed.py
@@ -98,7 +98,7 @@ def test_annotate(medcat_snomed_model):
medcat_snomed_model.init_model()
annotations = medcat_snomed_model.annotate("Spinal stenosis")
assert len(annotations) == 1
- assert type(annotations[0]["label_name"]) is str
+ assert type(annotations[0].label_name) is str
assert annotations[0].start == 0
assert annotations[0].end == 15
assert annotations[0].accuracy > 0
@@ -130,3 +130,24 @@ def test_train_unsupervised(medcat_snomed_model):
with tempfile.TemporaryFile("r+") as f:
medcat_snomed_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name")
medcat_snomed_model._unsupervised_trainer.train.assert_called()
+
+
+@pytest.mark.skipif(
+ not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")),
+ reason="requires the model file to be present in the resources folder",
+)
+def test_create_embeddings(medcat_umls_model):
+ medcat_umls_model.init_model()
+
+ embedding = medcat_umls_model.create_embeddings("Spinal stenosis")
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+ assert all(isinstance(x, float) for x in embedding)
+
+ embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"])
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == 2
+ for emb in embeddings:
+ assert isinstance(emb, list)
+ assert len(emb) > 0
+ assert all(isinstance(x, float) for x in emb)
diff --git a/tests/app/model_services/test_medcat_model_umls.py b/tests/app/model_services/test_medcat_model_umls.py
index 2f9ff15..71987f6 100644
--- a/tests/app/model_services/test_medcat_model_umls.py
+++ b/tests/app/model_services/test_medcat_model_umls.py
@@ -61,7 +61,7 @@ def test_annotate(medcat_umls_model):
medcat_umls_model.init_model()
annotations = medcat_umls_model.annotate("Spinal stenosis")
assert len(annotations) == 1
- assert type(annotations[0]["label_name"]) is str
+ assert type(annotations[0].label_name) is str
assert annotations[0].start == 0
assert annotations[0].end == 15
assert annotations[0].accuracy > 0
@@ -93,3 +93,24 @@ def test_train_unsupervised(medcat_umls_model):
with tempfile.TemporaryFile("r+") as f:
medcat_umls_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name")
medcat_umls_model._unsupervised_trainer.train.assert_called()
+
+
+@pytest.mark.skipif(
+ not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")),
+ reason="requires the model file to be present in the resources folder",
+)
+def test_create_embeddings(medcat_umls_model):
+ medcat_umls_model.init_model()
+
+ embedding = medcat_umls_model.create_embeddings("Spinal stenosis")
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+ assert all(isinstance(x, float) for x in embedding)
+
+ embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"])
+ assert isinstance(embeddings, list)
+ assert len(embeddings) == 2
+ for emb in embeddings:
+ assert isinstance(emb, list)
+ assert len(emb) > 0
+ assert all(isinstance(x, float) for x in emb)