Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions app/api/routers/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import hashlib
import logging
import pandas as pd
from fastapi.encoders import jsonable_encoder

import app.api.globals as cms_globals

from typing import Dict, List, Union, Iterator, Any
from collections import defaultdict
from io import BytesIO
from starlette.status import HTTP_400_BAD_REQUEST
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from typing_extensions import Annotated
from fastapi import APIRouter, Depends, Body, UploadFile, File, Request, Query, Response
from fastapi.responses import StreamingResponse, PlainTextResponse, JSONResponse
Expand All @@ -22,7 +24,7 @@
TextWithAnnotations,
TextWithPublicKey,
TextStreamItem,
Tags,
Tags, OpenAIEmbeddingsRequest, OpenAIEmbeddingsResponse,
)
from app.model_services.base import AbstractModelService
from app.utils import get_settings, load_pydantic_object_from_dict
Expand All @@ -43,6 +45,7 @@
PATH_PROCESS_BULK_FILE = "/process_bulk_file"
PATH_REDACT = "/redact"
PATH_REDACT_WITH_ENCRYPTION = "/redact_with_encryption"
PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"

router = APIRouter()
config = get_settings()
Expand Down Expand Up @@ -355,6 +358,93 @@ def get_redacted_text_with_encryption(
return JSONResponse(content=content)


@router.post(
PATH_OPENAI_EMBEDDINGS,
tags=[Tags.OpenAICompatible.name],
response_model=None,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint",
)
def embed_texts(
request: Request,
request_data: Annotated[OpenAIEmbeddingsRequest, Body(
description="Text(s) to be embedded", media_type="application/json"
)],
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> JSONResponse:
"""
Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint.

Args:
request (Request): The request object.
request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s).
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
model_service (AbstractModelService): The model service dependency.

Returns:
JSONResponse: A response containing the embeddings of the text(s).
"""
tracking_id = tracking_id or str(uuid.uuid4())

if not hasattr(model_service, "create_embeddings"):
error_response = {
"error": {
"message": "Model does not support embeddings",
"type": "invalid_request_error",
"param": "model",
"code": "model_not_supported",
}
}
return JSONResponse(
content=error_response,
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
headers={"x-cms-tracking-id": tracking_id},
)

input_text = request_data.input
model = model_service.model_name if request_data.model != model_service.model_name else request_data.model

if isinstance(input_text, str):
input_texts = [input_text]
else:
input_texts = input_text

try:
embeddings_data = []

for i, embedding in enumerate(model_service.create_embeddings(input_texts)):
embeddings_data.append({
"object": "embedding",
"embedding": embedding,
"index": i,
})

response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model)

return JSONResponse(
content=jsonable_encoder(response),
headers={"x-cms-tracking-id": tracking_id},
)

except Exception as e:
logger.error("Failed to create embeddings")
logger.exception(e)
error_response = {
"error": {
"message": f"Failed to create embeddings: {str(e)}",
"type": "server_error",
"code": "internal_error",
}
}
return JSONResponse(
content=error_response,
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
headers={"x-cms-tracking-id": tracking_id},
)



def _send_annotation_num_metric(annotation_num: int, handler: str) -> None:
cms_doc_annotations.labels(handler=handler).observe(annotation_num)

Expand Down
41 changes: 39 additions & 2 deletions app/api/routers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from starlette.types import Receive, Scope, Send
from starlette.background import BackgroundTask
from fastapi import APIRouter, Depends, Request, Response, WebSocket, WebSocketException
from pydantic import ValidationError
from pydantic import ValidationError, BaseModel
from app.domain import Tags, TextStreamItem
from app.model_services.base import AbstractModelService
from app.utils import get_settings
Expand All @@ -20,7 +20,6 @@

PATH_STREAM_PROCESS = "/process"
PATH_WS = "/ws"
PATH_GENERATE= "/generate"

router = APIRouter()
config = get_settings()
Expand Down Expand Up @@ -57,6 +56,22 @@ async def get_entities_stream_from_jsonlines_stream(
return _LocalStreamingResponse(annotation_stream, media_type="application/x-ndjson; charset=utf-8")


@router.get(
PATH_WS,
tags=[Tags.Annotations.name],
dependencies=[Depends(cms_globals.props.current_active_user)],
description="WebSocket info endpoint for real-time NER entity extraction. Use ws://host:port/stream/ws to establish an actual WebSocket connection.",
include_in_schema=True,
)
async def get_inline_annotations_from_websocket_info() -> "_WebSocketInfo":
"""
Information about the WebSocket endpoint for real-time NER entity extraction.

This endpoint provides documentation for the WebSocket connection available at the same path.
Connect to ws://host:port/stream/ws and send texts to retrieve annotated results.
"""
return _WebSocketInfo()

@router.websocket(PATH_WS)
# @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) # Not supported yet
async def get_inline_annotations_from_websocket(
Expand Down Expand Up @@ -189,6 +204,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.background()


class _WebSocketInfo(BaseModel):
message: str = "WebSocket endpoint for real-time NER entity extraction"
example: str = """<form action="" onsubmit="send_doc(event)">
<input type="text" id="cms-input" autocomplete="off"/>
<button>Send</button>
</form>
<ul id="cms-output"></ul>
<script>
var ws = new WebSocket("ws://localhost:8000/stream/ws");
ws.onmessage = function(event) {
document.getElementById("cms-output").appendChild(
Object.assign(document.createElement('li'), { textContent: event.data })
);
};
function send_doc(event) {
ws.send(document.getElementById("cms-input").value);
event.preventDefault();
};
</script>"""
protocol: str = "WebSocket"


async def _annotation_async_gen(request: Request, model_service: AbstractModelService) -> AsyncGenerator:
try:
buffer = ""
Expand Down
111 changes: 89 additions & 22 deletions app/model_services/huggingface_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from app import __version__ as app_version
from app.exception import ConfigurationException
from app.model_services.base import AbstractModelService
from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer
from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer, HuggingFaceLlmUnsupervisedTrainer
from app.domain import ModelCard, ModelType, Annotation, Device
from app.config import Settings
from app.utils import (
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
self._multi_label_threshold = 0.5
self._text_generator = ThreadPoolExecutor(max_workers=50)
self.model_name = model_name or "HuggingFace LLM model"
self.is_4bit_quantised = False

@property
def model(self) -> PreTrainedModel:
Expand Down Expand Up @@ -206,6 +207,8 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N
self._model.to(get_settings().DEVICE)
if self._enable_trainer:
self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self)
self._unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(self)
self.is_4bit_quantised = load_in_4bit

def info(self) -> ModelCard:
"""
Expand Down Expand Up @@ -396,29 +399,47 @@ def create_embeddings(

self.model.eval()

inputs = self.tokenizer(
text,
add_special_tokens=False,
return_tensors="pt",
padding=True,
truncation=True,
)

inputs.to(self.model.device)

with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
texts = [text] if isinstance(text, str) else text
all_embeddings = []

for txt in texts:
inputs = self.tokenizer(txt, add_special_tokens=False, truncation=False, padding=False)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
window_size = max(self.model.config.max_position_embeddings - 2, 1)
stride = window_size
chunk_embeddings = []

for start in range(0, len(input_ids), stride):
end = min(start + window_size, len(input_ids))
chunk_inputs = {
"input_ids": torch.tensor(
[input_ids[start:end]], dtype=torch.long
).to(self.model.device),
"attention_mask": torch.tensor(
[attention_mask[start:end]], dtype=torch.long
).to(self.model.device),
}

with torch.no_grad():
outputs = self.model(**chunk_inputs, output_hidden_states=True)

last_hidden_state = outputs.hidden_states[-1]
chunk_attention_mask = chunk_inputs["attention_mask"]
masked_hidden_states = last_hidden_state * chunk_attention_mask.unsqueeze(-1)
sum_hidden_states = masked_hidden_states.sum(dim=1)
num_tokens = chunk_attention_mask.sum(dim=1, keepdim=True)
chunk_embedding = sum_hidden_states / num_tokens
chunk_embeddings.append(chunk_embedding)

if end >= len(input_ids):
break

last_hidden_state = outputs.hidden_states[-1]
attention_mask = inputs["attention_mask"]
masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1)
sum_hidden_states = masked_hidden_states.sum(dim=1)
num_tokens = attention_mask.sum(dim=1, keepdim=True)
embeddings = sum_hidden_states / num_tokens
l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1)
final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True)
l2_normalised = torch.nn.functional.normalize(final_embedding, p=2, dim=1)
all_embeddings.append(l2_normalised.cpu().numpy().tolist()[0])

results = l2_normalised.cpu().numpy().tolist()
return results[0] if isinstance(text, str) else results
return all_embeddings[0] if isinstance(text, str) else all_embeddings

def train_supervised(
self,
Expand Down Expand Up @@ -465,3 +486,49 @@ def train_supervised(
synchronised,
**hyperparams,
)

def train_unsupervised(
self,
data_file: TextIO,
epochs: int,
log_frequency: int,
training_id: str,
input_file_name: str,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any],
) -> Tuple[bool, str, str]:
"""
Initiates unsupervised training on the model.

Args:
data_file (TextIO): The file containing a JSON list of texts.
epochs (int): The number of training epochs.
log_frequency (int): The number of epochs after which training metrics will be logged.
training_id (str): A unique identifier for the training process.
input_file_name (str): The name of the input file to be logged.
raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None.
description (Optional[str]): The description of the training or change logs. Defaults to empty.
synchronised (bool): Whether to wait for the training to complete.
**hyperparams (Dict[str, Any]): Additional hyperparameters for training.

Returns:
Tuple[bool, str, str]: A tuple with the first element indicating success or failure.

Raises:
ConfigurationException: If the unsupervised trainer is not enabled.
"""
if self._unsupervised_trainer is None:
raise ConfigurationException("The unsupervised trainer is not enabled")
return self._unsupervised_trainer.train(
data_file,
epochs,
log_frequency,
training_id,
input_file_name,
raw_data_files,
description,
synchronised,
**hyperparams,
)
Loading