Implement presence and frequency penalties (#95)#27
Implement presence and frequency penalties (#95)#27MitchLewis930 wants to merge 1 commit intofloat_comparison_beforefrom
Conversation
📝 WalkthroughWalkthroughThis pull request refactors the data model to centralize per-sequence information into a new Changes
Sequence Diagram(s)sequenceDiagram
participant Sampler as Sampler.forward()
participant InputMeta as InputMetadata
participant GetPen as _get_penalties()
participant GetOut as _get_output_tokens()
participant ApplyPen as _apply_penalties()
participant LogitProc as Logit Processing
Sampler->>InputMeta: access seq_data
Sampler->>GetPen: extract penalties per seq
GetPen->>Sampler: List[presence_penalties], List[frequency_penalties]
Sampler->>GetOut: gather output_token_ids per seq
GetOut->>Sampler: List[List[int]]
Sampler->>ApplyPen: pass logits + tokens + penalties
ApplyPen->>ApplyPen: compute token frequencies (numpy bincount)
ApplyPen->>ApplyPen: apply frequency & presence penalties to logits
ApplyPen->>Sampler: modified logits
Sampler->>LogitProc: apply temperature scaling
LogitProc->>LogitProc: compute softmax probabilities
Sampler->>Sampler: sample from distribution
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cacheflow/sampling_params.py (1)
6-18: Confirmed breaking API change —SamplingParams.__init__requirespresence_penalty,frequency_penalty, andtop_kwith no defaults.Direct instantiation calls in
benchmark/trace.py(lines 86–110, 9 occurrences) andbenchmark/benchmark_text_completion.py(line 46) will fail at runtime because they do not provide these three required arguments. The dictionariesrandom_sampling_params_dictandbeam_search_params_dictare also missing these keys.Either add default values to
__init__(e.g.,presence_penalty: float = 0.0) or update all direct constructor calls to include these arguments. Alternatively, migrate these call sites to useSamplingParams.from_dict(), which already provides sensible defaults.
🤖 Fix all issues with AI agents
In `@cacheflow/core/scheduler.py`:
- Around line 250-254: The type annotation for seq_data is incorrect: it is
declared as Dict[int, List[SequenceData]] but is populated with single
SequenceData instances from seq_group.get_seqs(status=SequenceStatus.RUNNING);
change the annotation to Dict[int, SequenceData] (and update any other
occurrences like the similar annotation at the other location around the seq
processing, e.g., line ~260) to match SequenceGroupMetadata and the actual
values returned by seq.seq_id / seq.data.
In `@cacheflow/sequence.py`:
- Around line 38-42: The __repr__ in SequenceData refers to a non-existent
attribute self.prompt, which raises AttributeError; update the
SequenceData.__repr__ to stop referencing self.prompt and instead use existing
attributes (e.g., self.prompt_token_ids and self.output_token_ids) or a
derived/memoized prompt property if intended; modify the string construction in
__repr__ (the method named __repr__ on class SequenceData) to include only valid
attributes such as prompt_token_ids and output_token_ids (or call the proper
accessor if a prompt getter exists).
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
cacheflow/core/scheduler.pycacheflow/frontend/fastapi_frontend.pycacheflow/frontend/simple_frontend.pycacheflow/model_executor/input_metadata.pycacheflow/model_executor/layers/sampler.pycacheflow/sampling_params.pycacheflow/sequence.pycacheflow/worker/worker.pysimple_server.py
🧰 Additional context used
🧬 Code graph analysis (7)
cacheflow/core/scheduler.py (2)
cacheflow/sequence.py (7)
Sequence(45-111)SequenceData(16-42)SequenceGroup(114-149)SequenceGroupMetadata(152-166)SequenceOutputs(169-193)SequenceStatus(9-13)get_seqs(126-133)cacheflow/core/block_manager.py (1)
get_block_table(238-240)
cacheflow/worker/worker.py (3)
cacheflow/sequence.py (8)
SequenceData(16-42)SequenceGroupMetadata(152-166)get_token_ids(30-31)get_token_ids(97-98)get_last_token_id(33-36)get_last_token_id(100-101)get_len(27-28)get_len(94-95)cacheflow/block.py (2)
get_token_ids(35-36)get_last_token_id(38-40)cacheflow/model_executor/input_metadata.py (1)
InputMetadata(10-54)
cacheflow/frontend/simple_frontend.py (1)
cacheflow/sequence.py (1)
Sequence(45-111)
cacheflow/model_executor/input_metadata.py (2)
cacheflow/sampling_params.py (1)
SamplingParams(4-106)cacheflow/sequence.py (1)
SequenceData(16-42)
cacheflow/model_executor/layers/sampler.py (1)
cacheflow/model_executor/input_metadata.py (1)
InputMetadata(10-54)
cacheflow/sequence.py (3)
cacheflow/block.py (3)
get_token_ids(35-36)get_last_token_id(38-40)LogicalTokenBlock(8-40)cacheflow/core/block_manager.py (1)
fork(133-139)cacheflow/worker/cache_engine.py (1)
copy(123-127)
cacheflow/frontend/fastapi_frontend.py (1)
cacheflow/sequence.py (1)
Sequence(45-111)
🪛 Ruff (0.14.13)
cacheflow/sampling_params.py
22-23: Avoid specifying long messages outside the exception class
(TRY003)
25-26: Avoid specifying long messages outside the exception class
(TRY003)
105-105: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (19)
simple_server.py (1)
14-15: LGTM — test inputs now exercise the new penalties.cacheflow/frontend/fastapi_frontend.py (1)
96-100: LGTM — prompt is now propagated into Sequence.cacheflow/frontend/simple_frontend.py (1)
37-53: LGTM — prompt now flows into Sequence creation.cacheflow/model_executor/input_metadata.py (1)
14-24: LGTM — InputMetadata now carriesseq_dataconsistently.cacheflow/worker/worker.py (4)
1-12: LGTM!The import changes correctly add
Optionalfor type hints and update the sequence imports to includeSequenceDataalongside existing imports, aligning with the new data model.
108-131: LGTM!The prompt handling correctly derives sequence IDs from
seq_data.keys()and usesSequenceData.get_token_ids()to retrieve prompt tokens. The logic flow is preserved while adopting the new data model.
142-166: LGTM!The generation token handling correctly uses
SequenceDataaccessors:get_last_token_id()for retrieving the token to process andget_len()for context length calculations. The position calculation remains correct.
188-200: LGTM!The seq_data aggregation correctly merges per-group sequence data into a combined dictionary for
InputMetadata. This aligns with the updatedInputMetadatasignature that expectsseq_data: Dict[int, SequenceData].cacheflow/model_executor/layers/sampler.py (6)
35-43: LGTM!The penalty application is correctly positioned before temperature scaling, following the standard order for LLM sampling. The assertions properly validate that the penalty arrays match the batch dimension.
84-102: LGTM!The
_get_penaltiesfunction correctly follows the same pattern as_get_temperatures, properly distinguishing between prompt inputs (single penalty per group) and generation tokens (penalty per sequence).
105-123: LGTM!The function correctly gathers output tokens per sequence. The comment on line 113-114 properly documents the recomputation edge case where prompts may have output tokens.
126-168: LGTM!The
_apply_penaltiesfunction correctly implements the OpenAI API penalty formula. The early return optimization for zero penalties is efficient, and the use ofnp.bincountfor frequency computation is appropriate.
218-224: LGTM!The signature change to accept lists and perform internal tensor conversion is consistent with other helper functions in this module, improving API consistency.
385-387: LGTM!The cumulative_logprobs access correctly uses the new
SequenceData-based path, aligning with the refactored data model.cacheflow/sequence.py (5)
45-64: LGTM!The
Sequenceconstructor correctly initializes the newSequenceDatawith prompt token IDs and stores the prompt string separately. The initialization flow is clear and consistent.
73-85: LGTM!The renamed
_append_tokens_to_blocksmethod correctly indicates its internal nature with the underscore prefix, and the block management logic is sound.
87-92: LGTM!The
append_tokenmethod correctly maintains consistency by updating both the logical token blocks and theSequenceDataoutput token list, along with cumulative log probabilities.
94-106: LGTM!The delegation methods cleanly forward to
SequenceData, andforkcorrectly uses deep copy to ensure child sequences have independent data, which is essential for beam search scenarios.
152-166: LGTM!The
SequenceGroupMetadataclass correctly adopts theseq_data: Dict[int, SequenceData]pattern, aligning with the refactored data model used throughout the codebase.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| seq_data: Dict[int, List[SequenceData]] = {} | ||
| block_tables: Dict[int, List[int]] = {} | ||
| for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): | ||
| seq_id = seq.seq_id | ||
| seq_data[seq_id] = seq.data |
There was a problem hiding this comment.
Fix seq_data type annotation to match actual values.
seq_data is populated with SequenceData instances, not lists. This should be Dict[int, SequenceData] to match SequenceGroupMetadata.
🔧 Suggested fix
- seq_data: Dict[int, List[SequenceData]] = {}
+ seq_data: Dict[int, SequenceData] = {}Also applies to: 260-260
🤖 Prompt for AI Agents
In `@cacheflow/core/scheduler.py` around lines 250 - 254, The type annotation for
seq_data is incorrect: it is declared as Dict[int, List[SequenceData]] but is
populated with single SequenceData instances from
seq_group.get_seqs(status=SequenceStatus.RUNNING); change the annotation to
Dict[int, SequenceData] (and update any other occurrences like the similar
annotation at the other location around the seq processing, e.g., line ~260) to
match SequenceGroupMetadata and the actual values returned by seq.seq_id /
seq.data.
| def __repr__(self) -> str: | ||
| return (f"SequenceData(" | ||
| f"prompt={self.prompt}, " | ||
| f"prompt_token_ids={self.prompt_token_ids}, " | ||
| f"output_token_ids={self.output_token_ids})") |
There was a problem hiding this comment.
__repr__ references non-existent self.prompt attribute.
The __repr__ method references self.prompt on line 40, but SequenceData only has prompt_token_ids. This will raise an AttributeError when the representation is accessed (e.g., during debugging or logging).
🐛 Proposed fix
def __repr__(self) -> str:
return (f"SequenceData("
- f"prompt={self.prompt}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids})")🤖 Prompt for AI Agents
In `@cacheflow/sequence.py` around lines 38 - 42, The __repr__ in SequenceData
refers to a non-existent attribute self.prompt, which raises AttributeError;
update the SequenceData.__repr__ to stop referencing self.prompt and instead use
existing attributes (e.g., self.prompt_token_ids and self.output_token_ids) or a
derived/memoized prompt property if intended; modify the string construction in
__repr__ (the method named __repr__ on class SequenceData) to include only valid
attributes such as prompt_token_ids and output_token_ids (or call the proper
accessor if a prompt getter exists).
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.