From a869a834cb5e69a5a1273b09b9f9a57f23b4548a Mon Sep 17 00:00:00 2001 From: Xi Bai Date: Thu, 15 Jan 2026 14:46:52 +0000 Subject: [PATCH] feat: add embedding creation and chuncking for MedCAT and HF NER models --- app/api/routers/invocation.py | 94 ++++- app/api/routers/stream.py | 41 ++- app/model_services/huggingface_llm_model.py | 111 ++++-- app/model_services/huggingface_ner_model.py | 80 ++++- app/model_services/medcat_model.py | 82 ++++- app/model_services/medcat_model_deid.py | 96 ++++- app/trainers/huggingface_llm_trainer.py | 336 +++++++++++++++++- tests/app/api/test_api.py | 13 - .../test_huggingface_llm_model.py | 178 +++++----- .../test_huggingface_ner_model.py | 17 + .../model_services/test_medcat_model_deid.py | 24 ++ .../model_services/test_medcat_model_icd10.py | 23 +- .../model_services/test_medcat_model_opcs4.py | 23 +- .../test_medcat_model_snomed.py | 23 +- .../model_services/test_medcat_model_umls.py | 23 +- 15 files changed, 1027 insertions(+), 137 deletions(-) diff --git a/app/api/routers/invocation.py b/app/api/routers/invocation.py index 47c6828..3c6f5c1 100644 --- a/app/api/routers/invocation.py +++ b/app/api/routers/invocation.py @@ -7,12 +7,14 @@ import hashlib import logging import pandas as pd +from fastapi.encoders import jsonable_encoder + import app.api.globals as cms_globals from typing import Dict, List, Union, Iterator, Any from collections import defaultdict from io import BytesIO -from starlette.status import HTTP_400_BAD_REQUEST +from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from typing_extensions import Annotated from fastapi import APIRouter, Depends, Body, UploadFile, File, Request, Query, Response from fastapi.responses import StreamingResponse, PlainTextResponse, JSONResponse @@ -22,7 +24,7 @@ TextWithAnnotations, TextWithPublicKey, TextStreamItem, - Tags, + Tags, OpenAIEmbeddingsRequest, OpenAIEmbeddingsResponse, ) from app.model_services.base import AbstractModelService from app.utils import get_settings, load_pydantic_object_from_dict @@ -43,6 +45,7 @@ PATH_PROCESS_BULK_FILE = "/process_bulk_file" PATH_REDACT = "/redact" PATH_REDACT_WITH_ENCRYPTION = "/redact_with_encryption" +PATH_OPENAI_EMBEDDINGS = "/v1/embeddings" router = APIRouter() config = get_settings() @@ -355,6 +358,93 @@ def get_redacted_text_with_encryption( return JSONResponse(content=content) +@router.post( + PATH_OPENAI_EMBEDDINGS, + tags=[Tags.OpenAICompatible.name], + response_model=None, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint", +) +def embed_texts( + request: Request, + request_data: Annotated[OpenAIEmbeddingsRequest, Body( + description="Text(s) to be embedded", media_type="application/json" + )], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> JSONResponse: + """ + Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint. + + Args: + request (Request): The request object. + request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s). + tracking_id (Union[str, None]): An optional tracking ID of the requested task. + model_service (AbstractModelService): The model service dependency. + + Returns: + JSONResponse: A response containing the embeddings of the text(s). + """ + tracking_id = tracking_id or str(uuid.uuid4()) + + if not hasattr(model_service, "create_embeddings"): + error_response = { + "error": { + "message": "Model does not support embeddings", + "type": "invalid_request_error", + "param": "model", + "code": "model_not_supported", + } + } + return JSONResponse( + content=error_response, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + headers={"x-cms-tracking-id": tracking_id}, + ) + + input_text = request_data.input + model = model_service.model_name if request_data.model != model_service.model_name else request_data.model + + if isinstance(input_text, str): + input_texts = [input_text] + else: + input_texts = input_text + + try: + embeddings_data = [] + + for i, embedding in enumerate(model_service.create_embeddings(input_texts)): + embeddings_data.append({ + "object": "embedding", + "embedding": embedding, + "index": i, + }) + + response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model) + + return JSONResponse( + content=jsonable_encoder(response), + headers={"x-cms-tracking-id": tracking_id}, + ) + + except Exception as e: + logger.error("Failed to create embeddings") + logger.exception(e) + error_response = { + "error": { + "message": f"Failed to create embeddings: {str(e)}", + "type": "server_error", + "code": "internal_error", + } + } + return JSONResponse( + content=error_response, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + headers={"x-cms-tracking-id": tracking_id}, + ) + + + def _send_annotation_num_metric(annotation_num: int, handler: str) -> None: cms_doc_annotations.labels(handler=handler).observe(annotation_num) diff --git a/app/api/routers/stream.py b/app/api/routers/stream.py index edba5e8..c44ecf8 100644 --- a/app/api/routers/stream.py +++ b/app/api/routers/stream.py @@ -11,7 +11,7 @@ from starlette.types import Receive, Scope, Send from starlette.background import BackgroundTask from fastapi import APIRouter, Depends, Request, Response, WebSocket, WebSocketException -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from app.domain import Tags, TextStreamItem from app.model_services.base import AbstractModelService from app.utils import get_settings @@ -20,7 +20,6 @@ PATH_STREAM_PROCESS = "/process" PATH_WS = "/ws" -PATH_GENERATE= "/generate" router = APIRouter() config = get_settings() @@ -57,6 +56,22 @@ async def get_entities_stream_from_jsonlines_stream( return _LocalStreamingResponse(annotation_stream, media_type="application/x-ndjson; charset=utf-8") +@router.get( + PATH_WS, + tags=[Tags.Annotations.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="WebSocket info endpoint for real-time NER entity extraction. Use ws://host:port/stream/ws to establish an actual WebSocket connection.", + include_in_schema=True, +) +async def get_inline_annotations_from_websocket_info() -> "_WebSocketInfo": + """ + Information about the WebSocket endpoint for real-time NER entity extraction. + + This endpoint provides documentation for the WebSocket connection available at the same path. + Connect to ws://host:port/stream/ws and send texts to retrieve annotated results. + """ + return _WebSocketInfo() + @router.websocket(PATH_WS) # @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) # Not supported yet async def get_inline_annotations_from_websocket( @@ -189,6 +204,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.background() +class _WebSocketInfo(BaseModel): + message: str = "WebSocket endpoint for real-time NER entity extraction" + example: str = """
+ + +
+ +""" + protocol: str = "WebSocket" + + async def _annotation_async_gen(request: Request, model_service: AbstractModelService) -> AsyncGenerator: try: buffer = "" diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index bc28bfd..4eafe8b 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -15,7 +15,7 @@ from app import __version__ as app_version from app.exception import ConfigurationException from app.model_services.base import AbstractModelService -from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer +from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer, HuggingFaceLlmUnsupervisedTrainer from app.domain import ModelCard, ModelType, Annotation, Device from app.config import Settings from app.utils import ( @@ -62,6 +62,7 @@ def __init__( self._multi_label_threshold = 0.5 self._text_generator = ThreadPoolExecutor(max_workers=50) self.model_name = model_name or "HuggingFace LLM model" + self.is_4bit_quantised = False @property def model(self) -> PreTrainedModel: @@ -206,6 +207,8 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N self._model.to(get_settings().DEVICE) if self._enable_trainer: self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self) + self._unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(self) + self.is_4bit_quantised = load_in_4bit def info(self) -> ModelCard: """ @@ -396,29 +399,47 @@ def create_embeddings( self.model.eval() - inputs = self.tokenizer( - text, - add_special_tokens=False, - return_tensors="pt", - padding=True, - truncation=True, - ) - - inputs.to(self.model.device) - - with torch.no_grad(): - outputs = self.model(**inputs, output_hidden_states=True) + texts = [text] if isinstance(text, str) else text + all_embeddings = [] + + for txt in texts: + inputs = self.tokenizer(txt, add_special_tokens=False, truncation=False, padding=False) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + window_size = max(self.model.config.max_position_embeddings - 2, 1) + stride = window_size + chunk_embeddings = [] + + for start in range(0, len(input_ids), stride): + end = min(start + window_size, len(input_ids)) + chunk_inputs = { + "input_ids": torch.tensor( + [input_ids[start:end]], dtype=torch.long + ).to(self.model.device), + "attention_mask": torch.tensor( + [attention_mask[start:end]], dtype=torch.long + ).to(self.model.device), + } + + with torch.no_grad(): + outputs = self.model(**chunk_inputs, output_hidden_states=True) + + last_hidden_state = outputs.hidden_states[-1] + chunk_attention_mask = chunk_inputs["attention_mask"] + masked_hidden_states = last_hidden_state * chunk_attention_mask.unsqueeze(-1) + sum_hidden_states = masked_hidden_states.sum(dim=1) + num_tokens = chunk_attention_mask.sum(dim=1, keepdim=True) + chunk_embedding = sum_hidden_states / num_tokens + chunk_embeddings.append(chunk_embedding) + + if end >= len(input_ids): + break - last_hidden_state = outputs.hidden_states[-1] - attention_mask = inputs["attention_mask"] - masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1) - sum_hidden_states = masked_hidden_states.sum(dim=1) - num_tokens = attention_mask.sum(dim=1, keepdim=True) - embeddings = sum_hidden_states / num_tokens - l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1) + final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True) + l2_normalised = torch.nn.functional.normalize(final_embedding, p=2, dim=1) + all_embeddings.append(l2_normalised.cpu().numpy().tolist()[0]) - results = l2_normalised.cpu().numpy().tolist() - return results[0] if isinstance(text, str) else results + return all_embeddings[0] if isinstance(text, str) else all_embeddings def train_supervised( self, @@ -465,3 +486,49 @@ def train_supervised( synchronised, **hyperparams, ) + + def train_unsupervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> Tuple[bool, str, str]: + """ + Initiates unsupervised training on the model. + + Args: + data_file (TextIO): The file containing a JSON list of texts. + epochs (int): The number of training epochs. + log_frequency (int): The number of epochs after which training metrics will be logged. + training_id (str): A unique identifier for the training process. + input_file_name (str): The name of the input file to be logged. + raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None. + description (Optional[str]): The description of the training or change logs. Defaults to empty. + synchronised (bool): Whether to wait for the training to complete. + **hyperparams (Dict[str, Any]): Additional hyperparameters for training. + + Returns: + Tuple[bool, str, str]: A tuple with the first element indicating success or failure. + + Raises: + ConfigurationException: If the unsupervised trainer is not enabled. + """ + if self._unsupervised_trainer is None: + raise ConfigurationException("The unsupervised trainer is not enabled") + return self._unsupervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index 98e55f9..a6eeb8c 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -1,9 +1,10 @@ import os import logging +import torch import pandas as pd from functools import partial -from typing import Dict, List, Optional, Tuple, Any, TextIO +from typing import Dict, List, Optional, Tuple, Any, TextIO, Union from transformers import ( AutoModelForTokenClassification, AutoTokenizer, @@ -276,6 +277,83 @@ def annotate(self, text: str) -> List[Annotation]: def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: raise NotImplementedError("Batch annotation is not yet implemented for HuggingFace NER models") + def create_embeddings( + self, + text: Union[str, List[str]], + *args: Any, + **kwargs: Any + ) -> Union[List[float], List[List[float]]]: + """ + Creates embeddings for a given text or list of texts using the model's hidden states. + + Args: + text (Union[str, List[str]]): The text(s) to be embedded. + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + + Returns: + List[float], List[List[float]]: The embedding vector(s) for the text(s). + + Raises: + NotImplementedError: If the model doesn't support embeddings. + """ + + self.model.eval() + + texts = [text] if isinstance(text, str) else text + all_embeddings = [] + + max_len = self.model.config.max_position_embeddings + + for txt in texts: + encoded = self.tokenizer( + txt, + add_special_tokens=True, + truncation=False, + return_attention_mask=True, + ) + + input_ids = encoded["input_ids"] + chunk_embeddings = [] + window_size = max_len - 2 + stride = window_size + + for start in range(0, len(input_ids), stride): + end = min(start + window_size, len(input_ids)) + + chunk = self.tokenizer.prepare_for_model( + input_ids[start:end], + add_special_tokens=True, + return_attention_mask=True, + truncation=True, + max_length=max_len, + padding="max_length", + ) + + chunk_inputs = { + "input_ids": torch.tensor([chunk["input_ids"]], device=self.model.device), + "attention_mask": torch.tensor([chunk["attention_mask"]], device=self.model.device), + } + + with torch.no_grad(): + outputs = self.model(**chunk_inputs, output_hidden_states=True) + + last_hidden_state = outputs.hidden_states[-1] + mask = chunk_inputs["attention_mask"].unsqueeze(-1) + summed = (last_hidden_state * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp(min=1) + chunk_embedding = summed / counts + chunk_embeddings.append(chunk_embedding) + + if end >= len(input_ids): + break + + final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True) + final_embedding = torch.nn.functional.normalize(final_embedding, p=2, dim=1) + all_embeddings.append(final_embedding.cpu().numpy()[0].tolist()) + + return all_embeddings[0] if isinstance(text, str) else all_embeddings + def train_supervised( self, data_file: TextIO, diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py index 0ef00d1..800e68d 100644 --- a/app/model_services/medcat_model.py +++ b/app/model_services/medcat_model.py @@ -1,10 +1,15 @@ import os import logging +import torch +import numpy as np import pandas as pd from multiprocessing import cpu_count from typing import Dict, List, Optional, TextIO, Tuple, Any, Set, Union from medcat.cat import CAT +from medcat.components.linking.embedding_linker import Linker +from medcat.components.linking.vector_context_model import PerDocumentTokenCache +from medcat.components.types import CoreComponentType from medcat.data.entities import Entities, OnlyCUIEntities from app import __version__ as app_version from app.model_services.base import AbstractModelService @@ -20,7 +25,7 @@ get_model_data_package_base_name, load_pydantic_object_from_dict, ) -from app.exception import ConfigurationException +from app.exception import ConfigurationException, ManagedModelException logger = logging.getLogger("cms") @@ -198,7 +203,7 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: annotations_list = [] for _, doc in docs.items(): annotations_list.append([ - load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc) + load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc) # type: ignore ]) return annotations_list @@ -370,6 +375,79 @@ def get_records_from_doc(self, doc: Union[Dict, Entities, OnlyCUIEntities]) -> L records = df.to_dict("records") return records + def create_embeddings( + self, + text: Union[str, List[str]], + *args: Any, + model_name: Optional[str] = None, + max_length: Optional[int] = None, + **kwargs: Any + ) -> Union[List[float], List[List[float]]]: + """ + Creates embeddings for a given text or list of texts using MedCAT's embedding linker. + + Args: + text (Union[str, List[str]]): The text(s) to be embedded. + model_name (Optional[str]): The name of the embedding model to use. + max_length (Optional[int]): Maximum sequence length for tokenization. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + Union[List[float], List[List[float]]]: The embedding vector(s) for the text(s). + """ + + assert self._model is not None, "Model is not initialised" + texts = [text] if isinstance(text, str) else text + linker = self._model.pipe.get_component(CoreComponentType.linking) + + if isinstance(linker, Linker): + embedding_model_name = getattr(linker.cnf_l, "embedding_model_name", None) + if embedding_model_name is None: + raise ManagedModelException("Embedding linker present but no embedding_model_name found in config.") + linker._load_transformers(embedding_model_name) + with torch.no_grad(): + emb_tensor = linker._embed(texts, linker.device) + embeddings = emb_tensor.cpu().numpy().tolist() + return embeddings[0] if isinstance(text, str) else embeddings + else: + all_embeddings = [] + ctx_model = getattr(linker, "context_model", None) + if ctx_model is None: + raise ManagedModelException( + "Linker does not expose context_model so cannot compute context-based embeddings." + ) + tokenizer = self._model.pipe.tokenizer + for txt in texts: + doc = tokenizer(txt) + if hasattr(tokenizer, "entity_from_tokens"): + entity = tokenizer.entity_from_tokens(list(doc)) + else: + raise ManagedModelException( + "Tokenizer does not support entity_from_tokens so cannot build entity for context model" + ) + + cache = PerDocumentTokenCache() + vectors = ctx_model.get_context_vectors(entity, doc, cache) + weights = getattr(linker.config.components.linking, "context_vector_weights", None) # type: ignore + if not weights: + weights = {k: 1.0 for k in vectors.keys()} + + combined = None + for size, vec in vectors.items(): + w = weights.get(size, 1.0) + if combined is None: + combined = w * vec + else: + combined = combined + w * vec + if combined is not None: + norm = np.linalg.norm(combined) + if norm > 0: + combined = combined / norm + all_embeddings.append(combined.tolist()) + + return all_embeddings[0] if isinstance(text, str) else all_embeddings + @staticmethod def _retrieve_meta_annotations(df: pd.DataFrame) -> pd.DataFrame: meta_annotations = [] diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py index 547f225..c401d7e 100644 --- a/app/model_services/medcat_model_deid.py +++ b/app/model_services/medcat_model_deid.py @@ -2,7 +2,7 @@ import inspect import threading import torch -from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable, cast +from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable, cast, Union from functools import partial from transformers import pipeline from medcat.cat import CAT @@ -91,7 +91,7 @@ def annotate(self, text: str) -> List[Annotation]: for _, entity in doc["entities"].items(): entity["type_ids"] = ["PII"] - records = self.get_records_from_doc({"entities": doc["entities"]}) + records = self.get_records_from_doc({"entities": doc["entities"]}) # type: ignore return [load_pydantic_object_from_dict(Annotation, record) for record in records] def annotate_with_local_chunking(self, text: str) -> List[Annotation]: @@ -179,11 +179,101 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: entity = cast(Dict[str, Any], entity) entity["type_ids"] = ["PII"] annotations_list.append([ - load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(entities) + load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(entities) # type: ignore ]) return annotations_list + def create_embeddings( + self, + text: Union[str, List[str]], + *args: Any, + **kwargs: Any + ) -> Union[List[float], List[List[float]]]: + """ + Creates embeddings for a given text or list of texts using the model's hidden states. + + Args: + text (Union[str, List[str]]): The text(s) to be embedded. + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + + Returns: + List[float], List[List[float]]: The embedding vector(s) for the text(s). + + Raises: + NotImplementedError: If the model doesn't support embeddings. + """ + + assert self.model is not None, "Model is not initialised" + ner = self.model.pipe.get_component(CoreComponentType.ner)._component # type: ignore + ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr( + ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False + ) + ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr( + ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None + ) + ner.tokenizer.hf_tokenizer.split_special_tokens = getattr( + ner.tokenizer.hf_tokenizer, "split_special_tokens", False + ) + tokenizer = ner.tokenizer.hf_tokenizer + model = ner.model + model.eval() + + texts = [text] if isinstance(text, str) else text + all_embeddings = [] + + max_len = model.config.max_position_embeddings + + for txt in texts: + encoded = tokenizer( + txt, + add_special_tokens=True, + truncation=False, + return_attention_mask=True, + ) + + input_ids = encoded["input_ids"] + chunk_embeddings = [] + window_size = max_len - 2 + stride = window_size + + for start in range(0, len(input_ids), stride): + end = min(start + window_size, len(input_ids)) + + chunk = tokenizer.prepare_for_model( + input_ids[start:end], + add_special_tokens=True, + return_attention_mask=True, + truncation=True, + max_length=max_len, + padding="max_length", + ) + + chunk_inputs = { + "input_ids": torch.tensor([chunk["input_ids"]], device=model.device), + "attention_mask": torch.tensor([chunk["attention_mask"]], device=model.device), + } + + with torch.no_grad(): + outputs = model(**chunk_inputs, output_hidden_states=True) + + last_hidden_state = outputs.hidden_states[-1] + mask = chunk_inputs["attention_mask"].unsqueeze(-1) + summed = (last_hidden_state * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp(min=1) + chunk_embedding = summed / counts + chunk_embeddings.append(chunk_embedding) + + if end >= len(input_ids): + break + + final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True) + final_embedding = torch.nn.functional.normalize(final_embedding, p=2, dim=1) + all_embeddings.append(final_embedding.cpu().numpy()[0].tolist()) + + return all_embeddings[0] if isinstance(text, str) else all_embeddings + def init_model(self, *args: Any, **kwargs: Any) -> None: """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration. diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index 5350422..89404fa 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -1,10 +1,12 @@ import os import logging import math +import tempfile import torch import gc import datasets import re +import random import threading import json import inspect @@ -18,6 +20,8 @@ TrainerCallback, TrainerState, TrainerControl, + DataCollatorForLanguageModeling, + Trainer, ) from peft import LoraConfig, get_peft_model # type: ignore from app.management.model_manager import ModelManager @@ -32,13 +36,22 @@ get_default_system_prompt, get_model_data_package_base_name, ) -from app.trainers.base import SupervisedTrainer -from app.domain import ModelType, TrainerBackend, LlmRole, LlmTrainerType, LlmDatasetType, PromptMessage +from app.trainers.base import SupervisedTrainer, UnsupervisedTrainer +from app.domain import ( + ModelType, + TrainerBackend, + LlmRole, + LlmTrainerType, + LlmDatasetType, + PromptMessage, + DatasetSplit, +) from app.exception import ( TrainingCancelledException, DatasetException, ConfigurationException, ExtraDependencyRequiredException, + ManagedModelException, ) if TYPE_CHECKING: from app.model_services.huggingface_llm_model import HuggingFaceLlmModel @@ -719,6 +732,325 @@ def _evaluate_with_rewards( return reward_avgs +@final +class HuggingFaceLlmUnsupervisedTrainer(UnsupervisedTrainer, _HuggingFaceLlmTrainerCommon): + """ + An unsupervised trainer class for HuggingFace LLM models. + + Args: + model_service (HuggingFaceLlmModel): An instance of the HuggingFace LLM model service. + """ + + def __init__(self, model_service: "HuggingFaceLlmModel") -> None: + UnsupervisedTrainer.__init__(self, model_service._config, model_service.model_name) + self._model_service = model_service + self._model_name = model_service.model_name + self._model_pack_path = model_service._model_pack_path + self._retrained_models_dir = os.path.join( + model_service._model_parent_dir, + "retrained", + self._model_name.replace(" ", "_"), + ) + self._model_manager = ModelManager(type(model_service), model_service._config) + self._max_length = model_service.model.config.max_position_embeddings + os.makedirs(self._retrained_models_dir, exist_ok=True) + + def run( + self, + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: + """ + Runs the unsupervised training loop for HuggingFace LLM models. + + Args: + training_params (Dict): A dictionary containing parameters for the training. + data_file (TextIO): The file-like object containing the training data. + log_frequency (int): The frequency at which logs should be recorded. + run_id (str): The run ID of the training job. + description (Optional[str]): The optional description of the training. + """ + eval_mode = training_params["nepochs"] == 0 + trained_model_pack_path = None + redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true" + skip_save_model = self._config.SKIP_SAVE_MODEL == "true" + results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results")) + logs_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "logs")) + reset_random_seed() + trainer = None + + if not eval_mode: + try: + copied_model_directory = None + if self._model_service.is_4bit_quantised: + logger.info("Use the LoRA adaptor for the quantised model...") + lora_config = LoraConfig( + task_type="CAUSAL_LM", + r=8, + lora_alpha=32, + lora_dropout=0.1, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + ) + model = get_peft_model(self._model_service.model, lora_config) + tokenizer = self._model_service.tokenizer + else: + logger.info("Loading a new model copy for training...") + copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id) + model, tokenizer = self._model_service.load_model(copied_model_pack_path) + copied_model_directory = os.path.join( + os.path.dirname(copied_model_pack_path), + get_model_data_package_base_name(copied_model_pack_path), + ) + + if non_default_device_is_available(self._config.DEVICE): + model.to(self._config.DEVICE) + + test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] + if isinstance(data_file, tempfile.TemporaryDirectory): + raw_dataset = datasets.load_from_disk(data_file.name) + if DatasetSplit.VALIDATION.value in raw_dataset.keys(): + train_texts = raw_dataset[DatasetSplit.TRAIN.value]["text"] + eval_texts = raw_dataset[DatasetSplit.VALIDATION.value]["text"] + elif DatasetSplit.TEST.value in raw_dataset.keys(): + train_texts = raw_dataset[DatasetSplit.TRAIN.value]["text"] + eval_texts = raw_dataset[DatasetSplit.TEST.value]["text"] + else: + lines = raw_dataset[DatasetSplit.TRAIN.value]["text"] + random.shuffle(lines) + train_texts = [line.strip() for line in lines[:int(len(lines) * (1 - test_size))]] + eval_texts = [line.strip() for line in lines[int(len(lines) * (1 - test_size)):]] + else: + with open(data_file.name, "r") as f: + lines = json.load(f) + random.shuffle(lines) + train_texts = [line.strip() for line in lines[:int(len(lines) * (1 - test_size))]] + eval_texts = [line.strip() for line in lines[int(len(lines) * (1 - test_size)):]] + + train_dataset = datasets.Dataset.from_dict({"text": train_texts}) + eval_dataset = datasets.Dataset.from_dict({"text": eval_texts}) + + train_dataset = train_dataset.map( + lambda examples: tokenizer(examples["text"], truncation=True, max_length=self._max_length), + batched=True, + remove_columns=["text"], + ) + eval_dataset = eval_dataset.map( + lambda examples: tokenizer(examples["text"], truncation=True, max_length=self._max_length), + batched=True, + remove_columns=["text"], + ) + + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + training_args = TrainingArguments( + output_dir=results_path, + logging_dir=logs_path, + logging_steps=log_frequency, + num_train_epochs=training_params["nepochs"], + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + learning_rate=5e-5, + weight_decay=0.01, + warmup_steps=500, + save_steps=1000, + eval_steps=1000, + report_to="none", + ) + + mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) + cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) + trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=data_collator, + callbacks=trainer_callbacks, + ) + + self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version) + + logger.info("Performing unsupervised training...") + trainer.train() + + if cancel_event_check_callback.training_cancelled: + raise TrainingCancelledException("Training was cancelled by the user") + + if not skip_save_model: + model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE) + model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}" + retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name) + trained_model_directory = os.path.join( + os.path.dirname(retrained_model_pack_path), + get_model_data_package_base_name(retrained_model_pack_path), + ) + if hasattr(model, "merge_and_unload"): + model = model.merge_and_unload() + model.save_pretrained( + trained_model_directory, + safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"), + ) + tokenizer.save_pretrained(trained_model_directory) + create_model_data_package(trained_model_directory, retrained_model_pack_path) + else: + model.save_pretrained( + copied_model_directory, # type: ignore + safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"), + ) + create_model_data_package(copied_model_directory, retrained_model_pack_path) # type: ignore + + self._tracker_client.log_model_config(model.config.to_dict()) # type: ignore + + model_uri = self._tracker_client.save_model( + retrained_model_pack_path, + self._model_name, + self._model_manager, + self._model_service.info().model_type.value, + ) + logger.info(f"Retrained model saved: {model_uri}") + else: + logger.info("Skipped saving the retrained model") + + if redeploy: + self.deploy_model(self._model_service, model, tokenizer) + else: + del model + del tokenizer + gc.collect() + logger.info("Skipped deployment of the retrained model") + + logger.info("Unsupervised training finished") + self._tracker_client.end_with_success() + + except TrainingCancelledException as e: + logger.exception(e) + logger.info("Unsupervised training was cancelled") + self._tracker_client.end_with_interruption() + except torch.OutOfMemoryError as e: + logger.exception("Unsupervised training failed on CUDA OOM") + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + except Exception: + pass + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + except Exception as e: + logger.exception("Unsupervised training failed") + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + finally: + data_file.close() + with self._training_lock: + self._training_in_progress = False + self._clean_up_training_cache() + self._housekeep_file(trained_model_pack_path) + if trainer is not None: + del trainer + gc.collect() + torch.cuda.empty_cache() + else: + try: + logger.info("Evaluating the running model...") + model, tokenizer = self._model_service.model, self._model_service.tokenizer + + if self._model_service.is_4bit_quantised: + logger.error("Cannot evaluate against a quantised model") + raise ManagedModelException("Cannot evaluate against a quantised model") + + if non_default_device_is_available(self._config.DEVICE): + model.to(self._config.DEVICE) + + if isinstance(data_file, tempfile.TemporaryDirectory): + raw_dataset = datasets.load_from_disk(data_file.name) + if DatasetSplit.TEST.value in raw_dataset.keys(): + eval_texts = raw_dataset[DatasetSplit.TEST.value]["text"] + elif DatasetSplit.VALIDATION.value in raw_dataset.keys(): + eval_texts = raw_dataset[DatasetSplit.VALIDATION.value]["text"] + else: + raise DatasetException("No test or validation split found in the input dataset file") + + else: + with open(data_file.name, "r") as f: + eval_texts = [line.strip() for line in json.load(f)] + + eval_dataset = datasets.Dataset.from_dict({"text": eval_texts}) + eval_dataset = eval_dataset.map( + lambda examples: tokenizer(examples["text"], truncation=True, max_length=self._max_length), + batched=True, + remove_columns=["text"], + ) + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + ) + + training_args = TrainingArguments( + output_dir=results_path, + logging_dir=logs_path, + logging_steps=log_frequency, + per_device_eval_batch_size=4, + report_to="none", + do_train=False, + do_eval=True, + ) + + mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) + cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) + trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] + + trainer = Trainer( + model=model, + args=training_args, + eval_dataset=eval_dataset, + data_collator=data_collator, + callbacks=trainer_callbacks, + ) + + eval_metrics = trainer.evaluate() + if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics: + eval_metrics.update({"perplexity": math.exp(eval_metrics["eval_loss"])}) + logger.info(f"Evaluation metrics: {eval_metrics}") + self._tracker_client.send_hf_metrics_logs(eval_metrics, 0) + self._tracker_client.end_with_success() + logger.info("Model evaluation finished") + except torch.OutOfMemoryError as e: + logger.exception("Evaluation failed on CUDA OOM") + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + except Exception: + pass + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + except Exception as e: + logger.exception("Evaluation failed") + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + finally: + data_file.close() + with self._training_lock: + self._training_in_progress = False + self._clean_up_training_cache() + if trainer is not None: + del trainer + gc.collect() + torch.cuda.empty_cache() + + @final class MLflowLoggingCallback(TrainerCallback): """ diff --git a/tests/app/api/test_api.py b/tests/app/api/test_api.py index 1d5f077..8e93585 100644 --- a/tests/app/api/test_api.py +++ b/tests/app/api/test_api.py @@ -14,19 +14,11 @@ def test_get_model_server(): model_service_dep = ModelServiceDep("medcat_snomed", config) app = get_model_server(config, model_service_dep) info = app.openapi()["info"] - tags = app.openapi_tags paths = [route.path for route in app.routes] assert isinstance(info["title"], str) assert isinstance(info["summary"], str) assert isinstance(info["version"], str) - assert {"name": "Metadata", "description": "Get the model card"} in tags - assert {"name": "Annotations", "description": "Retrieve NER entities by running the model"} in tags - assert {"name": "Redaction", "description": "Redact the extracted NER entities"} in tags - assert {"name": "Rendering", "description": "Preview embeddable annotation snippet in HTML"} in tags - assert {"name": "Training", "description": "Trigger model training on input annotations"} in tags - assert {"name": "Evaluating", "description": "Evaluate the deployed model with trainer export"} in tags - assert {"name": "Authentication", "description": "Authenticate registered users"} in tags assert "/info" in paths assert "/process" in paths assert "/process_jsonl" in paths @@ -61,13 +53,11 @@ def test_get_stream_server(): model_service_dep = ModelServiceDep("medcat_snomed", config) app = get_stream_server(config, model_service_dep) info = app.openapi()["info"] - tags = app.openapi_tags paths = [route.path for route in app.routes] assert isinstance(info["title"], str) assert isinstance(info["summary"], str) assert isinstance(info["version"], str) - assert {"name": "Streaming", "description": "Retrieve NER entities as a stream by running the model"} in tags assert "/info" in paths assert "/stream/process" in paths assert "/stream/ws" in paths @@ -84,14 +74,11 @@ def test_get_generative_server(): model_service_dep = ModelServiceDep("huggingface_llm_model", config) app = get_generative_server(config, model_service_dep) info = app.openapi()["info"] - tags = app.openapi_tags paths = [route.path for route in app.routes] assert isinstance(info["title"], str) assert isinstance(info["summary"], str) assert isinstance(info["version"], str) - assert {"name": "Metadata", "description": "Get the model card"} in tags - assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags assert "/info" in paths assert "/generate" in paths assert "/stream/generate" in paths diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py index 6fbb89d..0f134f9 100644 --- a/tests/app/model_services/test_huggingface_llm_model.py +++ b/tests/app/model_services/test_huggingface_llm_model.py @@ -148,101 +148,107 @@ async def test_generate_async(huggingface_llm_model): assert result == "Yeah." -def test_create_embeddings_single_text(huggingface_llm_model): - """Test create_embeddings with single text input.""" +@patch("torch.nn.functional.normalize") +@patch("torch.mean") +@patch("torch.cat") +@patch("torch.tensor") +def test_create_embeddings_single_text(mock_tensor, mock_cat, mock_mean, mock_normalise, huggingface_llm_model): + def tensor_side_effect(*args, **kwargs): + result = MagicMock() + result.to.return_value = result + return result + huggingface_llm_model.init_model() huggingface_llm_model.model = MagicMock() + huggingface_llm_model.model.config.max_position_embeddings = 10 huggingface_llm_model.tokenizer = MagicMock() - mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()] + long_input_ids = list(range(25)) + long_attention_mask = [1] * 25 + huggingface_llm_model.tokenizer.return_value = { + "input_ids": long_input_ids, + "attention_mask": long_attention_mask + } mock_outputs = MagicMock() - mock_outputs.hidden_states = mock_hidden_states - mock_last_hidden_state = MagicMock() - mock_last_hidden_state.shape = [1, 3, 768] - mock_hidden_states[-1] = mock_last_hidden_state - mock_attention_mask = MagicMock() - mock_attention_mask.shape = [1, 3] - mock_attention_mask.sum.return_value = MagicMock() - mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock() - mock_inputs = MagicMock() - mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock() - huggingface_llm_model.tokenizer.return_value = mock_inputs + mock_hidden_state = MagicMock() + mock_outputs.hidden_states = [None, None, mock_hidden_state] huggingface_llm_model.model.return_value = mock_outputs - expected_result = [0.1, 0.2, 0.3] - mock_embeddings_batch = MagicMock() - mock_first_embedding = MagicMock() - mock_cpu_tensor = MagicMock() - mock_numpy_array = MagicMock() - mock_numpy_array.tolist.return_value = expected_result - mock_embeddings_batch.__getitem__.return_value = mock_first_embedding - mock_first_embedding.cpu.return_value = mock_cpu_tensor - mock_cpu_tensor.numpy.return_value = mock_numpy_array - mock_masked_hidden_states = MagicMock() - mock_sum_hidden_states = MagicMock() - mock_num_tokens = MagicMock() - mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states - mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states - mock_attention_mask.sum.return_value = mock_num_tokens - mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch - - result = huggingface_llm_model.create_embeddings("Alright") - - huggingface_llm_model.model.eval.assert_called_once() - huggingface_llm_model.tokenizer.assert_called_once_with( - "Alright", - add_special_tokens=False, - return_tensors="pt", - padding=True, - truncation=True - ) - huggingface_llm_model.model.assert_called_once_with( - **mock_inputs, - output_hidden_states=True + mock_chunk_embedding = MagicMock() + mock_final_embedding = MagicMock() + mock_normalised = MagicMock() + mock_concatenated = MagicMock() + mock_cat.return_value = mock_concatenated + mock_mean.return_value = mock_final_embedding + mock_normalise.return_value = mock_normalised + mock_normalised.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] + mock_tensor.side_effect = tensor_side_effect + mock_masked = MagicMock() + mock_summed = MagicMock() + mock_hidden_state.__mul__.return_value = mock_masked + mock_masked.sum.return_value = mock_summed + mock_summed.__truediv__.return_value = mock_chunk_embedding + + result = huggingface_llm_model.create_embeddings( + "This is a long text that should be chunked into multiple pieces" ) - - assert result is not None + assert huggingface_llm_model.model.call_count >= 3 + mock_cat.assert_called_once() + mock_mean.assert_called_once() + assert result == [0.1, 0.2, 0.3] + + +@patch("torch.nn.functional.normalize") +@patch("torch.mean") +@patch("torch.cat") +@patch("torch.tensor") +def test_create_embeddings_list_text(mock_tensor, mock_cat, mock_mean, mock_normalise, huggingface_llm_model): + def tokenizer_side_effect(text, **kwargs): + if isinstance(text, list): + return { + "input_ids": [list(range(10)), list(range(15))], + "attention_mask": [[1]*10, [1]*15] + } + else: + return { + "input_ids": list(range(len(text.split()))), + "attention_mask": [1] * len(text.split()) + } + + def tensor_side_effect(*args, **kwargs): + result = MagicMock() + result.to.return_value = result + return result -def test_create_embeddings_list_text(huggingface_llm_model): huggingface_llm_model.init_model() huggingface_llm_model.model = MagicMock() + huggingface_llm_model.model.config.max_position_embeddings = 6 huggingface_llm_model.tokenizer = MagicMock() - mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()] + huggingface_llm_model.tokenizer.side_effect = tokenizer_side_effect mock_outputs = MagicMock() - mock_outputs.hidden_states = mock_hidden_states - mock_last_hidden_state = MagicMock() - mock_last_hidden_state.shape = [2, 3, 768] - mock_hidden_states[-1] = mock_last_hidden_state - mock_attention_mask = MagicMock() - mock_attention_mask.shape = [2, 3] - mock_attention_mask.sum.return_value = MagicMock() - mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock() - mock_inputs = MagicMock() - mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock() - huggingface_llm_model.tokenizer.return_value = mock_inputs - huggingface_llm_model.model.return_value = mock_outputs - mock_embeddings_batch = MagicMock() - mock_first_embedding = MagicMock() - mock_cpu_tensor = MagicMock() - mock_numpy_array = MagicMock() - mock_numpy_array.tolist.return_value = [[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]] - mock_embeddings_batch.__getitem__.return_value = mock_first_embedding - mock_first_embedding.cpu.return_value = mock_cpu_tensor - mock_cpu_tensor.numpy.return_value = mock_numpy_array - mock_masked_hidden_states = MagicMock() - mock_sum_hidden_states = MagicMock() - mock_num_tokens = MagicMock() - mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states - mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states - mock_attention_mask.sum.return_value = mock_num_tokens - mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch - - result = huggingface_llm_model.create_embeddings(["Alright", "Alright"]) - - huggingface_llm_model.tokenizer.assert_called_once_with( - ["Alright", "Alright"], - add_special_tokens=False, - return_tensors="pt", - padding=True, - truncation=True, - ) - assert result is not None + mock_hidden_state = MagicMock() + mock_outputs.hidden_states = [None, None, mock_hidden_state] + huggingface_llm_model.model.return_value = mock_outputs + mock_chunk_embedding = MagicMock() + mock_final_embedding = MagicMock() + mock_normalised = MagicMock() + mock_concatenated = MagicMock() + mock_cat.return_value = mock_concatenated + mock_mean.return_value = mock_final_embedding + mock_normalise.return_value = mock_normalised + mock_normalised.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] + mock_tensor.side_effect = tensor_side_effect + mock_masked = MagicMock() + mock_summed = MagicMock() + mock_hidden_state.__mul__.return_value = mock_masked + mock_masked.sum.return_value = mock_summed + mock_summed.__truediv__.return_value = mock_chunk_embedding + + result = huggingface_llm_model.create_embeddings([ + "Alright?", + "This is a long text that should be chunked into multiple pieces", + ]) + + assert huggingface_llm_model.model.call_count >= 4 + assert mock_cat.call_count == 2 + assert mock_mean.call_count == 2 + assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]] diff --git a/tests/app/model_services/test_huggingface_ner_model.py b/tests/app/model_services/test_huggingface_ner_model.py index f617979..44779da 100644 --- a/tests/app/model_services/test_huggingface_ner_model.py +++ b/tests/app/model_services/test_huggingface_ner_model.py @@ -99,3 +99,20 @@ def test_train_supervised(huggingface_ner_model): with tempfile.TemporaryFile("r+") as f: huggingface_ner_model.train_supervised(f, 1, 1, "training_id", "input_file_name") huggingface_ner_model._supervised_trainer.train.assert_called() + + +def test_create_embeddings(huggingface_ner_model): + huggingface_ner_model.init_model() + + text = "Spinal stenosis" + embedding = huggingface_ner_model.create_embeddings(text) + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, float) for x in embedding) + + texts = ["Spinal stenosis", "Diabetes"] + embeddings = huggingface_ner_model.create_embeddings(texts) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + assert all(isinstance(emb, list) for emb in embeddings) + assert all(len(emb) > 0 for emb in embeddings) diff --git a/tests/app/model_services/test_medcat_model_deid.py b/tests/app/model_services/test_medcat_model_deid.py index 59f6b1d..8c16d90 100644 --- a/tests/app/model_services/test_medcat_model_deid.py +++ b/tests/app/model_services/test_medcat_model_deid.py @@ -186,3 +186,27 @@ def test_train_supervised(medcat_deid_model): with tempfile.TemporaryFile("r+") as f: medcat_deid_model.train_supervised(f, 1, 1, "training_id", "input_file_name") medcat_deid_model._supervised_trainer.train.assert_called() + + +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) +def test_create_embeddings(medcat_umls_model): + medcat_umls_model.init_model() + + embedding = medcat_umls_model.create_embeddings("This is a post code NW1 2DA") + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, float) for x in embedding) + + embeddings = medcat_umls_model.create_embeddings([ + "This is a post code NW1 2DA", + "This is a post code NW1 2DB", + ]) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for emb in embeddings: + assert isinstance(emb, list) + assert len(emb) > 0 + assert all(isinstance(x, float) for x in emb) diff --git a/tests/app/model_services/test_medcat_model_icd10.py b/tests/app/model_services/test_medcat_model_icd10.py index b1f6bcd..c019c45 100644 --- a/tests/app/model_services/test_medcat_model_icd10.py +++ b/tests/app/model_services/test_medcat_model_icd10.py @@ -101,7 +101,7 @@ def test_annotate(medcat_icd10_model): medcat_icd10_model.init_model() annotations = medcat_icd10_model.annotate("Spinal stenosis") assert len(annotations) == 1 - assert type(annotations[0]["label_name"]) is str + assert type(annotations[0].label_name) is str assert annotations[0].start == 0 assert annotations[0].end == 15 assert annotations[0].accuracy > 0 @@ -133,3 +133,24 @@ def test_train_unsupervised(medcat_icd10_model): with tempfile.TemporaryFile("r+") as f: medcat_icd10_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name") medcat_icd10_model._unsupervised_trainer.train.assert_called() + + +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) +def test_create_embeddings(medcat_umls_model): + medcat_umls_model.init_model() + + embedding = medcat_umls_model.create_embeddings("Spinal stenosis") + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, float) for x in embedding) + + embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"]) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for emb in embeddings: + assert isinstance(emb, list) + assert len(emb) > 0 + assert all(isinstance(x, float) for x in emb) diff --git a/tests/app/model_services/test_medcat_model_opcs4.py b/tests/app/model_services/test_medcat_model_opcs4.py index 12b9d0d..5fcd932 100644 --- a/tests/app/model_services/test_medcat_model_opcs4.py +++ b/tests/app/model_services/test_medcat_model_opcs4.py @@ -101,7 +101,7 @@ def test_annotate(medcat_opcs4_model): medcat_opcs4_model.init_model() annotations = medcat_opcs4_model.annotate("Spinal tap") assert len(annotations) == 1 - assert type(annotations[0]["label_name"]) is str + assert type(annotations[0].label_name) is str assert annotations[0].start == 0 assert annotations[0].end == 10 assert annotations[0].accuracy > 0 @@ -133,3 +133,24 @@ def test_train_unsupervised(medcat_opcs4_model): with tempfile.TemporaryFile("r+") as f: medcat_opcs4_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name") medcat_opcs4_model._unsupervised_trainer.train.assert_called() + + +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "opcs4_model.zip")), + reason="requires the model file to be present in the resources folder", +) +def test_create_embeddings(medcat_umls_model): + medcat_umls_model.init_model() + + embedding = medcat_umls_model.create_embeddings("Spinal stenosis") + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, float) for x in embedding) + + embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"]) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for emb in embeddings: + assert isinstance(emb, list) + assert len(emb) > 0 + assert all(isinstance(x, float) for x in emb) diff --git a/tests/app/model_services/test_medcat_model_snomed.py b/tests/app/model_services/test_medcat_model_snomed.py index 4928660..b4a4ae7 100644 --- a/tests/app/model_services/test_medcat_model_snomed.py +++ b/tests/app/model_services/test_medcat_model_snomed.py @@ -98,7 +98,7 @@ def test_annotate(medcat_snomed_model): medcat_snomed_model.init_model() annotations = medcat_snomed_model.annotate("Spinal stenosis") assert len(annotations) == 1 - assert type(annotations[0]["label_name"]) is str + assert type(annotations[0].label_name) is str assert annotations[0].start == 0 assert annotations[0].end == 15 assert annotations[0].accuracy > 0 @@ -130,3 +130,24 @@ def test_train_unsupervised(medcat_snomed_model): with tempfile.TemporaryFile("r+") as f: medcat_snomed_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name") medcat_snomed_model._unsupervised_trainer.train.assert_called() + + +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) +def test_create_embeddings(medcat_umls_model): + medcat_umls_model.init_model() + + embedding = medcat_umls_model.create_embeddings("Spinal stenosis") + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, float) for x in embedding) + + embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"]) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for emb in embeddings: + assert isinstance(emb, list) + assert len(emb) > 0 + assert all(isinstance(x, float) for x in emb) diff --git a/tests/app/model_services/test_medcat_model_umls.py b/tests/app/model_services/test_medcat_model_umls.py index 2f9ff15..71987f6 100644 --- a/tests/app/model_services/test_medcat_model_umls.py +++ b/tests/app/model_services/test_medcat_model_umls.py @@ -61,7 +61,7 @@ def test_annotate(medcat_umls_model): medcat_umls_model.init_model() annotations = medcat_umls_model.annotate("Spinal stenosis") assert len(annotations) == 1 - assert type(annotations[0]["label_name"]) is str + assert type(annotations[0].label_name) is str assert annotations[0].start == 0 assert annotations[0].end == 15 assert annotations[0].accuracy > 0 @@ -93,3 +93,24 @@ def test_train_unsupervised(medcat_umls_model): with tempfile.TemporaryFile("r+") as f: medcat_umls_model.train_unsupervised(f, 1, 1, "training_id", "input_file_name") medcat_umls_model._unsupervised_trainer.train.assert_called() + + +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) +def test_create_embeddings(medcat_umls_model): + medcat_umls_model.init_model() + + embedding = medcat_umls_model.create_embeddings("Spinal stenosis") + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, float) for x in embedding) + + embeddings = medcat_umls_model.create_embeddings(["Spinal stenosis", "Diabetes"]) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for emb in embeddings: + assert isinstance(emb, list) + assert len(emb) > 0 + assert all(isinstance(x, float) for x in emb)