From d27b7c76e659521d57b12e1734376aa293e70ab5 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 17:52:06 +0000 Subject: [PATCH 01/17] a very simple example based on rubric grading --- tinker_cookbook/recipes/rubric/debug_env.py | 44 +++ tinker_cookbook/recipes/rubric/env.py | 321 ++++++++++++++++++ .../recipes/rubric/generate_data.py | 41 +++ tinker_cookbook/recipes/rubric/train.py | 148 ++++++++ 4 files changed, 554 insertions(+) create mode 100644 tinker_cookbook/recipes/rubric/debug_env.py create mode 100644 tinker_cookbook/recipes/rubric/env.py create mode 100644 tinker_cookbook/recipes/rubric/generate_data.py create mode 100644 tinker_cookbook/recipes/rubric/train.py diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py new file mode 100644 index 00000000..99a9955e --- /dev/null +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -0,0 +1,44 @@ +from tinker_cookbook import model_info +from tinker_cookbook.recipes.rubric.env import RubricGradedEnv, RubricBasedDatapoint, Rubric +from tinker_cookbook.completers import TinkerMessageCompleter, TinkerTokenCompleter +from tinker_cookbook.renderers import get_renderer +from tinker_cookbook.tokenizer_utils import get_tokenizer +import tinker +from tinker_cookbook.rl.rollouts import do_single_rollout +import asyncio + +async def main(): + datapoint = RubricBasedDatapoint( + convo=[ + {"role": "user", "content": "What is 4 + 5?"}, + {"role": "assistant", "content": "9"}, + {"role": "user", "content": "What is 125 + 311?"}, + ], + rubric_items=[Rubric(rubric_str="Does the chatbot correctly gets the answer 436?"), Rubric(rubric_str="Does the chatbot provide an explanation for the answer?")] + ) + policy_name = "meta-llama/Llama-3.1-8B-Instruct" + grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + service_client = tinker.ServiceClient() + policy = TinkerTokenCompleter( + sampling_client=service_client.create_sampling_client(base_model=policy_name), + max_tokens=64, + ) + policy_renderer = get_renderer(model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name)) + grader = TinkerMessageCompleter( + sampling_client=service_client.create_sampling_client(base_model=grader_name), + renderer=get_renderer(model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name)), + max_tokens=64, + ) + + env = RubricGradedEnv( + renderer=policy_renderer, + datapoint=datapoint, + grader_llm=grader, + debug=True, + ) + + await do_single_rollout(policy, env) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py new file mode 100644 index 00000000..f982687b --- /dev/null +++ b/tinker_cookbook/recipes/rubric/env.py @@ -0,0 +1,321 @@ +from tinker_cookbook.rl.types import ( + Action, + Env, + StepResult, + EnvGroupBuilder, + RLDataset, + RLDatasetBuilder, +) +from tinker_cookbook.renderers import Message, Renderer, Role +from typing import TypeAlias +from tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter +from tinker.types import ModelInput +from dataclasses import dataclass +from typing import Sequence +import re +import json +import chz +import tinker +from tinker_cookbook.tokenizer_utils import get_tokenizer +from tinker_cookbook.renderers import get_renderer +import logging +import asyncio + +import chz +from tinker_cookbook import model_info + +logger = logging.getLogger(__name__) + +Conversation: TypeAlias = list[Message] + +# ANSI color codes +BLUE = "\033[94m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +MAGENTA = "\033[95m" +RESET = "\033[0m" + + +@dataclass +class Rubric: + """ + A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response. + """ + rubric_str: str + extraction_regex: str = r"(.*)" + grader_output_format_instruction: str = "Please output your score between 0 and 1 wrapped in ... " + + def __convert_role(self, role: Role) -> str: + return "Human" if role in ("user", "system") else "Chatbot" + + def _flatten_convo(self, convo: Conversation) -> str: + """ + Convert the whole conversation (user's turns + assistant's turns) into a single string. E.g. + \n\nHuman: .... + \n\nChatbot: ... + \n\nHuman: ... + \n\nChatbot: ... + """ + return "\n\n".join([f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo]) + + def get_grader_prompt(self, convo: Conversation) -> Conversation: + """ + Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. + """ + + prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric." + + prompt += f"Here is the conversation: \n\n{self._flatten_convo(convo)} \n\n\n\nHere is the rubric: \n{self.rubric_str}\n\n" + prompt += f"Please grade the conversation based on the rubric. {self.grader_output_format_instruction}" + return [ + { + "role": "user", + "content": prompt, + } + ] + + def extract_score(self, response: str) -> float: + match = re.search(self.extraction_regex, response, re.DOTALL) + if match is not None: + try: + return float(match.group(1)) + except ValueError: + print(f"Warning: Failed to extract score from grader response: {response}") + return 0.0 + else: + print(f"Warning: Failed to extract score from grader response: {response}") + return 0.0 + + def to_dict(self) -> dict[str, str]: + return { + "rubric_str": self.rubric_str, + "extraction_regex": self.extraction_regex, + "grader_output_format_instruction": self.grader_output_format_instruction, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @staticmethod + def from_dict(d: dict[str, str]) -> "Rubric": + return Rubric( + rubric_str=d["rubric_str"], + extraction_regex=d["extraction_regex"], + grader_output_format_instruction=d["grader_output_format_instruction"], + ) + + @staticmethod + def from_json(json_str: str) -> "Rubric": + return Rubric.from_dict(json.loads(json_str)) + + + +@dataclass(frozen=True) +class RubricBasedDatapoint: + """ + A rubric-based datapoint contains a conversation and a rubric. + In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. + """ + convo: Conversation + rubric_items: Sequence[Rubric] + + def to_json(self) -> str: + return json.dumps({ + "convo": self.convo, + "rubric_items": [rubric.to_json() for rubric in self.rubric_items], + }) + + @staticmethod + def from_json(json_str: str) -> "RubricBasedDatapoint": + d = json.loads(json_str) + return RubricBasedDatapoint( + convo=d["convo"], + rubric_items=[Rubric.from_json(rubric) for rubric in d["rubric_items"]], + ) + + +class RubricGradedEnv(Env): + + def __init__( + self, + renderer: Renderer, + datapoint: RubricBasedDatapoint, + grader_llm: MessageCompleter, + debug: bool = False, + ): + """ + Initialize the RubricGradedEnv. In this environment, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. + """ + self.renderer = renderer + self.datapoint = datapoint + self.grader_llm = grader_llm + self.debug = debug + + @property + def rubric_items(self) -> Sequence[Rubric]: + return self.datapoint.rubric_items + + @property + def convo(self) -> Conversation: + return self.datapoint.convo + + @property + def stop_condition(self) -> StopCondition: + return self.renderer.get_stop_sequences() + + async def initial_observation(self) -> tuple[ModelInput, StopCondition]: + return self.renderer.build_generation_prompt(self.convo), self.stop_condition + + async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> float: + # this is the conversation for the grader + # effectively it's just one user turn + grader_prompt = rubric.get_grader_prompt(convo) + + # obtain the response from the grader and convert it to a score + grader_response = await self.grader_llm(grader_prompt) + grader_response_content = grader_response["content"] + assert isinstance(grader_response_content, str), "Grader response content must be a string" + score = rubric.extract_score(grader_response_content) + if self.debug: + print(f"{YELLOW}{'='*80}") + print(f"DEBUG: First Turn of Grader Prompt") + print(f"{'='*80}{RESET}") + print(f"{YELLOW}{grader_prompt[0]['content']}{RESET}\n") + + print(f"{MAGENTA}{'='*80}") + print(f"DEBUG: Score") + print(f"{'='*80}{RESET}") + print(f"{MAGENTA}Score: {score}{RESET}\n") + return score + + async def step(self, action: Action) -> StepResult: + # obtain the policy action message + (policy_action_message, _parse_success) = self.renderer.parse_response(action) + + if self.debug: + + print(f"\n{BLUE}{'='*80}") + print(f"DEBUG: Original Conversation (self.convo)") + print(f"{'='*80}{RESET}") + print(f"{BLUE}{json.dumps(self.convo, indent=2)}{RESET}\n") + + print(f"{GREEN}{'='*80}") + print(f"DEBUG: Policy Action Message") + print(f"{'='*80}{RESET}") + print(f"{GREEN}{json.dumps(policy_action_message, indent=2)}{RESET}\n") + # this shows the full back-and-forth conversation to the grader + convo = self.convo + [policy_action_message] + + scores = await asyncio.gather(*[self._grade_with_rubric(convo, rubric_item) for rubric_item in self.rubric_items]) + avg_score = sum(scores) / len(scores) + + return StepResult( + reward=avg_score, + episode_done=True, + next_observation=self.renderer.build_generation_prompt(convo), + next_stop_condition=self.stop_condition, + ) + + +@dataclass(frozen=True) +class RubricGradedEnvGroupBuilder(EnvGroupBuilder): + renderer: Renderer + datapoint: RubricBasedDatapoint + grader_llm: MessageCompleter + group_size: int + + async def make_envs(self) -> Sequence[RubricGradedEnv]: + return [ + RubricGradedEnv( + renderer=self.renderer, + datapoint=self.datapoint, + grader_llm=self.grader_llm, + ) for _ in range(self.group_size) + ] + + +@dataclass(frozen=True) +class RubricGradedDataset(RLDataset): + renderer: Renderer + batch_size: int + group_size: int + datapoints: Sequence[RubricBasedDatapoint] + grader_llm: MessageCompleter + + def get_batch(self, index: int) -> Sequence[RubricGradedEnvGroupBuilder]: + batch = [ + RubricGradedEnvGroupBuilder( + renderer=self.renderer, + datapoint=self.datapoints[index * self.batch_size + i], + grader_llm=self.grader_llm, + group_size=self.group_size, + ) + for i in range(self.batch_size) + ] + return batch + + def __len__(self) -> int: + return len(self.datapoints) // self.batch_size + + +@chz.chz +class RubricGradedDatasetBuilder(RLDatasetBuilder): + renderer_name: str + model_name_for_tokenizer: str + batch_size: int + train_group_size: int + test_group_size: int = 1 + + train_jsonl_path: str + test_jsonl_path: str | None = None + + base_url: str | None = None + grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" + + def _get_datapoints_from_jsonl(self, jsonl_path: str | None) -> Sequence[RubricBasedDatapoint] | None: + if jsonl_path is None: + return None + datapoints = [] + with open(jsonl_path, "r") as f: + for line in f: + datapoint = RubricBasedDatapoint.from_json(line) + datapoints.append(datapoint) + return datapoints + + def _get_grader_llm(self) -> MessageCompleter: + tokenizer = get_tokenizer(self.grader_llm_name) + renderer_name = model_info.get_recommended_renderer_name(self.grader_llm_name) + renderer = get_renderer(name=renderer_name, tokenizer=tokenizer) + service_client = tinker.ServiceClient(base_url=self.base_url) + sampling_client = service_client.create_sampling_client(base_model=self.grader_llm_name) + return TinkerMessageCompleter( + sampling_client=sampling_client, + renderer=renderer, + max_tokens=2048 + ) + + async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | None]: + train_datapoints = self._get_datapoints_from_jsonl(self.train_jsonl_path) + test_datapoints = self._get_datapoints_from_jsonl(self.test_jsonl_path) + + renderer = get_renderer(name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer)) + + assert train_datapoints is not None, "Train datapoints are required" + train_dataset = RubricGradedDataset( + renderer=renderer, + batch_size=self.batch_size, + group_size=self.train_group_size, + datapoints=train_datapoints, + grader_llm=self._get_grader_llm(), + ) + if test_datapoints is None: + return train_dataset, None + else: + test_dataset = RubricGradedDataset( + renderer=renderer, + batch_size=len(test_datapoints), + group_size=self.test_group_size, + datapoints=test_datapoints, + grader_llm=self._get_grader_llm(), + ) + return train_dataset, test_dataset \ No newline at end of file diff --git a/tinker_cookbook/recipes/rubric/generate_data.py b/tinker_cookbook/recipes/rubric/generate_data.py new file mode 100644 index 00000000..0eba5030 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/generate_data.py @@ -0,0 +1,41 @@ +from tinker_cookbook.recipes.rubric.env import RubricBasedDatapoint, Rubric +import random +import os + +def generate_one(rng: random.Random) -> RubricBasedDatapoint: + x, y = rng.randint(0, 1000), rng.randint(0, 1000) + return RubricBasedDatapoint( + convo=[ + {"role": "user", "content": "What is 4 + 5?"}, + {"role": "assistant", "content": "9"}, + {"role": "user", "content": f"What is {x} + {y}?"}, + ], rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly gets the answer {x + y}?")] + ) + + +def generate_dataset(num_train: int, num_test: int, seed: int, write_dir: str = "tinker_cookbook/example_data/") -> tuple[str, str]: + random.seed(seed) + rng = random.Random(seed) + total_datapoints = num_train + num_test + datapoints = [generate_one(rng) for _ in range(total_datapoints)] + + train_datapoints = datapoints[:num_train] + train_jsonl_path = os.path.join(write_dir, "example_rubric_train.jsonl") + with open(train_jsonl_path, "w") as f: + for datapoint in train_datapoints: + f.write(datapoint.to_json() + "\n") + print(f"Generated {len(train_datapoints)} train datapoints in {train_jsonl_path}") + + test_datapoints = datapoints[num_train:] + test_jsonl_path = os.path.join(write_dir, "example_rubric_test.jsonl") + with open(test_jsonl_path, "w") as f: + for datapoint in test_datapoints: + f.write(datapoint.to_json() + "\n") + print(f"Generated {len(test_datapoints)} test datapoints in {test_jsonl_path}") + + return train_jsonl_path, test_jsonl_path + +if __name__ == "__main__": + train_jsonl_path, test_jsonl_path = generate_dataset(num_train=10000, num_test=1000, seed=42) + print(f"Generated train dataset in {train_jsonl_path}") + print(f"Generated test dataset in {test_jsonl_path}") \ No newline at end of file diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py new file mode 100644 index 00000000..2b6f5714 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/train.py @@ -0,0 +1,148 @@ +import chz +import asyncio +from datetime import datetime + +import chz +from tinker_cookbook import cli_utils, model_info +from tinker_cookbook.rl.train import AsyncConfig, Config, main +from tinker_cookbook.rl.types import RLDatasetBuilder +from tinker.types import LossFnType +from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder + +@chz.chz +class CLIConfig: + """Simple command-line configuration for RL training.""" + + # Model configuration + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + lora_rank: int = 32 + renderer_name: str | None = None + load_checkpoint_path: str | None = None + + seed: int = 0 # Random seed for data shuffling + + # Training hyperparameters + train_group_size: int = 4 + test_group_size: int = 1 + groups_per_batch: int = 100 + learning_rate: float = 1e-5 + max_tokens: int = 5 + temperature: float = 1.0 + kl_penalty_coef: float = 0.0 + grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" + train_jsonl_path: str = "tinker_cookbook/example_data/example_rubric_train.jsonl" + test_jsonl_path: str = "tinker_cookbook/example_data/example_rubric_test.jsonl" + + # Number of optimizer steps per training iteration. + # Useful for very large batch sizes. + num_substeps: int = 1 + + # Logging configuration + log_path: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + compute_post_kl: bool = False + + # Evals + eval_every: int = 20 + + # Checkpointing + save_every: int = 20 + + # Service configuration + base_url: str | None = None + + behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" + + max_steps_off_policy: int | None = None + loss_fn: LossFnType = "importance_sampling" + + +def get_dataset_builder( + batch_size: int, + policy_model_name: str, + renderer_name: str, + grader_llm_name: str, + train_group_size: int, + train_jsonl_path: str, + test_jsonl_path: str | None = None, + test_group_size: int = 1, +) -> RLDatasetBuilder: + return RubricGradedDatasetBuilder( + batch_size=batch_size, + model_name_for_tokenizer=policy_model_name, + renderer_name=renderer_name, + grader_llm_name=grader_llm_name, + train_jsonl_path=train_jsonl_path, + test_jsonl_path=test_jsonl_path, + train_group_size=train_group_size, + test_group_size=test_group_size, + ) + + + +async def cli_main(cli_config: CLIConfig): + """Convert CLI config to full config and run training.""" + + # Get tokenizer for stop sequences + renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name( + cli_config.model_name + ) + model_name = cli_config.model_name.replace("/", "-") + run_name = f"{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" + # create log path if it doesn't exist + if cli_config.log_path is not None: + log_path = cli_config.log_path + else: + log_path = f"/tmp/tinker-examples/math_rl/{run_name}" + + if cli_config.wandb_name is not None: + wandb_name = cli_config.wandb_name + else: + wandb_name = run_name + + # Create full config + config = Config( + learning_rate=cli_config.learning_rate, + dataset_builder=get_dataset_builder( + batch_size=cli_config.groups_per_batch, + policy_model_name=cli_config.model_name, + renderer_name=renderer_name, + grader_llm_name=cli_config.grader_llm_name, + train_group_size=cli_config.train_group_size, + train_jsonl_path=cli_config.train_jsonl_path, + test_jsonl_path=cli_config.test_jsonl_path, + test_group_size=cli_config.test_group_size, + ), + model_name=cli_config.model_name, + lora_rank=cli_config.lora_rank, + max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, + wandb_project=cli_config.wandb_project, + wandb_name=wandb_name, + log_path=log_path, + base_url=cli_config.base_url, + load_checkpoint_path=cli_config.load_checkpoint_path, + compute_post_kl=cli_config.compute_post_kl, + kl_penalty_coef=cli_config.kl_penalty_coef, + num_substeps=cli_config.num_substeps, + eval_every=cli_config.eval_every, + save_every=cli_config.save_every, + async_config=AsyncConfig( + max_steps_off_policy=cli_config.max_steps_off_policy, + groups_per_batch=cli_config.groups_per_batch, + ) + if cli_config.max_steps_off_policy is not None + else None, + loss_fn=cli_config.loss_fn, + ) + + cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) + + # Run training + await main(config) + + +if __name__ == "__main__": + cli_config = chz.entrypoint(CLIConfig) + asyncio.run(cli_main(cli_config)) \ No newline at end of file From aa0c6de3c6f3c159d13bbc87baef44003330e11c Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 18:16:50 +0000 Subject: [PATCH 02/17] b --- tinker_cookbook/recipes/rubric/debug_env.py | 2 +- tinker_cookbook/recipes/rubric/env.py | 11 ++++------- tinker_cookbook/recipes/rubric/train.py | 2 -- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py index 99a9955e..1b48c082 100644 --- a/tinker_cookbook/recipes/rubric/debug_env.py +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -14,7 +14,7 @@ async def main(): {"role": "assistant", "content": "9"}, {"role": "user", "content": "What is 125 + 311?"}, ], - rubric_items=[Rubric(rubric_str="Does the chatbot correctly gets the answer 436?"), Rubric(rubric_str="Does the chatbot provide an explanation for the answer?")] + rubric_items=[Rubric(rubric_str="Does the chatbot correctly get the answer 436?"), Rubric(rubric_str="Does the chatbot provide an answer without saying anything else?")] ) policy_name = "meta-llama/Llama-3.1-8B-Instruct" grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py index f982687b..fec5021c 100644 --- a/tinker_cookbook/recipes/rubric/env.py +++ b/tinker_cookbook/recipes/rubric/env.py @@ -18,13 +18,9 @@ import tinker from tinker_cookbook.tokenizer_utils import get_tokenizer from tinker_cookbook.renderers import get_renderer -import logging import asyncio - -import chz from tinker_cookbook import model_info -logger = logging.getLogger(__name__) Conversation: TypeAlias = list[Message] @@ -122,7 +118,7 @@ class RubricBasedDatapoint: def to_json(self) -> str: return json.dumps({ "convo": self.convo, - "rubric_items": [rubric.to_json() for rubric in self.rubric_items], + "rubric_items": [rubric.to_dict() for rubric in self.rubric_items], }) @staticmethod @@ -130,7 +126,7 @@ def from_json(json_str: str) -> "RubricBasedDatapoint": d = json.loads(json_str) return RubricBasedDatapoint( convo=d["convo"], - rubric_items=[Rubric.from_json(rubric) for rubric in d["rubric_items"]], + rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]], ) @@ -185,7 +181,8 @@ async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> float print(f"{MAGENTA}{'='*80}") print(f"DEBUG: Score") print(f"{'='*80}{RESET}") - print(f"{MAGENTA}Score: {score}{RESET}\n") + print(f"{MAGENTA}Grader Response: {grader_response_content}{RESET}\n") + print(f"{MAGENTA}Extracted Score: {score}{RESET}\n") return score async def step(self, action: Action) -> StepResult: diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py index 2b6f5714..6a91c566 100644 --- a/tinker_cookbook/recipes/rubric/train.py +++ b/tinker_cookbook/recipes/rubric/train.py @@ -1,8 +1,6 @@ import chz import asyncio from datetime import datetime - -import chz from tinker_cookbook import cli_utils, model_info from tinker_cookbook.rl.train import AsyncConfig, Config, main from tinker_cookbook.rl.types import RLDatasetBuilder From 40e53745801bed7a2be61771d8878ad72853bac5 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 18:18:36 +0000 Subject: [PATCH 03/17] b --- tinker_cookbook/recipes/rubric/debug_env.py | 16 +++- tinker_cookbook/recipes/rubric/env.py | 86 +++++++++++-------- .../recipes/rubric/generate_data.py | 11 ++- tinker_cookbook/recipes/rubric/train.py | 4 +- 4 files changed, 70 insertions(+), 47 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py index 1b48c082..9068dfce 100644 --- a/tinker_cookbook/recipes/rubric/debug_env.py +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -7,6 +7,7 @@ from tinker_cookbook.rl.rollouts import do_single_rollout import asyncio + async def main(): datapoint = RubricBasedDatapoint( convo=[ @@ -14,7 +15,10 @@ async def main(): {"role": "assistant", "content": "9"}, {"role": "user", "content": "What is 125 + 311?"}, ], - rubric_items=[Rubric(rubric_str="Does the chatbot correctly get the answer 436?"), Rubric(rubric_str="Does the chatbot provide an answer without saying anything else?")] + rubric_items=[ + Rubric(rubric_str="Does the chatbot correctly get the answer 436?"), + Rubric(rubric_str="Does the chatbot provide an answer without saying anything else?"), + ], ) policy_name = "meta-llama/Llama-3.1-8B-Instruct" grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" @@ -23,10 +27,14 @@ async def main(): sampling_client=service_client.create_sampling_client(base_model=policy_name), max_tokens=64, ) - policy_renderer = get_renderer(model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name)) + policy_renderer = get_renderer( + model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name) + ) grader = TinkerMessageCompleter( sampling_client=service_client.create_sampling_client(base_model=grader_name), - renderer=get_renderer(model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name)), + renderer=get_renderer( + model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name) + ), max_tokens=64, ) @@ -41,4 +49,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py index fec5021c..360700a4 100644 --- a/tinker_cookbook/recipes/rubric/env.py +++ b/tinker_cookbook/recipes/rubric/env.py @@ -37,9 +37,12 @@ class Rubric: """ A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response. """ + rubric_str: str extraction_regex: str = r"(.*)" - grader_output_format_instruction: str = "Please output your score between 0 and 1 wrapped in ... " + grader_output_format_instruction: str = ( + "Please output your score between 0 and 1 wrapped in ... " + ) def __convert_role(self, role: Role) -> str: return "Human" if role in ("user", "system") else "Chatbot" @@ -52,11 +55,13 @@ def _flatten_convo(self, convo: Conversation) -> str: \n\nHuman: ... \n\nChatbot: ... """ - return "\n\n".join([f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo]) + return "\n\n".join( + [f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo] + ) def get_grader_prompt(self, convo: Conversation) -> Conversation: """ - Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. + Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. """ prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric." @@ -99,27 +104,29 @@ def from_dict(d: dict[str, str]) -> "Rubric": extraction_regex=d["extraction_regex"], grader_output_format_instruction=d["grader_output_format_instruction"], ) - + @staticmethod def from_json(json_str: str) -> "Rubric": return Rubric.from_dict(json.loads(json_str)) - @dataclass(frozen=True) class RubricBasedDatapoint: """ - A rubric-based datapoint contains a conversation and a rubric. + A rubric-based datapoint contains a conversation and a rubric. In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. """ + convo: Conversation rubric_items: Sequence[Rubric] def to_json(self) -> str: - return json.dumps({ - "convo": self.convo, - "rubric_items": [rubric.to_dict() for rubric in self.rubric_items], - }) + return json.dumps( + { + "convo": self.convo, + "rubric_items": [rubric.to_dict() for rubric in self.rubric_items], + } + ) @staticmethod def from_json(json_str: str) -> "RubricBasedDatapoint": @@ -131,7 +138,6 @@ def from_json(json_str: str) -> "RubricBasedDatapoint": class RubricGradedEnv(Env): - def __init__( self, renderer: Renderer, @@ -150,7 +156,7 @@ def __init__( @property def rubric_items(self) -> Sequence[Rubric]: return self.datapoint.rubric_items - + @property def convo(self) -> Conversation: return self.datapoint.convo @@ -158,7 +164,7 @@ def convo(self) -> Conversation: @property def stop_condition(self) -> StopCondition: return self.renderer.get_stop_sequences() - + async def initial_observation(self) -> tuple[ModelInput, StopCondition]: return self.renderer.build_generation_prompt(self.convo), self.stop_condition @@ -173,37 +179,38 @@ async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> float assert isinstance(grader_response_content, str), "Grader response content must be a string" score = rubric.extract_score(grader_response_content) if self.debug: - print(f"{YELLOW}{'='*80}") - print(f"DEBUG: First Turn of Grader Prompt") - print(f"{'='*80}{RESET}") + print(f"{YELLOW}{'=' * 80}") + print("DEBUG: First Turn of Grader Prompt") + print(f"{'=' * 80}{RESET}") print(f"{YELLOW}{grader_prompt[0]['content']}{RESET}\n") - print(f"{MAGENTA}{'='*80}") - print(f"DEBUG: Score") - print(f"{'='*80}{RESET}") + print(f"{MAGENTA}{'=' * 80}") + print("DEBUG: Score") + print(f"{'=' * 80}{RESET}") print(f"{MAGENTA}Grader Response: {grader_response_content}{RESET}\n") print(f"{MAGENTA}Extracted Score: {score}{RESET}\n") return score - + async def step(self, action: Action) -> StepResult: # obtain the policy action message (policy_action_message, _parse_success) = self.renderer.parse_response(action) if self.debug: - - print(f"\n{BLUE}{'='*80}") - print(f"DEBUG: Original Conversation (self.convo)") - print(f"{'='*80}{RESET}") + print(f"\n{BLUE}{'=' * 80}") + print("DEBUG: Original Conversation (self.convo)") + print(f"{'=' * 80}{RESET}") print(f"{BLUE}{json.dumps(self.convo, indent=2)}{RESET}\n") - print(f"{GREEN}{'='*80}") - print(f"DEBUG: Policy Action Message") - print(f"{'='*80}{RESET}") + print(f"{GREEN}{'=' * 80}") + print("DEBUG: Policy Action Message") + print(f"{'=' * 80}{RESET}") print(f"{GREEN}{json.dumps(policy_action_message, indent=2)}{RESET}\n") - # this shows the full back-and-forth conversation to the grader + # this shows the full back-and-forth conversation to the grader convo = self.convo + [policy_action_message] - scores = await asyncio.gather(*[self._grade_with_rubric(convo, rubric_item) for rubric_item in self.rubric_items]) + scores = await asyncio.gather( + *[self._grade_with_rubric(convo, rubric_item) for rubric_item in self.rubric_items] + ) avg_score = sum(scores) / len(scores) return StepResult( @@ -227,7 +234,8 @@ async def make_envs(self) -> Sequence[RubricGradedEnv]: renderer=self.renderer, datapoint=self.datapoint, grader_llm=self.grader_llm, - ) for _ in range(self.group_size) + ) + for _ in range(self.group_size) ] @@ -269,7 +277,9 @@ class RubricGradedDatasetBuilder(RLDatasetBuilder): base_url: str | None = None grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" - def _get_datapoints_from_jsonl(self, jsonl_path: str | None) -> Sequence[RubricBasedDatapoint] | None: + def _get_datapoints_from_jsonl( + self, jsonl_path: str | None + ) -> Sequence[RubricBasedDatapoint] | None: if jsonl_path is None: return None datapoints = [] @@ -278,7 +288,7 @@ def _get_datapoints_from_jsonl(self, jsonl_path: str | None) -> Sequence[RubricB datapoint = RubricBasedDatapoint.from_json(line) datapoints.append(datapoint) return datapoints - + def _get_grader_llm(self) -> MessageCompleter: tokenizer = get_tokenizer(self.grader_llm_name) renderer_name = model_info.get_recommended_renderer_name(self.grader_llm_name) @@ -286,16 +296,16 @@ def _get_grader_llm(self) -> MessageCompleter: service_client = tinker.ServiceClient(base_url=self.base_url) sampling_client = service_client.create_sampling_client(base_model=self.grader_llm_name) return TinkerMessageCompleter( - sampling_client=sampling_client, - renderer=renderer, - max_tokens=2048 + sampling_client=sampling_client, renderer=renderer, max_tokens=2048 ) - + async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | None]: train_datapoints = self._get_datapoints_from_jsonl(self.train_jsonl_path) test_datapoints = self._get_datapoints_from_jsonl(self.test_jsonl_path) - renderer = get_renderer(name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer)) + renderer = get_renderer( + name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer) + ) assert train_datapoints is not None, "Train datapoints are required" train_dataset = RubricGradedDataset( @@ -315,4 +325,4 @@ async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | Non datapoints=test_datapoints, grader_llm=self._get_grader_llm(), ) - return train_dataset, test_dataset \ No newline at end of file + return train_dataset, test_dataset diff --git a/tinker_cookbook/recipes/rubric/generate_data.py b/tinker_cookbook/recipes/rubric/generate_data.py index 0eba5030..52ccbdc0 100644 --- a/tinker_cookbook/recipes/rubric/generate_data.py +++ b/tinker_cookbook/recipes/rubric/generate_data.py @@ -2,6 +2,7 @@ import random import os + def generate_one(rng: random.Random) -> RubricBasedDatapoint: x, y = rng.randint(0, 1000), rng.randint(0, 1000) return RubricBasedDatapoint( @@ -9,11 +10,14 @@ def generate_one(rng: random.Random) -> RubricBasedDatapoint: {"role": "user", "content": "What is 4 + 5?"}, {"role": "assistant", "content": "9"}, {"role": "user", "content": f"What is {x} + {y}?"}, - ], rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly gets the answer {x + y}?")] + ], + rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly gets the answer {x + y}?")], ) -def generate_dataset(num_train: int, num_test: int, seed: int, write_dir: str = "tinker_cookbook/example_data/") -> tuple[str, str]: +def generate_dataset( + num_train: int, num_test: int, seed: int, write_dir: str = "tinker_cookbook/example_data/" +) -> tuple[str, str]: random.seed(seed) rng = random.Random(seed) total_datapoints = num_train + num_test @@ -35,7 +39,8 @@ def generate_dataset(num_train: int, num_test: int, seed: int, write_dir: str = return train_jsonl_path, test_jsonl_path + if __name__ == "__main__": train_jsonl_path, test_jsonl_path = generate_dataset(num_train=10000, num_test=1000, seed=42) print(f"Generated train dataset in {train_jsonl_path}") - print(f"Generated test dataset in {test_jsonl_path}") \ No newline at end of file + print(f"Generated test dataset in {test_jsonl_path}") diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py index 6a91c566..f5a2039b 100644 --- a/tinker_cookbook/recipes/rubric/train.py +++ b/tinker_cookbook/recipes/rubric/train.py @@ -7,6 +7,7 @@ from tinker.types import LossFnType from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder + @chz.chz class CLIConfig: """Simple command-line configuration for RL training.""" @@ -78,7 +79,6 @@ def get_dataset_builder( ) - async def cli_main(cli_config: CLIConfig): """Convert CLI config to full config and run training.""" @@ -143,4 +143,4 @@ async def cli_main(cli_config: CLIConfig): if __name__ == "__main__": cli_config = chz.entrypoint(CLIConfig) - asyncio.run(cli_main(cli_config)) \ No newline at end of file + asyncio.run(cli_main(cli_config)) From c51822e9761e9ed097485fcad694b235954291a0 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 20:20:47 +0000 Subject: [PATCH 04/17] b --- tinker_cookbook/recipes/rubric/data.py | 147 +++++++++++++++++++++++++ tinker_cookbook/recipes/rubric/env.py | 132 ++-------------------- 2 files changed, 154 insertions(+), 125 deletions(-) create mode 100644 tinker_cookbook/recipes/rubric/data.py diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py new file mode 100644 index 00000000..734974e5 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/data.py @@ -0,0 +1,147 @@ +from tinker_cookbook.rl.types import ( + Action, + Env, + StepResult, + EnvGroupBuilder, + RLDataset, + RLDatasetBuilder, +) +from tinker_cookbook.renderers import Message, Renderer, Role +from typing import TypeAlias +from tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter +from tinker.types import ModelInput +from dataclasses import dataclass +from typing import Sequence +import re +import json +import chz +import tinker +from tinker_cookbook.tokenizer_utils import get_tokenizer +from tinker_cookbook.renderers import get_renderer +import asyncio +from tinker_cookbook import model_info + +Conversation: TypeAlias = list[Message] + +@dataclass +class Rubric: + """ + A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response. + """ + + rubric_str: str + extraction_regex: str = r"(.*)" + grader_output_format_instruction: str = ( + "Please output your score between 0 and 1 wrapped in ... " + ) + + def __convert_role(self, role: Role) -> str: + return "Human" if role in ("user", "system") else "Chatbot" + + def _flatten_convo(self, convo: Conversation) -> str: + """ + Convert the whole conversation (user's turns + assistant's turns) into a single string. E.g. + \n\nHuman: .... + \n\nChatbot: ... + \n\nHuman: ... + \n\nChatbot: ... + """ + return "\n\n".join( + [f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo] + ) + + def get_grader_prompt(self, convo: Conversation) -> Conversation: + """ + Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. + """ + + prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric." + + prompt += f"Here is the conversation: \n\n{self._flatten_convo(convo)} \n\n\n\nHere is the rubric: \n{self.rubric_str}\n\n" + prompt += f"Please grade the conversation based on the rubric. {self.grader_output_format_instruction}" + return [ + { + "role": "user", + "content": prompt, + } + ] + + def extract_score(self, response: str) -> float: + match = re.search(self.extraction_regex, response, re.DOTALL) + if match is not None: + try: + return float(match.group(1)) + except ValueError: + print(f"Warning: Failed to extract score from grader response: {response}") + return 0.0 + else: + print(f"Warning: Failed to extract score from grader response: {response}") + return 0.0 + + def to_dict(self) -> dict[str, str]: + return { + "rubric_str": self.rubric_str, + "extraction_regex": self.extraction_regex, + "grader_output_format_instruction": self.grader_output_format_instruction, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @staticmethod + def from_dict(d: dict[str, str]) -> "Rubric": + return Rubric( + rubric_str=d["rubric_str"], + extraction_regex=d["extraction_regex"], + grader_output_format_instruction=d["grader_output_format_instruction"], + ) + + @staticmethod + def from_json(json_str: str) -> "Rubric": + return Rubric.from_dict(json.loads(json_str)) + + +@dataclass(frozen=True) +class RubricBasedDatapoint: + """ + A rubric-based datapoint contains a conversation and a rubric. + In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. + """ + + convo: Conversation + rubric_items: Sequence[Rubric] + + def to_json(self) -> str: + return json.dumps( + { + "convo": self.convo, + "rubric_items": [rubric.to_dict() for rubric in self.rubric_items], + } + ) + + @staticmethod + def from_json(json_str: str) -> "RubricBasedDatapoint": + d = json.loads(json_str) + return RubricBasedDatapoint( + convo=d["convo"], + rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]], + ) + + + +class RubricDatapointListBuilder: + + def __call__(self) -> Sequence[RubricBasedDatapoint]: + raise NotImplementedError("Subclass must implement this method") + +@chz.chz +class RubricDatapointListBuilderFromJsonl(RubricDatapointListBuilder): + jsonl_path: str + + def __call__(self) -> Sequence[RubricBasedDatapoint]: + datapoints = [] + with open(self.jsonl_path, "r") as f: + for line in f: + data = json.loads(line) + datapoints.append(RubricBasedDatapoint.from_json(data)) + return datapoints \ No newline at end of file diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py index 360700a4..96d371ff 100644 --- a/tinker_cookbook/recipes/rubric/env.py +++ b/tinker_cookbook/recipes/rubric/env.py @@ -20,9 +20,7 @@ from tinker_cookbook.renderers import get_renderer import asyncio from tinker_cookbook import model_info - - -Conversation: TypeAlias = list[Message] +from tinker_cookbook.recipes.rubric.data import RubricBasedDatapoint, Rubric, Conversation, RubricDatapointListBuilder # ANSI color codes BLUE = "\033[94m" @@ -31,112 +29,6 @@ MAGENTA = "\033[95m" RESET = "\033[0m" - -@dataclass -class Rubric: - """ - A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response. - """ - - rubric_str: str - extraction_regex: str = r"(.*)" - grader_output_format_instruction: str = ( - "Please output your score between 0 and 1 wrapped in ... " - ) - - def __convert_role(self, role: Role) -> str: - return "Human" if role in ("user", "system") else "Chatbot" - - def _flatten_convo(self, convo: Conversation) -> str: - """ - Convert the whole conversation (user's turns + assistant's turns) into a single string. E.g. - \n\nHuman: .... - \n\nChatbot: ... - \n\nHuman: ... - \n\nChatbot: ... - """ - return "\n\n".join( - [f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo] - ) - - def get_grader_prompt(self, convo: Conversation) -> Conversation: - """ - Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. - """ - - prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric." - - prompt += f"Here is the conversation: \n\n{self._flatten_convo(convo)} \n\n\n\nHere is the rubric: \n{self.rubric_str}\n\n" - prompt += f"Please grade the conversation based on the rubric. {self.grader_output_format_instruction}" - return [ - { - "role": "user", - "content": prompt, - } - ] - - def extract_score(self, response: str) -> float: - match = re.search(self.extraction_regex, response, re.DOTALL) - if match is not None: - try: - return float(match.group(1)) - except ValueError: - print(f"Warning: Failed to extract score from grader response: {response}") - return 0.0 - else: - print(f"Warning: Failed to extract score from grader response: {response}") - return 0.0 - - def to_dict(self) -> dict[str, str]: - return { - "rubric_str": self.rubric_str, - "extraction_regex": self.extraction_regex, - "grader_output_format_instruction": self.grader_output_format_instruction, - } - - def to_json(self) -> str: - return json.dumps(self.to_dict()) - - @staticmethod - def from_dict(d: dict[str, str]) -> "Rubric": - return Rubric( - rubric_str=d["rubric_str"], - extraction_regex=d["extraction_regex"], - grader_output_format_instruction=d["grader_output_format_instruction"], - ) - - @staticmethod - def from_json(json_str: str) -> "Rubric": - return Rubric.from_dict(json.loads(json_str)) - - -@dataclass(frozen=True) -class RubricBasedDatapoint: - """ - A rubric-based datapoint contains a conversation and a rubric. - In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. - """ - - convo: Conversation - rubric_items: Sequence[Rubric] - - def to_json(self) -> str: - return json.dumps( - { - "convo": self.convo, - "rubric_items": [rubric.to_dict() for rubric in self.rubric_items], - } - ) - - @staticmethod - def from_json(json_str: str) -> "RubricBasedDatapoint": - d = json.loads(json_str) - return RubricBasedDatapoint( - convo=d["convo"], - rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]], - ) - - class RubricGradedEnv(Env): def __init__( self, @@ -271,24 +163,12 @@ class RubricGradedDatasetBuilder(RLDatasetBuilder): train_group_size: int test_group_size: int = 1 - train_jsonl_path: str - test_jsonl_path: str | None = None + train_datapoint_list_builder: RubricDatapointListBuilder + test_datapoint_list_builder: RubricDatapointListBuilder | None = None base_url: str | None = None grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" - def _get_datapoints_from_jsonl( - self, jsonl_path: str | None - ) -> Sequence[RubricBasedDatapoint] | None: - if jsonl_path is None: - return None - datapoints = [] - with open(jsonl_path, "r") as f: - for line in f: - datapoint = RubricBasedDatapoint.from_json(line) - datapoints.append(datapoint) - return datapoints - def _get_grader_llm(self) -> MessageCompleter: tokenizer = get_tokenizer(self.grader_llm_name) renderer_name = model_info.get_recommended_renderer_name(self.grader_llm_name) @@ -300,8 +180,10 @@ def _get_grader_llm(self) -> MessageCompleter: ) async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | None]: - train_datapoints = self._get_datapoints_from_jsonl(self.train_jsonl_path) - test_datapoints = self._get_datapoints_from_jsonl(self.test_jsonl_path) + train_datapoints = self.train_datapoint_list_builder() + test_datapoints = None + if self.test_datapoint_list_builder is not None: + test_datapoints = self.test_datapoint_list_builder() renderer = get_renderer( name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer) From cf812c7204ce77e28929eaf1d2d51186327cf39a Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 20:21:32 +0000 Subject: [PATCH 05/17] b --- tinker_cookbook/recipes/rubric/data.py | 22 +++++----------------- tinker_cookbook/recipes/rubric/env.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index 734974e5..7e0ba725 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -1,28 +1,17 @@ from tinker_cookbook.rl.types import ( - Action, - Env, - StepResult, - EnvGroupBuilder, - RLDataset, - RLDatasetBuilder, + Message, + Role, ) -from tinker_cookbook.renderers import Message, Renderer, Role from typing import TypeAlias -from tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter -from tinker.types import ModelInput from dataclasses import dataclass from typing import Sequence import re import json import chz -import tinker -from tinker_cookbook.tokenizer_utils import get_tokenizer -from tinker_cookbook.renderers import get_renderer -import asyncio -from tinker_cookbook import model_info Conversation: TypeAlias = list[Message] + @dataclass class Rubric: """ @@ -128,12 +117,11 @@ def from_json(json_str: str) -> "RubricBasedDatapoint": ) - class RubricDatapointListBuilder: - def __call__(self) -> Sequence[RubricBasedDatapoint]: raise NotImplementedError("Subclass must implement this method") + @chz.chz class RubricDatapointListBuilderFromJsonl(RubricDatapointListBuilder): jsonl_path: str @@ -144,4 +132,4 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: for line in f: data = json.loads(line) datapoints.append(RubricBasedDatapoint.from_json(data)) - return datapoints \ No newline at end of file + return datapoints diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py index 96d371ff..218f052e 100644 --- a/tinker_cookbook/recipes/rubric/env.py +++ b/tinker_cookbook/recipes/rubric/env.py @@ -6,13 +6,11 @@ RLDataset, RLDatasetBuilder, ) -from tinker_cookbook.renderers import Message, Renderer, Role -from typing import TypeAlias +from tinker_cookbook.renderers import Renderer from tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter from tinker.types import ModelInput from dataclasses import dataclass from typing import Sequence -import re import json import chz import tinker @@ -20,7 +18,12 @@ from tinker_cookbook.renderers import get_renderer import asyncio from tinker_cookbook import model_info -from tinker_cookbook.recipes.rubric.data import RubricBasedDatapoint, Rubric, Conversation, RubricDatapointListBuilder +from tinker_cookbook.recipes.rubric.data import ( + RubricBasedDatapoint, + Rubric, + Conversation, + RubricDatapointListBuilder, +) # ANSI color codes BLUE = "\033[94m" @@ -29,6 +32,7 @@ MAGENTA = "\033[95m" RESET = "\033[0m" + class RubricGradedEnv(Env): def __init__( self, From 163d81636abd4179c4498f4d3b7e851ce6a221a4 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 20:21:55 +0000 Subject: [PATCH 06/17] b --- tinker_cookbook/recipes/rubric/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index 7e0ba725..5f6f5cde 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -1,4 +1,4 @@ -from tinker_cookbook.rl.types import ( +from tinker_cookbook.renderers import ( Message, Role, ) From f370ab8d2152c5053d07c965b7e9fd4b7bc234d5 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 23:36:35 +0000 Subject: [PATCH 07/17] b --- tinker_cookbook/recipes/rubric/data.py | 45 +++++- tinker_cookbook/recipes/rubric/debug_env.py | 26 +++- .../recipes/rubric/generate_data.py | 2 +- .../recipes/rubric/prometheus_experimental.py | 140 ++++++++++++++++++ tinker_cookbook/recipes/rubric/train.py | 7 +- 5 files changed, 212 insertions(+), 8 deletions(-) create mode 100644 tinker_cookbook/recipes/rubric/prometheus_experimental.py diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index 5f6f5cde..be556f4e 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -116,7 +116,7 @@ def from_json(json_str: str) -> "RubricBasedDatapoint": rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]], ) - +@chz.chz class RubricDatapointListBuilder: def __call__(self) -> Sequence[RubricBasedDatapoint]: raise NotImplementedError("Subclass must implement this method") @@ -133,3 +133,46 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: data = json.loads(line) datapoints.append(RubricBasedDatapoint.from_json(data)) return datapoints + + +@chz.chz +class PrometheusDatapointListBuilder(RubricDatapointListBuilder): + + data_path: str = "prometheus-eval/Feedback-Collection" + + def __call__(self) -> Sequence[RubricBasedDatapoint]: + from datasets import load_dataset + train_dataset = load_dataset(self.data_path)["train"] + return [self.build_rubric_datapoint(item) for item in train_dataset] + + + def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: + + convo = [ + {'role': 'user', 'content': item['orig_instruction']}, + ] + + rubric_text = f"Your job is to evalaute the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n" + rubric_text += f"Here is the calibration for each score:\n" + for i in range(1, 6): + rubric_text += f"{i}.0: {item[f'orig_score{i}_description']}\n" + + rubric_text += f"\nHere is a reference response that achieved a score of 5: {item['orig_reference_answer']}\n" + + + rubric = Rubric( + rubric_str=rubric_text, + extraction_regex=r"(.*)", + grader_output_format_instruction="Please output your score between 1 and 5 wrapped in ... ", + ) + + return RubricBasedDatapoint( + convo=convo, + rubric_items=[rubric], + ) + + + + + + diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py index 9068dfce..dbedf963 100644 --- a/tinker_cookbook/recipes/rubric/debug_env.py +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -7,8 +7,7 @@ from tinker_cookbook.rl.rollouts import do_single_rollout import asyncio - -async def main(): +def get_addition_datapoint() -> RubricBasedDatapoint: datapoint = RubricBasedDatapoint( convo=[ {"role": "user", "content": "What is 4 + 5?"}, @@ -20,6 +19,17 @@ async def main(): Rubric(rubric_str="Does the chatbot provide an answer without saying anything else?"), ], ) + + return datapoint + +def get_prometheus_datapoint() -> RubricBasedDatapoint: + from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder + datapoint = PrometheusDatapointListBuilder()() + datapoint = datapoint[0] + return datapoint + +async def main(datapoint: RubricBasedDatapoint): + policy_name = "meta-llama/Llama-3.1-8B-Instruct" grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" service_client = tinker.ServiceClient() @@ -49,4 +59,14 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + + dataset = "addition" + + if dataset == "addition": + datapoint = get_addition_datapoint() + asyncio.run(main(datapoint)) + elif dataset == "prometheus": + datapoint = get_prometheus_datapoint() + asyncio.run(main(datapoint)) + else: + raise ValueError(f"Unknown dataset: {dataset}") diff --git a/tinker_cookbook/recipes/rubric/generate_data.py b/tinker_cookbook/recipes/rubric/generate_data.py index 52ccbdc0..3ee9587b 100644 --- a/tinker_cookbook/recipes/rubric/generate_data.py +++ b/tinker_cookbook/recipes/rubric/generate_data.py @@ -1,4 +1,4 @@ -from tinker_cookbook.recipes.rubric.env import RubricBasedDatapoint, Rubric +from tinker_cookbook.recipes.rubric.data import RubricBasedDatapoint, Rubric import random import os diff --git a/tinker_cookbook/recipes/rubric/prometheus_experimental.py b/tinker_cookbook/recipes/rubric/prometheus_experimental.py new file mode 100644 index 00000000..2d7eb796 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/prometheus_experimental.py @@ -0,0 +1,140 @@ +import chz +import asyncio +from datetime import datetime +from tinker_cookbook import cli_utils, model_info +from tinker_cookbook.rl.train import AsyncConfig, Config, main +from tinker_cookbook.rl.types import RLDatasetBuilder +from tinker.types import LossFnType +from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder +from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder + + +@chz.chz +class CLIConfig: + """Simple command-line configuration for RL training.""" + + # Model configuration + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + lora_rank: int = 32 + renderer_name: str | None = None + load_checkpoint_path: str | None = None + + seed: int = 0 # Random seed for data shuffling + + # Training hyperparameters + train_group_size: int = 4 + test_group_size: int = 1 + groups_per_batch: int = 100 + learning_rate: float = 1e-5 + max_tokens: int = 5 + temperature: float = 1.0 + kl_penalty_coef: float = 0.0 + grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" + # Number of optimizer steps per training iteration. + # Useful for very large batch sizes. + num_substeps: int = 1 + + # Logging configuration + log_path: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + compute_post_kl: bool = False + + # Evals + eval_every: int = 20 + + # Checkpointing + save_every: int = 20 + + # Service configuration + base_url: str | None = None + + behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" + + max_steps_off_policy: int | None = None + loss_fn: LossFnType = "importance_sampling" + + +def get_dataset_builder( + batch_size: int, + policy_model_name: str, + renderer_name: str, + grader_llm_name: str, + train_group_size: int, + test_group_size: int = 1, +) -> RLDatasetBuilder: + return RubricGradedDatasetBuilder( + batch_size=batch_size, + model_name_for_tokenizer=policy_model_name, + renderer_name=renderer_name, + grader_llm_name=grader_llm_name, + train_datapoint_list_builder=PrometheusDatapointListBuilder(), + test_datapoint_list_builder=None, + train_group_size=train_group_size, + test_group_size=test_group_size, + ) + + +async def cli_main(cli_config: CLIConfig): + """Convert CLI config to full config and run training.""" + + # Get tokenizer for stop sequences + renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name( + cli_config.model_name + ) + model_name = cli_config.model_name.replace("/", "-") + run_name = f"prometheus_experimental-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" + # create log path if it doesn't exist + if cli_config.log_path is not None: + log_path = cli_config.log_path + else: + log_path = f"/tmp/tinker-examples/rubric/{run_name}" + + if cli_config.wandb_name is not None: + wandb_name = cli_config.wandb_name + else: + wandb_name = run_name + + # Create full config + config = Config( + learning_rate=cli_config.learning_rate, + dataset_builder=get_dataset_builder( + batch_size=cli_config.groups_per_batch, + policy_model_name=cli_config.model_name, + renderer_name=renderer_name, + grader_llm_name=cli_config.grader_llm_name, + train_group_size=cli_config.train_group_size, + test_group_size=cli_config.test_group_size, + ), + model_name=cli_config.model_name, + lora_rank=cli_config.lora_rank, + max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, + wandb_project=cli_config.wandb_project, + wandb_name=wandb_name, + log_path=log_path, + base_url=cli_config.base_url, + load_checkpoint_path=cli_config.load_checkpoint_path, + compute_post_kl=cli_config.compute_post_kl, + kl_penalty_coef=cli_config.kl_penalty_coef, + num_substeps=cli_config.num_substeps, + eval_every=cli_config.eval_every, + save_every=cli_config.save_every, + async_config=AsyncConfig( + max_steps_off_policy=cli_config.max_steps_off_policy, + groups_per_batch=cli_config.groups_per_batch, + ) + if cli_config.max_steps_off_policy is not None + else None, + loss_fn=cli_config.loss_fn, + ) + + cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) + + # Run training + await main(config) + + +if __name__ == "__main__": + cli_config = chz.entrypoint(CLIConfig) + asyncio.run(cli_main(cli_config)) diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py index f5a2039b..26a7faba 100644 --- a/tinker_cookbook/recipes/rubric/train.py +++ b/tinker_cookbook/recipes/rubric/train.py @@ -6,6 +6,7 @@ from tinker_cookbook.rl.types import RLDatasetBuilder from tinker.types import LossFnType from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder +from tinker_cookbook.recipes.rubric.data import RubricDatapointListBuilderFromJsonl @chz.chz @@ -72,8 +73,8 @@ def get_dataset_builder( model_name_for_tokenizer=policy_model_name, renderer_name=renderer_name, grader_llm_name=grader_llm_name, - train_jsonl_path=train_jsonl_path, - test_jsonl_path=test_jsonl_path, + train_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=train_jsonl_path), + test_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=test_jsonl_path) if test_jsonl_path is not None else None, train_group_size=train_group_size, test_group_size=test_group_size, ) @@ -92,7 +93,7 @@ async def cli_main(cli_config: CLIConfig): if cli_config.log_path is not None: log_path = cli_config.log_path else: - log_path = f"/tmp/tinker-examples/math_rl/{run_name}" + log_path = f"/tmp/tinker-examples/rubric/{run_name}" if cli_config.wandb_name is not None: wandb_name = cli_config.wandb_name From 42920860a6e596dfd64efed368bc190d921553ee Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 23:37:07 +0000 Subject: [PATCH 08/17] b --- tinker_cookbook/recipes/rubric/data.py | 16 ++++------------ tinker_cookbook/recipes/rubric/debug_env.py | 6 ++++-- tinker_cookbook/recipes/rubric/train.py | 8 ++++++-- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index be556f4e..d30a4aad 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -116,6 +116,7 @@ def from_json(json_str: str) -> "RubricBasedDatapoint": rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]], ) + @chz.chz class RubricDatapointListBuilder: def __call__(self) -> Sequence[RubricBasedDatapoint]: @@ -137,29 +138,26 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: @chz.chz class PrometheusDatapointListBuilder(RubricDatapointListBuilder): - data_path: str = "prometheus-eval/Feedback-Collection" def __call__(self) -> Sequence[RubricBasedDatapoint]: from datasets import load_dataset + train_dataset = load_dataset(self.data_path)["train"] return [self.build_rubric_datapoint(item) for item in train_dataset] - def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: - convo = [ - {'role': 'user', 'content': item['orig_instruction']}, + {"role": "user", "content": item["orig_instruction"]}, ] rubric_text = f"Your job is to evalaute the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n" - rubric_text += f"Here is the calibration for each score:\n" + rubric_text += "Here is the calibration for each score:\n" for i in range(1, 6): rubric_text += f"{i}.0: {item[f'orig_score{i}_description']}\n" rubric_text += f"\nHere is a reference response that achieved a score of 5: {item['orig_reference_answer']}\n" - rubric = Rubric( rubric_str=rubric_text, extraction_regex=r"(.*)", @@ -170,9 +168,3 @@ def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: convo=convo, rubric_items=[rubric], ) - - - - - - diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py index dbedf963..84950bbc 100644 --- a/tinker_cookbook/recipes/rubric/debug_env.py +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -7,6 +7,7 @@ from tinker_cookbook.rl.rollouts import do_single_rollout import asyncio + def get_addition_datapoint() -> RubricBasedDatapoint: datapoint = RubricBasedDatapoint( convo=[ @@ -22,14 +23,16 @@ def get_addition_datapoint() -> RubricBasedDatapoint: return datapoint + def get_prometheus_datapoint() -> RubricBasedDatapoint: from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder + datapoint = PrometheusDatapointListBuilder()() datapoint = datapoint[0] return datapoint -async def main(datapoint: RubricBasedDatapoint): +async def main(datapoint: RubricBasedDatapoint): policy_name = "meta-llama/Llama-3.1-8B-Instruct" grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" service_client = tinker.ServiceClient() @@ -59,7 +62,6 @@ async def main(datapoint: RubricBasedDatapoint): if __name__ == "__main__": - dataset = "addition" if dataset == "addition": diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py index 26a7faba..88953d4b 100644 --- a/tinker_cookbook/recipes/rubric/train.py +++ b/tinker_cookbook/recipes/rubric/train.py @@ -73,8 +73,12 @@ def get_dataset_builder( model_name_for_tokenizer=policy_model_name, renderer_name=renderer_name, grader_llm_name=grader_llm_name, - train_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=train_jsonl_path), - test_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=test_jsonl_path) if test_jsonl_path is not None else None, + train_datapoint_list_builder=RubricDatapointListBuilderFromJsonl( + jsonl_path=train_jsonl_path + ), + test_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=test_jsonl_path) + if test_jsonl_path is not None + else None, train_group_size=train_group_size, test_group_size=test_group_size, ) From 56f9bebd834a4bec044dfc820e545a2ac3b4da47 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 23:53:25 +0000 Subject: [PATCH 09/17] adding readme --- tinker_cookbook/recipes/rubric/README.md | 94 ++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tinker_cookbook/recipes/rubric/README.md diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md new file mode 100644 index 00000000..e2835975 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/README.md @@ -0,0 +1,94 @@ +# Rubric-based Grading for LLMs + +- [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix, and a list of rubric items. +- [`generate_data.py`](./generate_data.py) generates some example datapoint if you want to run our demo on addition. +- [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader. +- [`train.py`](./train.py) allows you train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`. +- [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed. + + +## A simple example of using a grader LLM with rubrics + +We show how to use rubric-based LLM to provide reward on an addition task. E.g. + +``` +**User**: What's 233 + 100? +**Assistant**: 333 +``` + +Usually, this could be graded by matching the number to the ground truth 333 without needing an LLLM. However, for pedagogical purposes we will grade the response using a language model with rubric. I.e. We will ask a language mode "Does the assistant answer 333?" + +### Generate an example dataset + +To run this, first generate a dataset: + +``` +python -m tinker_cookbook.recipes.rubric.generate_data +``` + +Then you will see two `jsonl` file generated, one for training, one for testing. For example, if you look into ` tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of +- a convo (the conversation prefix that the policy sees) +- rubric_items: a list of rubric items that specify what is a good item, how the grader should format the response, and how the grading result should be extracted. + +``` +{ + "convo": [ + { + "role": "user", + "content": "What is 4 + 5?" + }, + { + "role": "assistant", + "content": "9" + }, + { + "role": "user", + "content": "What is 122 + 12?" + } + ], + "rubric_items": [ + { + "rubric_str": "Does the chatbot correctly gets the answer 134?", + "extraction_regex": "(.*)", + "grader_output_format_instruction": "Please output your score between 0 and 1 wrapped in ... " + } + ] +} +``` + +### Debugging and Printing What Happens During Rollouts + +Run +``` +python -m tinker_cookbook.recipes.rubric.debug_env +``` + +You can see the message that the policy sees, its response, the grader input, and the grader output. + +image + + +### An example training run + +To train the LLM to add with a rubric-based LLM, run +``` +python -m tinker_cookbook.recipes.rubric.train +``` + +You can see the reward quickly goes up. + +image + +### A more realistic dataset + +We take the `prometheus-eval/Feedback-Collection` dataset from hugingface [link](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubric to grade general chat responses. Run the following to kick off training: + +``` +python -m tinker_cookbook.recipes.rubric.prometheus_experimental +``` + +We can see that the reward climbs up steadily. + +image + +Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting for you to improve rubric-based grading for training LLMs! \ No newline at end of file From 4213c9fe8e89b6793eeed6c71cfb8efb286b46e6 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 23:53:37 +0000 Subject: [PATCH 10/17] adding readme --- tinker_cookbook/recipes/rubric/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md index e2835975..797f0136 100644 --- a/tinker_cookbook/recipes/rubric/README.md +++ b/tinker_cookbook/recipes/rubric/README.md @@ -4,7 +4,7 @@ - [`generate_data.py`](./generate_data.py) generates some example datapoint if you want to run our demo on addition. - [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader. - [`train.py`](./train.py) allows you train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`. -- [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed. +- [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed. ## A simple example of using a grader LLM with rubrics @@ -26,7 +26,7 @@ To run this, first generate a dataset: python -m tinker_cookbook.recipes.rubric.generate_data ``` -Then you will see two `jsonl` file generated, one for training, one for testing. For example, if you look into ` tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of +Then you will see two `jsonl` file generated, one for training, one for testing. For example, if you look into ` tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of - a convo (the conversation prefix that the policy sees) - rubric_items: a list of rubric items that specify what is a good item, how the grader should format the response, and how the grading result should be extracted. @@ -58,12 +58,12 @@ Then you will see two `jsonl` file generated, one for training, one for testing. ### Debugging and Printing What Happens During Rollouts -Run +Run ``` python -m tinker_cookbook.recipes.rubric.debug_env ``` -You can see the message that the policy sees, its response, the grader input, and the grader output. +You can see the message that the policy sees, its response, the grader input, and the grader output. image @@ -91,4 +91,4 @@ We can see that the reward climbs up steadily. image -Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting for you to improve rubric-based grading for training LLMs! \ No newline at end of file +Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting for you to improve rubric-based grading for training LLMs! From 456a03c6bb57c2077af268dbc3dece59d34a4cd6 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 23:58:21 +0000 Subject: [PATCH 11/17] b --- tinker_cookbook/recipes/rubric/README.md | 4 ++-- tinker_cookbook/recipes/rubric/data.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md index 797f0136..b3bb0b76 100644 --- a/tinker_cookbook/recipes/rubric/README.md +++ b/tinker_cookbook/recipes/rubric/README.md @@ -1,6 +1,6 @@ # Rubric-based Grading for LLMs -- [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix, and a list of rubric items. +- [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix and a list of rubric items. - [`generate_data.py`](./generate_data.py) generates some example datapoint if you want to run our demo on addition. - [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader. - [`train.py`](./train.py) allows you train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`. @@ -81,7 +81,7 @@ You can see the reward quickly goes up. ### A more realistic dataset -We take the `prometheus-eval/Feedback-Collection` dataset from hugingface [link](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubric to grade general chat responses. Run the following to kick off training: +We take the `prometheus-eval/Feedback-Collection` dataset from [hugingface](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubric to grade general chat responses. Run the following to kick off training: ``` python -m tinker_cookbook.recipes.rubric.prometheus_experimental diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index d30a4aad..bf430106 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -144,10 +144,10 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: from datasets import load_dataset train_dataset = load_dataset(self.data_path)["train"] - return [self.build_rubric_datapoint(item) for item in train_dataset] + return [self.build_rubric_datapoint(item) for item in train_dataset] # type: ignore def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: - convo = [ + convo: Conversation = [ {"role": "user", "content": item["orig_instruction"]}, ] From d073bf85a77f93b9f7c96b3c03e02f191f3bb101 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sat, 13 Dec 2025 23:58:34 +0000 Subject: [PATCH 12/17] b --- tinker_cookbook/recipes/rubric/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index bf430106..a0915d7e 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -144,7 +144,7 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: from datasets import load_dataset train_dataset = load_dataset(self.data_path)["train"] - return [self.build_rubric_datapoint(item) for item in train_dataset] # type: ignore + return [self.build_rubric_datapoint(item) for item in train_dataset] # type: ignore def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: convo: Conversation = [ From 0fafd14285acc654f2ad24722765b70912421fd1 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sun, 14 Dec 2025 00:04:07 +0000 Subject: [PATCH 13/17] b --- tinker_cookbook/recipes/rubric/data.py | 5 ++--- tinker_cookbook/recipes/rubric/generate_data.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index a0915d7e..c9011b62 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -131,8 +131,7 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: datapoints = [] with open(self.jsonl_path, "r") as f: for line in f: - data = json.loads(line) - datapoints.append(RubricBasedDatapoint.from_json(data)) + datapoints.append(RubricBasedDatapoint.from_json(line)) return datapoints @@ -151,7 +150,7 @@ def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: {"role": "user", "content": item["orig_instruction"]}, ] - rubric_text = f"Your job is to evalaute the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n" + rubric_text = f"Your job is to evaluate the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n" rubric_text += "Here is the calibration for each score:\n" for i in range(1, 6): rubric_text += f"{i}.0: {item[f'orig_score{i}_description']}\n" diff --git a/tinker_cookbook/recipes/rubric/generate_data.py b/tinker_cookbook/recipes/rubric/generate_data.py index 3ee9587b..922db4d3 100644 --- a/tinker_cookbook/recipes/rubric/generate_data.py +++ b/tinker_cookbook/recipes/rubric/generate_data.py @@ -11,7 +11,7 @@ def generate_one(rng: random.Random) -> RubricBasedDatapoint: {"role": "assistant", "content": "9"}, {"role": "user", "content": f"What is {x} + {y}?"}, ], - rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly gets the answer {x + y}?")], + rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly get the answer {x + y}?")], ) From 07a54efcc2a25b285d44a22a9e191e57a7a0eb24 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Sun, 14 Dec 2025 00:10:59 +0000 Subject: [PATCH 14/17] b --- tinker_cookbook/recipes/rubric/README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md index b3bb0b76..92a3026e 100644 --- a/tinker_cookbook/recipes/rubric/README.md +++ b/tinker_cookbook/recipes/rubric/README.md @@ -1,22 +1,22 @@ # Rubric-based Grading for LLMs - [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix and a list of rubric items. -- [`generate_data.py`](./generate_data.py) generates some example datapoint if you want to run our demo on addition. +- [`generate_data.py`](./generate_data.py) generates some example datapoints if you want to run our demo on addition. - [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader. -- [`train.py`](./train.py) allows you train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`. +- [`train.py`](./train.py) allows you to train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`. - [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed. ## A simple example of using a grader LLM with rubrics -We show how to use rubric-based LLM to provide reward on an addition task. E.g. +We show how to use a rubric-based LLM to provide a reward for an addition task. E.g. ``` **User**: What's 233 + 100? **Assistant**: 333 ``` -Usually, this could be graded by matching the number to the ground truth 333 without needing an LLLM. However, for pedagogical purposes we will grade the response using a language model with rubric. I.e. We will ask a language mode "Does the assistant answer 333?" +Usually, this could be graded by matching the number to the ground truth 333 without needing an LLM. However, for pedagogical purposes, we will grade the response using a language model with a rubric. That is, we will ask a language model "Does the assistant answer 333?" ### Generate an example dataset @@ -26,9 +26,9 @@ To run this, first generate a dataset: python -m tinker_cookbook.recipes.rubric.generate_data ``` -Then you will see two `jsonl` file generated, one for training, one for testing. For example, if you look into ` tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of +Then you will see two `jsonl` files generated, one for training, one for testing. For example, if you look into `tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of - a convo (the conversation prefix that the policy sees) -- rubric_items: a list of rubric items that specify what is a good item, how the grader should format the response, and how the grading result should be extracted. +- rubric_items: a list of rubric items that specify what is a good response, how the grader should format the response, and how the grading result should be extracted. ``` { @@ -48,7 +48,7 @@ Then you will see two `jsonl` file generated, one for training, one for testing. ], "rubric_items": [ { - "rubric_str": "Does the chatbot correctly gets the answer 134?", + "rubric_str": "Does the chatbot correctly get the answer 134?", "extraction_regex": "(.*)", "grader_output_format_instruction": "Please output your score between 0 and 1 wrapped in ... " } @@ -81,7 +81,7 @@ You can see the reward quickly goes up. ### A more realistic dataset -We take the `prometheus-eval/Feedback-Collection` dataset from [hugingface](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubric to grade general chat responses. Run the following to kick off training: +We take the `prometheus-eval/Feedback-Collection` dataset from [Hugging Face](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubrics to grade general chat responses. Run the following to kick off training: ``` python -m tinker_cookbook.recipes.rubric.prometheus_experimental @@ -91,4 +91,4 @@ We can see that the reward climbs up steadily. image -Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting for you to improve rubric-based grading for training LLMs! +Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting point for you to improve rubric-based grading for training LLMs! From 99feebaa98d8baf5cf52632781c1e7b42ad77a8a Mon Sep 17 00:00:00 2001 From: John Schulman Date: Sat, 20 Dec 2025 06:28:05 +0000 Subject: [PATCH 15/17] . --- tinker_cookbook/recipes/rubric/README.md | 6 +- tinker_cookbook/recipes/rubric/data.py | 54 ++++++++++++----- tinker_cookbook/recipes/rubric/debug_env.py | 8 ++- tinker_cookbook/recipes/rubric/env.py | 66 +++++++++++---------- 4 files changed, 83 insertions(+), 51 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md index 92a3026e..920b99b9 100644 --- a/tinker_cookbook/recipes/rubric/README.md +++ b/tinker_cookbook/recipes/rubric/README.md @@ -65,7 +65,7 @@ python -m tinker_cookbook.recipes.rubric.debug_env You can see the message that the policy sees, its response, the grader input, and the grader output. -image +Debug output showing the conversation context, policy response, grader prompt, and extracted score ### An example training run @@ -77,7 +77,7 @@ python -m tinker_cookbook.recipes.rubric.train You can see the reward quickly goes up. -image +Training metrics showing reward increasing over training steps for the addition task ### A more realistic dataset @@ -89,6 +89,6 @@ python -m tinker_cookbook.recipes.rubric.prometheus_experimental We can see that the reward climbs up steadily. -image +Training metrics showing reward climbing steadily over training steps for the Prometheus dataset Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting point for you to improve rubric-based grading for training LLMs! diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index c9011b62..ca2e6f14 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -2,7 +2,7 @@ Message, Role, ) -from typing import TypeAlias +from typing import Any, TypeAlias from dataclasses import dataclass from typing import Sequence import re @@ -24,7 +24,7 @@ class Rubric: "Please output your score between 0 and 1 wrapped in ... " ) - def __convert_role(self, role: Role) -> str: + def _convert_role(self, role: Role) -> str: return "Human" if role in ("user", "system") else "Chatbot" def _flatten_convo(self, convo: Conversation) -> str: @@ -36,22 +36,41 @@ def _flatten_convo(self, convo: Conversation) -> str: \n\nChatbot: ... """ return "\n\n".join( - [f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo] + [f"{self._convert_role(message['role'])}: {message['content']}" for message in convo] ) def get_grader_prompt(self, convo: Conversation) -> Conversation: """ - Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. + Create a prompt for the grader to grade the conversation based on the rubric. + The prompt separates the context (prior turns) from the completion (last assistant message) + so the grader focuses on grading the most recent response. """ - - prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric." - - prompt += f"Here is the conversation: \n\n{self._flatten_convo(convo)} \n\n\n\nHere is the rubric: \n{self.rubric_str}\n\n" - prompt += f"Please grade the conversation based on the rubric. {self.grader_output_format_instruction}" + # Separate context from the completion to grade + context = convo[:-1] + completion = convo[-1] + + lines = [ + "I will show you a conversation context, a chatbot completion to grade, and a rubric.", + "Please grade the chatbot's completion based on the rubric.", + "", + "", + self._flatten_convo(context) if context else "(No prior context)", + "", + "", + "", + f"Chatbot: {completion['content']}", + "", + "", + "", + self.rubric_str, + "", + "", + f"Please grade the chatbot's completion based on the rubric. {self.grader_output_format_instruction}", + ] return [ { "role": "user", - "content": prompt, + "content": "\n".join(lines), } ] @@ -120,6 +139,7 @@ def from_json(json_str: str) -> "RubricBasedDatapoint": @chz.chz class RubricDatapointListBuilder: def __call__(self) -> Sequence[RubricBasedDatapoint]: + """Load and return a sequence of rubric-based datapoints.""" raise NotImplementedError("Subclass must implement this method") @@ -145,17 +165,19 @@ def __call__(self) -> Sequence[RubricBasedDatapoint]: train_dataset = load_dataset(self.data_path)["train"] return [self.build_rubric_datapoint(item) for item in train_dataset] # type: ignore - def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: + def build_rubric_datapoint(self, item: dict[str, Any]) -> RubricBasedDatapoint: convo: Conversation = [ {"role": "user", "content": item["orig_instruction"]}, ] - rubric_text = f"Your job is to evaluate the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n" - rubric_text += "Here is the calibration for each score:\n" + rubric_lines = [ + f"Your job is to evaluate the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.", + "Here is the calibration for each score:", + ] for i in range(1, 6): - rubric_text += f"{i}.0: {item[f'orig_score{i}_description']}\n" - - rubric_text += f"\nHere is a reference response that achieved a score of 5: {item['orig_reference_answer']}\n" + rubric_lines.append(f"{i}.0: {item[f'orig_score{i}_description']}") + rubric_lines.append(f"Here is a reference response that achieved a score of 5: {item['orig_reference_answer']}") + rubric_text = "\n".join(rubric_lines) rubric = Rubric( rubric_str=rubric_text, diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py index 84950bbc..6df88d72 100644 --- a/tinker_cookbook/recipes/rubric/debug_env.py +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -33,12 +33,16 @@ def get_prometheus_datapoint() -> RubricBasedDatapoint: async def main(datapoint: RubricBasedDatapoint): + # Configuration parameters policy_name = "meta-llama/Llama-3.1-8B-Instruct" grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + policy_max_tokens = 64 + grader_max_tokens = 64 + service_client = tinker.ServiceClient() policy = TinkerTokenCompleter( sampling_client=service_client.create_sampling_client(base_model=policy_name), - max_tokens=64, + max_tokens=policy_max_tokens, ) policy_renderer = get_renderer( model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name) @@ -48,7 +52,7 @@ async def main(datapoint: RubricBasedDatapoint): renderer=get_renderer( model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name) ), - max_tokens=64, + max_tokens=grader_max_tokens, ) env = RubricGradedEnv( diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py index 218f052e..0435c308 100644 --- a/tinker_cookbook/recipes/rubric/env.py +++ b/tinker_cookbook/recipes/rubric/env.py @@ -24,13 +24,7 @@ Conversation, RubricDatapointListBuilder, ) - -# ANSI color codes -BLUE = "\033[94m" -GREEN = "\033[92m" -YELLOW = "\033[93m" -MAGENTA = "\033[95m" -RESET = "\033[0m" +from termcolor import colored class RubricGradedEnv(Env): @@ -40,14 +34,17 @@ def __init__( datapoint: RubricBasedDatapoint, grader_llm: MessageCompleter, debug: bool = False, + format_coef: float = 0.1, ): """ - Initialize the RubricGradedEnv. In this environment, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. + Initialize the RubricGradedEnv. In this environment, the policy model sees the conversation, + creates a response, and then the grader language model grades the response based on the rubric. """ self.renderer = renderer self.datapoint = datapoint self.grader_llm = grader_llm self.debug = debug + self.format_coef = format_coef @property def rubric_items(self) -> Sequence[Rubric]: @@ -75,33 +72,35 @@ async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> float assert isinstance(grader_response_content, str), "Grader response content must be a string" score = rubric.extract_score(grader_response_content) if self.debug: - print(f"{YELLOW}{'=' * 80}") - print("DEBUG: First Turn of Grader Prompt") - print(f"{'=' * 80}{RESET}") - print(f"{YELLOW}{grader_prompt[0]['content']}{RESET}\n") - - print(f"{MAGENTA}{'=' * 80}") - print("DEBUG: Score") - print(f"{'=' * 80}{RESET}") - print(f"{MAGENTA}Grader Response: {grader_response_content}{RESET}\n") - print(f"{MAGENTA}Extracted Score: {score}{RESET}\n") + print(colored("=" * 80, "yellow")) + print(colored("DEBUG: First Turn of Grader Prompt", "yellow")) + print(colored("=" * 80, "yellow")) + print(colored(grader_prompt[0]["content"], "yellow") + "\n") + + print(colored("=" * 80, "magenta")) + print(colored("DEBUG: Score", "magenta")) + print(colored("=" * 80, "magenta")) + print(colored(f"Grader Response: {grader_response_content}", "magenta") + "\n") + print(colored(f"Extracted Score: {score}", "magenta") + "\n") return score async def step(self, action: Action) -> StepResult: # obtain the policy action message - (policy_action_message, _parse_success) = self.renderer.parse_response(action) + (policy_action_message, parse_success) = self.renderer.parse_response(action) + correct_format = float(parse_success) if self.debug: - print(f"\n{BLUE}{'=' * 80}") - print("DEBUG: Original Conversation (self.convo)") - print(f"{'=' * 80}{RESET}") - print(f"{BLUE}{json.dumps(self.convo, indent=2)}{RESET}\n") - - print(f"{GREEN}{'=' * 80}") - print("DEBUG: Policy Action Message") - print(f"{'=' * 80}{RESET}") - print(f"{GREEN}{json.dumps(policy_action_message, indent=2)}{RESET}\n") - # this shows the full back-and-forth conversation to the grader + print("\n" + colored("=" * 80, "blue")) + print(colored("DEBUG: Original Conversation (self.convo)", "blue")) + print(colored("=" * 80, "blue")) + print(colored(json.dumps(self.convo, indent=2), "blue") + "\n") + + print(colored("=" * 80, "green")) + print(colored("DEBUG: Policy Action Message", "green")) + print(colored("=" * 80, "green")) + print(colored(json.dumps(policy_action_message, indent=2), "green") + "\n") + print(colored(f"Parse Success: {parse_success}", "green") + "\n") + convo = self.convo + [policy_action_message] scores = await asyncio.gather( @@ -109,11 +108,18 @@ async def step(self, action: Action) -> StepResult: ) avg_score = sum(scores) / len(scores) + # Apply format penalty similar to ProblemEnv + total_reward = self.format_coef * (correct_format - 1) + avg_score + return StepResult( - reward=avg_score, + reward=total_reward, episode_done=True, next_observation=self.renderer.build_generation_prompt(convo), next_stop_condition=self.stop_condition, + metrics={ + "format": correct_format, + "rubric_score": avg_score, + }, ) From 74a2eb87fc4fdc6d8e6118eea8366477e41b3ede Mon Sep 17 00:00:00 2001 From: John Schulman Date: Sat, 20 Dec 2025 06:29:09 +0000 Subject: [PATCH 16/17] . --- tinker_cookbook/recipes/rubric/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index ca2e6f14..82cf0842 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -176,7 +176,9 @@ def build_rubric_datapoint(self, item: dict[str, Any]) -> RubricBasedDatapoint: ] for i in range(1, 6): rubric_lines.append(f"{i}.0: {item[f'orig_score{i}_description']}") - rubric_lines.append(f"Here is a reference response that achieved a score of 5: {item['orig_reference_answer']}") + rubric_lines.append( + f"Here is a reference response that achieved a score of 5: {item['orig_reference_answer']}" + ) rubric_text = "\n".join(rubric_lines) rubric = Rubric( From 29a6fc339c1e09ba3ef3540de2b1b04b1ca7f88d Mon Sep 17 00:00:00 2001 From: John Schulman Date: Sat, 20 Dec 2025 06:39:05 +0000 Subject: [PATCH 17/17] . --- tinker_cookbook/recipes/rubric/data.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py index 82cf0842..8f61f08e 100644 --- a/tinker_cookbook/recipes/rubric/data.py +++ b/tinker_cookbook/recipes/rubric/data.py @@ -1,13 +1,15 @@ +import json +import os +import re +from dataclasses import dataclass +from typing import Any, Sequence, TypeAlias + +import chz + from tinker_cookbook.renderers import ( Message, Role, ) -from typing import Any, TypeAlias -from dataclasses import dataclass -from typing import Sequence -import re -import json -import chz Conversation: TypeAlias = list[Message] @@ -148,6 +150,12 @@ class RubricDatapointListBuilderFromJsonl(RubricDatapointListBuilder): jsonl_path: str def __call__(self) -> Sequence[RubricBasedDatapoint]: + if not os.path.exists(self.jsonl_path): + raise FileNotFoundError( + f"Data file not found: {self.jsonl_path}\n" + f"Please generate the example data first by running:\n" + f" python -m tinker_cookbook.recipes.rubric.generate_data" + ) datapoints = [] with open(self.jsonl_path, "r") as f: for line in f: