diff --git a/pyproject.toml b/pyproject.toml index f06c6d5..d0b3982 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.33" +version = "1.5.34" authors = ["Together AI "] description = "Python client for Together's Cloud Platform! Note: SDK 2.0 is now available at https://github.com/togethercomputer/together-py" readme = "README.md" diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index a1377c8..6aebbd5 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import re from datetime import datetime, timezone from textwrap import wrap from typing import Any, Literal @@ -14,18 +13,11 @@ from together import Together from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, generate_progress_bar -from together.types.finetune import ( - DownloadCheckpointType, - FinetuneEventType, - FinetuneTrainingLimits, - FullTrainingType, - LoRATrainingType, -) +from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits from together.utils import ( finetune_price_to_dollars, format_timestamp, log_warn, - log_warn_once, parse_timestamp, ) @@ -258,6 +250,7 @@ def create( lora_dropout: float, lora_alpha: float, lora_trainable_modules: str, + train_vision: bool, suffix: str, wandb_api_key: str, wandb_base_url: str, @@ -299,6 +292,7 @@ def create( lora_dropout=lora_dropout, lora_alpha=lora_alpha, lora_trainable_modules=lora_trainable_modules, + train_vision=train_vision, suffix=suffix, wandb_api_key=wandb_api_key, wandb_base_url=wandb_base_url, @@ -368,6 +362,10 @@ def create( "You have specified a number of evaluation loops but no validation file." ) + if model_limits.supports_vision: + # Don't show price estimation for multimodal models yet + confirm = True + finetune_price_estimation_result = client.fine_tuning.estimate_price( training_file=training_file, validation_file=validation_file, diff --git a/src/together/constants.py b/src/together/constants.py index 5e0c912..36132f6 100644 --- a/src/together/constants.py +++ b/src/together/constants.py @@ -1,5 +1,6 @@ import enum + # Session constants TIMEOUT_SECS = 600 MAX_SESSION_LIFETIME_SECS = 180 @@ -40,6 +41,11 @@ # the number of bytes in a gigabyte, used to convert bytes to GB for readable comparison NUM_BYTES_IN_GB = 2**30 +# Multimodal limits +MAX_IMAGES_PER_EXAMPLE = 10 +MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB +# Max length = Header length + base64 factor (4/3) * image bytes +MAX_BASE64_IMAGE_LENGTH = len("data:image/jpeg;base64,") + 4 * MAX_IMAGE_BYTES // 3 # expected columns for Parquet files PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"] diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 7cd1eb0..2b21f59 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from typing import List, Dict, Literal +from typing import Dict, List, Literal from rich import print as rprint @@ -18,10 +18,11 @@ FinetuneList, FinetuneListEvents, FinetuneLRScheduler, - FinetuneRequest, - FinetuneResponse, + FinetuneMultimodalParams, FinetunePriceEstimationRequest, FinetunePriceEstimationResponse, + FinetuneRequest, + FinetuneResponse, FinetuneTrainingLimits, FullTrainingType, LinearLRScheduler, @@ -73,6 +74,7 @@ def create_finetune_request( lora_dropout: float | None = 0, lora_alpha: float | None = None, lora_trainable_modules: str | None = "all-linear", + train_vision: bool = False, suffix: str | None = None, wandb_api_key: str | None = None, wandb_base_url: str | None = None, @@ -252,6 +254,15 @@ def create_finetune_request( simpo_gamma=simpo_gamma, ) + if model_limits.supports_vision: + multimodal_params = FinetuneMultimodalParams(train_vision=train_vision) + elif train_vision: + raise ValueError( + f"Vision encoder training is not supported for the non-multimodal model `{model}`" + ) + else: + multimodal_params = None + finetune_request = FinetuneRequest( model=model, training_file=training_file, @@ -272,6 +283,7 @@ def create_finetune_request( wandb_project_name=wandb_project_name, wandb_name=wandb_name, training_method=training_method_cls, + multimodal_params=multimodal_params, from_checkpoint=from_checkpoint, from_hf_model=from_hf_model, hf_model_revision=hf_model_revision, @@ -342,6 +354,7 @@ def create( lora_dropout: float | None = 0, lora_alpha: float | None = None, lora_trainable_modules: str | None = "all-linear", + train_vision: bool = False, suffix: str | None = None, wandb_api_key: str | None = None, wandb_base_url: str | None = None, @@ -387,6 +400,7 @@ def create( lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0. lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8. lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear". + train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False. suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name. Defaults to None. wandb_api_key (str, optional): API key for Weights & Biases integration. @@ -464,6 +478,7 @@ def create( lora_dropout=lora_dropout, lora_alpha=lora_alpha, lora_trainable_modules=lora_trainable_modules, + train_vision=train_vision, suffix=suffix, wandb_api_key=wandb_api_key, wandb_base_url=wandb_base_url, @@ -906,6 +921,7 @@ async def create( lora_dropout: float | None = 0, lora_alpha: float | None = None, lora_trainable_modules: str | None = "all-linear", + train_vision: bool = False, suffix: str | None = None, wandb_api_key: str | None = None, wandb_base_url: str | None = None, @@ -951,6 +967,7 @@ async def create( lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0. lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8. lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear". + train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False. suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name. Defaults to None. wandb_api_key (str, optional): API key for Weights & Biases integration. @@ -1028,6 +1045,7 @@ async def create( lora_dropout=lora_dropout, lora_alpha=lora_alpha, lora_trainable_modules=lora_trainable_modules, + train_vision=train_vision, suffix=suffix, wandb_api_key=wandb_api_key, wandb_base_url=wandb_base_url, diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index 61c054a..351f2a1 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -7,17 +7,18 @@ AudioSpeechStreamChunk, AudioSpeechStreamEvent, AudioSpeechStreamResponse, + AudioTimestampGranularities, AudioTranscriptionRequest, - AudioTranslationRequest, AudioTranscriptionResponse, + AudioTranscriptionResponseFormat, AudioTranscriptionVerboseResponse, + AudioTranslationRequest, AudioTranslationResponse, AudioTranslationVerboseResponse, - AudioTranscriptionResponseFormat, - AudioTimestampGranularities, ModelVoices, VoiceListResponse, ) +from together.types.batch import BatchEndpoint, BatchJob, BatchJobStatus from together.types.chat_completions import ( ChatCompletionChunk, ChatCompletionRequest, @@ -31,6 +32,19 @@ ) from together.types.embeddings import EmbeddingRequest, EmbeddingResponse from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint +from together.types.evaluation import ( + ClassifyParameters, + CompareParameters, + EvaluationCreateResponse, + EvaluationJob, + EvaluationRequest, + EvaluationStatus, + EvaluationStatusResponse, + EvaluationType, + JudgeModelConfig, + ModelRequest, + ScoreParameters, +) from together.types.files import ( FileDeleteResponse, FileList, @@ -41,49 +55,32 @@ FileType, ) from together.types.finetune import ( - TrainingMethodDPO, - TrainingMethodSFT, - FinetuneCheckpoint, CosineLRScheduler, CosineLRSchedulerArgs, + FinetuneCheckpoint, + FinetuneDeleteResponse, FinetuneDownloadResult, - LinearLRScheduler, - LinearLRSchedulerArgs, - FinetuneLRScheduler, FinetuneList, FinetuneListEvents, - FinetuneRequest, - FinetuneResponse, + FinetuneLRScheduler, + FinetuneMultimodalParams, FinetunePriceEstimationRequest, FinetunePriceEstimationResponse, - FinetuneDeleteResponse, + FinetuneRequest, + FinetuneResponse, FinetuneTrainingLimits, FullTrainingType, + LinearLRScheduler, + LinearLRSchedulerArgs, LoRATrainingType, + TrainingMethodDPO, + TrainingMethodSFT, TrainingType, ) from together.types.images import ImageRequest, ImageResponse from together.types.models import ModelObject, ModelUploadRequest, ModelUploadResponse from together.types.rerank import RerankRequest, RerankResponse -from together.types.batch import BatchJob, BatchJobStatus, BatchEndpoint -from together.types.evaluation import ( - EvaluationType, - EvaluationStatus, - JudgeModelConfig, - ModelRequest, - ClassifyParameters, - ScoreParameters, - CompareParameters, - EvaluationRequest, - EvaluationCreateResponse, - EvaluationJob, - EvaluationStatusResponse, -) -from together.types.videos import ( - CreateVideoBody, - CreateVideoResponse, - VideoJob, -) +from together.types.videos import CreateVideoBody, CreateVideoResponse, VideoJob __all__ = [ @@ -131,6 +128,7 @@ "RerankRequest", "RerankResponse", "FinetuneTrainingLimits", + "FinetuneMultimodalParams", "AudioSpeechRequest", "AudioResponseFormat", "AudioLanguage", diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 0eb7402..25607a3 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -1,14 +1,12 @@ from __future__ import annotations from enum import Enum -from typing import List, Literal, Any +from typing import Any, List, Literal from pydantic import Field, StrictBool, field_validator from together.types.abstract import BaseModel -from together.types.common import ( - ObjectType, -) +from together.types.common import ObjectType class FinetuneJobStatus(str, Enum): @@ -175,6 +173,14 @@ class TrainingMethodDPO(TrainingMethod): simpo_gamma: float | None = None +class FinetuneMultimodalParams(BaseModel): + """ + Multimodal parameters + """ + + train_vision: bool = False + + class FinetuneProgress(BaseModel): """ Fine-tune job progress @@ -231,6 +237,8 @@ class FinetuneRequest(BaseModel): ) # from step from_checkpoint: str | None = None + # multimodal parameters + multimodal_params: FinetuneMultimodalParams | None = None # hf related fields hf_api_token: str | None = None hf_output_repo_name: str | None = None @@ -409,6 +417,7 @@ class FinetuneTrainingLimits(BaseModel): min_learning_rate: float full_training: FinetuneFullTrainingLimits | None = None lora_training: FinetuneLoraTrainingLimits | None = None + supports_vision: bool = False class LinearLRSchedulerArgs(BaseModel): diff --git a/src/together/utils/files.py b/src/together/utils/files.py index 3734753..8f361fa 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -1,8 +1,8 @@ from __future__ import annotations +import csv import json import os -import csv from pathlib import Path from traceback import format_exc from typing import Any, Dict, List @@ -10,18 +10,30 @@ from tqdm import tqdm from together.constants import ( + JSONL_REQUIRED_COLUMNS_MAP, + MAX_BASE64_IMAGE_LENGTH, MAX_FILE_SIZE_GB, + MAX_IMAGES_PER_EXAMPLE, MIN_SAMPLES, NUM_BYTES_IN_GB, PARQUET_EXPECTED_COLUMNS, - JSONL_REQUIRED_COLUMNS_MAP, - REQUIRED_COLUMNS_MESSAGE, POSSIBLE_ROLES_CONVERSATION, + REQUIRED_COLUMNS_MESSAGE, DatasetFormat, ) from together.types import FilePurpose +# MessageContent is a string or a list of dicts with 'type': 'text' or 'image_url', and 'text' or 'image_url.url' +# Example: "Hello" or [ +# {"type": "text", "text": "Hello"}, +# {"type": "image_url", "image_url": { +# "url": "data:image/jpeg;base64,..." +# }} +# ] +MessageContent = str | list[dict[str, Any]] + + class InvalidFileFormatError(ValueError): """Exception raised for invalid file formats during file checks.""" @@ -70,7 +82,7 @@ def check_file( if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB: report_dict["message"] = ( - f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB." + f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB, 3)} GB." ) report_dict["is_check_passed"] = False elif file_size == 0: @@ -103,7 +115,9 @@ def check_file( return report_dict -def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> None: +def _check_conversation_type( + messages: List[Dict[str, str | int | MessageContent]], idx: int +) -> None: """Check that the conversation has correct type. Args: @@ -145,12 +159,6 @@ def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> line_number=idx + 1, error_source="key_value", ) - if not isinstance(message[column], str): - raise InvalidFileFormatError( - message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}", - line_number=idx + 1, - error_source="text_field", - ) def _check_conversation_roles( @@ -175,7 +183,9 @@ def _check_conversation_roles( ) -def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None: +def _check_message_weight( + message: Dict[str, str | int | MessageContent], idx: int +) -> int | None: """Check that the message has a weight with the correct type and value. Args: @@ -199,11 +209,14 @@ def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None: line_number=idx + 1, error_source="key_value", ) + return weight + + return None def _check_message_role( - message: Dict[str, str | bool], previous_role: str | None, idx: int -) -> str | bool: + message: Dict[str, str | int | MessageContent], previous_role: str | None, idx: int +) -> str: """Check that the message has correct roles. Args: @@ -217,6 +230,14 @@ def _check_message_role( Raises: InvalidFileFormatError: If the message role is invalid. """ + if not isinstance(message["role"], str): + raise InvalidFileFormatError( + message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. " + f"Role must be a string. Found {type(message['role'])}", + line_number=idx + 1, + error_source="key_value", + ) + if message["role"] not in POSSIBLE_ROLES_CONVERSATION: raise InvalidFileFormatError( message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. " @@ -234,8 +255,133 @@ def _check_message_role( return message["role"] +def _check_message_content( + message_content: str | int | MessageContent, role: str, idx: int +) -> tuple[bool, int]: + """Check that the message content has the correct type. + Message content can be either a) a string or b) an OpenAI-style multimodal list of content items + Example: + a) "Hello", or + b) [ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": { + "url": "data:image/jpeg;base64,..." + }} + ] + + Args: + message: The message to check. + idx: Line number in the file. + + Returns: + tuple[bool, int]: A tuple with message is multimodal and the number of images in the message content. + """ + # Text-only message content + if isinstance(message_content, str): + return False, 0 + + # Multimodal message content + if isinstance(message_content, list): + num_images = 0 + for item in message_content: + if not isinstance(item, dict): + raise InvalidFileFormatError( + "The dataset is malformed, the `content` field must be a list of dicts.", + line_number=idx + 1, + error_source="key_value", + ) + if "type" not in item: + raise InvalidFileFormatError( + "The dataset is malformed, the `content` field must be a list of dicts with a `type` field.", + line_number=idx + 1, + error_source="key_value", + ) + + if item["type"] == "text": + if "text" not in item or not isinstance(item["text"], str): + raise InvalidFileFormatError( + "The dataset is malformed, the `text` field must be present in the `content` item field and be" + f" a string. Got '{item.get('text')!r}' instead.", + line_number=idx + 1, + error_source="key_value", + ) + elif item["type"] == "image_url": + if role != "user": + raise InvalidFileFormatError( + "The dataset is malformed, only user messages can contain images.", + line_number=idx + 1, + error_source="key_value", + ) + + if "image_url" not in item or not isinstance(item["image_url"], dict): + raise InvalidFileFormatError( + "The dataset is malformed, the `image_url` field must be present in the `content` field and " + f"be a dictionary. Got {item.get('image_url')!r} instead.", + line_number=idx + 1, + error_source="key_value", + ) + + image_data = item["image_url"].get("url") + if not image_data or not isinstance(image_data, str): + raise InvalidFileFormatError( + "The dataset is malformed, the `url` field must be present in the `image_url` field and be " + f"a string. Got {image_data!r} instead.", + line_number=idx + 1, + error_source="key_value", + ) + + if not any( + image_data.startswith(f"data:image/{fmt};base64,") + for fmt in ["jpeg", "png", "webp"] + ): + raise InvalidFileFormatError( + "The dataset is malformed, the `url` field must be either a JPEG, PNG or WEBP base64-encoded " + "image in 'data:image/;base64,' format. " + f"Got '{image_data[:100]}...' instead.", + line_number=idx + 1, + ) + + if len(image_data) > MAX_BASE64_IMAGE_LENGTH: + raise InvalidFileFormatError( + "The dataset is malformed, the `url` field must contain base64-encoded image " + f"that is less than 10MB, found ~{len(image_data) * 3 // 4} bytes.", + line_number=idx + 1, + error_source="key_value", + ) + + num_images += 1 + else: + raise InvalidFileFormatError( + "The dataset is malformed, the `type` field must be either 'text' or 'image_url'. " + f"Got {item['type']!r}.", + line_number=idx + 1, + error_source="key_value", + ) + + if num_images > MAX_IMAGES_PER_EXAMPLE: + raise InvalidFileFormatError( + f"The dataset is malformed, the `content` field must contain at most " + f"{MAX_IMAGES_PER_EXAMPLE} images, found {num_images}.", + line_number=idx + 1, + error_source="key_value", + ) + + # We still consider text-only messages in such format as multimodal, even if they don't have any images + # included - so we can process datasets with rather sparse images (i.e. not in each sample) consistently. + return True, num_images + + raise InvalidFileFormatError( + f"Invalid content type on line {idx + 1} of the input file. Expected string or multimodal list of dicts, " + f"found {type(message_content)}", + line_number=idx + 1, + error_source="key_value", + ) + + def validate_messages( - messages: List[Dict[str, str | bool]], idx: int, require_assistant_role: bool = True + messages: List[Dict[str, str | int | MessageContent]], + idx: int, + require_assistant_role: bool = True, ) -> None: """Validate the messages column. @@ -249,15 +395,45 @@ def validate_messages( """ _check_conversation_type(messages, idx) - has_weights = any("weight" in message for message in messages) previous_role = None assistant_role_exists = False + messages_are_multimodal: bool | None = None + total_number_of_images = 0 + for message in messages: - if has_weights: - _check_message_weight(message, idx) + message_weight = _check_message_weight(message, idx) previous_role = _check_message_role(message, previous_role, idx) assistant_role_exists |= previous_role == "assistant" + is_multimodal, number_of_images = _check_message_content( + message["content"], role=previous_role, idx=idx + ) + # Multimodal validation + if number_of_images > 0 and message_weight is not None and message_weight != 0: + raise InvalidFileFormatError( + "Messages with images cannot have non-zero weights.", + line_number=idx + 1, + error_source="key_value", + ) + if messages_are_multimodal is None: + # Detect the format of the messages in the conversation. + messages_are_multimodal = is_multimodal + elif messages_are_multimodal != is_multimodal: + # Due to the format limitation, we cannot mix multimodal and text only messages in the same sample. + raise InvalidFileFormatError( + "Messages in the conversation must be either all in multimodal or all intext only format.", + line_number=idx + 1, + error_source="key_value", + ) + total_number_of_images += number_of_images + + if total_number_of_images > MAX_IMAGES_PER_EXAMPLE: + raise InvalidFileFormatError( + f"The dataset is malformed, the `messages` must contain at most {MAX_IMAGES_PER_EXAMPLE} images. " + f"Found {total_number_of_images} images.", + line_number=idx + 1, + error_source="key_value", + ) _check_conversation_roles(require_assistant_role, assistant_role_exists, idx) @@ -347,12 +523,7 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: error_source="key_value", ) - if not isinstance(example[key][0]["content"], str): - raise InvalidFileFormatError( - message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.", - line_number=idx + 1, - error_source="key_value", - ) + _check_message_content(example[key][0]["content"], role="assistant", idx=idx) def _check_utf8(file: Path) -> Dict[str, Any]: @@ -454,8 +625,7 @@ def _check_csv(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]: report_dict["load_csv"] = False if idx < 0: report_dict["message"] = ( - "Unable to decode file. " - "File may be empty or in an unsupported format. " + "Unable to decode file. File may be empty or in an unsupported format. " ) else: report_dict["message"] = ( @@ -542,13 +712,10 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]: ) else: for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: - if not isinstance(json_line[column], str): - raise InvalidFileFormatError( - message=f'Invalid value type for "{column}" key on line {idx + 1}. ' - f"Expected string. Found {type(json_line[column])}.", - line_number=idx + 1, - error_source="key_value", - ) + role = "assistant" if column in {"completion"} else "user" + _check_message_content( + json_line[column], role=role, idx=idx + ) if dataset_format is None: dataset_format = current_format @@ -578,8 +745,7 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]: report_dict["load_json"] = False if idx < 0: report_dict["message"] = ( - "Unable to decode file. " - "File may be empty or in an unsupported format. " + "Unable to decode file. File may be empty or in an unsupported format. " ) else: report_dict["message"] = ( diff --git a/tests/unit/test_files_checks.py b/tests/unit/test_files_checks.py index 728452c..92c23c3 100644 --- a/tests/unit/test_files_checks.py +++ b/tests/unit/test_files_checks.py @@ -1,10 +1,8 @@ -import json -import pytest import csv +import json from pathlib import Path -from together.constants import MIN_SAMPLES -from together.utils.files import check_file, FilePurpose +from together.utils.files import FilePurpose, check_file def test_check_jsonl_valid_general(tmp_path: Path): @@ -43,6 +41,39 @@ def test_check_jsonl_valid_instruction(tmp_path: Path): assert report["has_min_samples"] +def test_check_jsonl_valid_instruction_multimodal(tmp_path: Path): + file = tmp_path / "valid_instruction_multimodal.jsonl" + content = [ + { + "prompt": [ + { + "type": "text", + "text": "What's the difference between these two images?", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,..."}, + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,..."}, + }, + ], + "completion": "The first image is a cat, the second image is a dog.", + }, + ] + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert report["is_check_passed"] + assert report["utf8"] + assert report["num_samples"] == len(content) + assert report["has_min_samples"] + + def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path): # Create a valid JSONL file with conversational format and 1 user-assistant turn pair file = tmp_path / "valid_conversational_single_turn.jsonl" @@ -122,6 +153,48 @@ def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path): assert report["has_min_samples"] +def test_check_jsonl_valid_conversational_multimodal_single_turn(tmp_path: Path): + file = tmp_path / "valid_conversational_multimodal_single_turn.jsonl" + content = [ + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's the difference between these two images?", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,..."}, + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,..."}, + }, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "Hi there!"}], + }, + ] + }, + ] + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + print(report) + assert report["is_check_passed"] + assert report["utf8"] + assert report["num_samples"] == len(content) + assert report["has_min_samples"] + + def test_check_jsonl_empty_file(tmp_path: Path): # Create an empty JSONL file file = tmp_path / "empty.jsonl" @@ -414,6 +487,37 @@ def test_check_jsonl_invalid_weight(tmp_path: Path): assert "Weight must be either 0 or 1" in report["message"] +def test_check_jsonl_invalid_multimodal_content(tmp_path: Path): + file = tmp_path / "invalid_multimodal_content.jsonl" + content = [ + { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": ""}, + }, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "Hi there!"}], + }, + ] + } + ] + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + assert not report["is_check_passed"] + assert "field must be either a JPEG, PNG or WEBP" in report["message"] + + def test_check_csv_valid_general(tmp_path: Path): # Create a valid CSV file file = tmp_path / "valid.csv" diff --git a/tests/unit/test_finetune_resources.py b/tests/unit/test_finetune_resources.py index 6020a0c..73ab149 100644 --- a/tests/unit/test_finetune_resources.py +++ b/tests/unit/test_finetune_resources.py @@ -1,5 +1,6 @@ +from unittest.mock import MagicMock, Mock + import pytest -from unittest.mock import MagicMock, Mock, patch from together.client import Together from together.resources.finetune import create_finetune_request