diff --git a/customization/bedrock-finetuning/understanding/dataset_validation/README.md b/customization/bedrock-finetuning/understanding/dataset_validation/README.md index 64464b45..bfbd8f4e 100644 --- a/customization/bedrock-finetuning/understanding/dataset_validation/README.md +++ b/customization/bedrock-finetuning/understanding/dataset_validation/README.md @@ -16,17 +16,34 @@ python3 nova_ft_dataset_validator.py -i -m -t -m -t 1: raise ValueError("Only one video is allowed per sample") @@ -317,6 +558,7 @@ class ConverseDatasetSample(BaseModel): schemaVersion: Optional[str] = None system: Optional[List[SystemMessage]] = None + toolConfig: Optional[ToolConfig] = None messages: List[Message] @field_validator("messages") @@ -325,6 +567,13 @@ def validate_data_sample_rules(cls, messages): check_roles_order(messages) return messages + @model_validator(mode="after") + def validate_tool_use_rules(cls, values): + """Validates tool use rules across the conversation.""" + if values.toolConfig is not None: + validate_tool_use_in_conversation(values.messages, values.toolConfig) + return values + MessageOrCandidate = Annotated[ Union[ @@ -358,6 +607,154 @@ def validate_data_sample_rules(cls, messages: List[MessageOrCandidate]): return messages +# RFT (Reinforcement Fine-Tuning) Models +class RFTFunctionParameters(BaseModel): + """Represents parameters for an RFT function.""" + + type: str + properties: dict + required: Optional[List[str]] = None + + @field_validator("type") + def validate_type(cls, param_type): + if param_type != "object": + raise ValueError("Invalid parameters type, must be 'object'") + return param_type + + @field_validator("properties") + def validate_properties(cls, properties): + if not isinstance(properties, dict): + raise ValueError("Invalid properties, must be a dictionary") + return properties + + +class RFTFunction(BaseModel): + """Represents an RFT function specification.""" + + name: str + description: str + parameters: RFTFunctionParameters + + @field_validator("name") + def validate_name(cls, name): + if not name or not name.strip(): + raise ValueError("Invalid function name, cannot be empty") + return name + + @field_validator("description") + def validate_description(cls, description): + if not description or not description.strip(): + raise ValueError("Invalid function description, cannot be empty") + return description + + +class RFTTool(BaseModel): + """Represents an RFT tool.""" + + type: str + function: RFTFunction + + @field_validator("type") + def validate_type(cls, tool_type): + if tool_type != "function": + raise ValueError("Invalid tool type, must be 'function'") + return tool_type + + +class RFTMessage(BaseModel): + """Represents a simple RFT message with optional role and content per RFT specification.""" + + role: Optional[str] = None + content: Optional[str] = None + + @field_validator("role") + def validate_role(cls, role): + # role is optional, but if provided must be valid + if role is not None: + valid_roles = ["system", "user", "assistant"] + if role.lower() not in valid_roles: + raise ValueError(f"Invalid role, must be one of {valid_roles}") + return role + + @field_validator("content") + def validate_content(cls, content): + # content is optional, but if provided must not be empty + if content is not None: + if not content.strip(): + raise ValueError("Invalid content, if provided cannot be empty") + validate_invalid_tokens(content) + return content + + +class RFTDatasetSample(BaseModel): + """Represents an RFT dataset sample with required messages and tools, optional id and reference answer. + + Field requirements per RFT specification: + - id: Optional - Unique identifier for tracking + - messages: Required - Array of message objects + - messages[].role: Optional - "system", "user", or "assistant" + - messages[].content: Optional - Text content of the message + - tools: Required - Tool specifications available to the model + - reference_answer: Optional - Expected output (string or object) + """ + + id: Optional[str] = None + messages: List[RFTMessage] + tools: List[RFTTool] + reference_answer: Optional[Union[str, dict]] = None + + @field_validator("id") + def validate_id(cls, sample_id): + # id is optional, but if provided must not be empty + if sample_id is not None and (not sample_id or not sample_id.strip()): + raise ValueError("Invalid id, if provided cannot be empty") + return sample_id + + @field_validator("messages") + def validate_messages(cls, messages): + if not messages: + raise ValueError("Invalid messages, cannot be empty") + + # Check that messages have valid role sequence if roles are provided + has_system = any(msg.role.lower() == "system" for msg in messages if msg.role) + if has_system: + # If there's a system message, it should be first + first_role = messages[0].role.lower() if messages[0].role else None + if first_role != "system": + raise ValueError("Invalid messages, system message must be first if present") + + # Check that there's at least one user message + if not any(msg.role and msg.role.lower() == "user" for msg in messages): + raise ValueError("Invalid messages, must have at least one user message") + + return messages + + @field_validator("reference_answer") + def validate_reference_answer(cls, reference_answer): + # reference_answer is optional, but if provided must not be empty + if reference_answer is not None: + if isinstance(reference_answer, str): + if not reference_answer.strip(): + raise ValueError("Invalid reference_answer, if provided as string cannot be empty") + elif isinstance(reference_answer, dict): + if not reference_answer: + raise ValueError("Invalid reference_answer, if provided as dict cannot be empty") + else: + raise ValueError("Invalid reference_answer, must be a string or dictionary") + return reference_answer + + @field_validator("tools") + def validate_tools(cls, tools): + # tools is required and cannot be empty + if not tools: + raise ValueError("Invalid tools, tools field is required and cannot be empty list") + # Check for duplicate tool names + tool_names = [tool.function.name for tool in tools] + if len(tool_names) != len(set(tool_names)): + raise ValueError("Invalid tools, duplicate tool names found") + return tools + + def validate_converse_dataset(args): """Validates the entire conversation dataset against Nova format requirements.""" try: @@ -368,7 +765,7 @@ def validate_converse_dataset(args): # Only validate data record bounds for Bedrock platform if args.platform.lower() == "bedrock": print(f"Platform: {args.platform} - Validating data record bounds") - validate_data_record_bounds(num_samples, args.model_name) + validate_data_record_bounds(num_samples, args.model_name, args.task_type) else: print(f"Platform: {args.platform} - Skipping data record bounds validation") except Exception as e: @@ -381,9 +778,19 @@ def validate_converse_dataset(args): print(f"Validating samples for model: {args.model_name}") task_type = TaskType(str(args.task_type).upper()) print(f"Using task: {task_type}") + + # RFT is only supported on lite-2.0 + if task_type is TaskType.RFT and args.model_name != "lite-2.0": + raise NovaClientError( + f"RFT task type is only supported on lite-2.0 model. " + f"Current model: {args.model_name}. Please use -m lite-2.0 for RFT tasks." + ) + for i, sample in enumerate(samples): try: - if task_type is TaskType.DPO: + if task_type is TaskType.RFT: + RFTDatasetSample.model_validate(sample) + elif task_type is TaskType.DPO: ConverseDatasetSampleWithCandidates.model_validate( sample, context={"model_name": args.model_name} ) @@ -450,91 +857,3 @@ def check_roles_order(messages): for i, message in enumerate(messages): if i % 2 == 0 and message.role != ConverseRoles.USER: raise ValueError( - f"Invalid messages, expected {ConverseRoles.USER} role but found {message.role}" - ) - elif i % 2 == 1 and message.role != ConverseRoles.ASSISTANT: - raise ValueError( - f"Invalid messages, expected {ConverseRoles.ASSISTANT} role but found {message.role}" - ) - - # When turns are odd - if messages[-1].role != ConverseRoles.ASSISTANT: - raise ValueError(f"Invalid messages, last turn should have {ConverseRoles.ASSISTANT} role") - - -def is_valid_path(file_path): - """Validates that file path contains only alphanumeric characters, underscores, hyphens, slashes, and dots.""" - pattern = r"^[\w\-/\.]+$" - if not re.match(pattern, file_path): - raise ValueError( - "Invalid characters in 'uri'. Only alphanumeric, underscores, hyphens, slashes, and dots are allowed" - ) - - -def get_data_record_bounds(model_name: str): - """Returns the minimum and maximum number of samples allowed for a given model.""" - return MODEL_TO_NUM_SAMPLES_MAP[model_name] - - -def validate_data_record_bounds(num_samples: int, model_name: str): - """Validates that the number of samples is within allowed bounds for the model.""" - data_record_bounds = get_data_record_bounds(model_name) - if num_samples < data_record_bounds[0] or num_samples > data_record_bounds[1]: - raise NovaClientError( - ( - f"Numer of samples {num_samples} out of bounds between {data_record_bounds[0]} and {data_record_bounds[1]} " - f"for {model_name}" - ) - ) - - -def validate_role_name(role: str): - if role.lower() not in CONVERSE_ROLES_WITHOUT_SYSTEM: - raise ValueError( - f"Invalid value for role, valid values are {CONVERSE_ROLES_WITHOUT_SYSTEM}" - ) - return role - - -if __name__ == "__main__": - description = """ - This script is for validating Nova converse format. - Takes input a jsonl file with samples in the Nova converse format: - https://docs.aws.amazon.com/nova/latest/userguide/customize-fine-tune-prepare.html - """ - parser = argparse.ArgumentParser( - description=description, formatter_class=argparse.RawTextHelpFormatter - ) - parser.add_argument( - "-i", - "--input_file", - type=str, - required=True, - help="The input jsonl file in Nova converse format", - ) - parser.add_argument( - "-m", - "--model_name", - type=str, - choices=["micro", "lite", "pro"], - required=True, - help="Choose a model from: micro, lite, pro", - ) - parser.add_argument( - "-t", - "--task_type", - type=str, - choices=["sft", "dpo", "SFT", "DPO"], - required=True, - help="Choose a task type: sft, dpo", - ) - parser.add_argument( - "-p", - "--platform", - type=str, - choices=["bedrock", "sagemaker"], - default="bedrock", - help="Choose a platform: bedrock, sagemaker (default: bedrock)", - ) - args = parser.parse_args() - validate_converse_dataset(args)