Skip to content

Comments

Implement presence and frequency penalties (#95)#27

Open
MitchLewis930 wants to merge 1 commit intofloat_comparison_beforefrom
float_comparison_after
Open

Implement presence and frequency penalties (#95)#27
MitchLewis930 wants to merge 1 commit intofloat_comparison_beforefrom
float_comparison_after

Conversation

@MitchLewis930
Copy link
Collaborator

@MitchLewis930 MitchLewis930 commented Jan 24, 2026

Summary by CodeRabbit

  • New Features

    • Added support for presence and frequency penalties in text generation, enabling fine-tuned control over output diversity and token repetition through new sampling parameters.
  • Improvements

    • Extended sampling configuration to support presence_penalty and frequency_penalty settings (values between -2 and 2).
    • Improved internal sequence data management for better performance and code organization.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link

coderabbitai bot commented Jan 24, 2026

📝 Walkthrough

Walkthrough

This pull request refactors the data model to centralize per-sequence information into a new SequenceData class, replacing scattered fields across multiple components. Additionally, it introduces penalty-based decoding via presence and frequency penalties in the sampler, and adds corresponding parameters to SamplingParams.

Changes

Cohort / File(s) Summary
Data Model Refactoring - Core Classes
cacheflow/sequence.py
Introduced new SequenceData class encapsulating prompt_token_ids, output_token_ids, and cumulative_logprobs. Updated Sequence to accept prompt: str parameter, store self.data, and delegate token-related queries to SequenceData. Modified SequenceGroupMetadata to use unified seq_data: Dict[int, SequenceData] instead of separate input_tokens, context_len, and seq_logprobs fields.
Data Model Propagation - Metadata
cacheflow/model_executor/input_metadata.py
Replaced seq_logprobs: Dict[int, float] parameter with seq_data: Dict[int, SequenceData] in InputMetadata constructor. Updated corresponding attribute assignment and imports.
Data Model Propagation - Scheduler & Worker
cacheflow/core/scheduler.py, cacheflow/worker/worker.py
Updated scheduler to construct SequenceGroupMetadata with new seq_data parameter. Worker now aggregates seq_data across sequence groups and passes it to InputMetadata instead of seq_logprobs. Both modules now access sequence information via SequenceData accessors (get_token_ids(), get_last_token_id(), get_len()).
Data Model Propagation - Frontends
cacheflow/frontend/fastapi_frontend.py, cacheflow/frontend/simple_frontend.py
Updated Sequence constructor calls to pass prompt as an argument. Simple frontend also updated _add_query method signature to accept prompt: str parameter before token_ids.
Penalty-Based Decoding
cacheflow/model_executor/layers/sampler.py
Added new helper functions: _get_penalties() to extract presence/frequency penalties per sequence, _get_output_tokens() to gather output token IDs, and _apply_penalties() to compute and apply frequency/presence penalties to logits. Updated _apply_top_p_top_k() signature to accept lists instead of tensors. Modified forward pass to apply penalties before softmax and use seq_data.cumulative_logprobs for generation tokens.
Sampling Parameters Enhancement
cacheflow/sampling_params.py
Added presence_penalty and frequency_penalty parameters (range [-2, 2]) to SamplingParams. Updated __init__, __repr__, and from_dict() to handle new parameters with validation.
Test Case Updates
simple_server.py
Added presence_penalty and frequency_penalty to test case dictionaries for validation.

Sequence Diagram(s)

sequenceDiagram
    participant Sampler as Sampler.forward()
    participant InputMeta as InputMetadata
    participant GetPen as _get_penalties()
    participant GetOut as _get_output_tokens()
    participant ApplyPen as _apply_penalties()
    participant LogitProc as Logit Processing

    Sampler->>InputMeta: access seq_data
    Sampler->>GetPen: extract penalties per seq
    GetPen->>Sampler: List[presence_penalties], List[frequency_penalties]
    
    Sampler->>GetOut: gather output_token_ids per seq
    GetOut->>Sampler: List[List[int]]
    
    Sampler->>ApplyPen: pass logits + tokens + penalties
    ApplyPen->>ApplyPen: compute token frequencies (numpy bincount)
    ApplyPen->>ApplyPen: apply frequency & presence penalties to logits
    ApplyPen->>Sampler: modified logits
    
    Sampler->>LogitProc: apply temperature scaling
    LogitProc->>LogitProc: compute softmax probabilities
    Sampler->>Sampler: sample from distribution
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 Whiskers twitching with delight,
Data flows through, bundled tight,
SequenceData brings order and grace,
Penalties dance at sampling's pace,
Refactored code, organized right!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'Implement presence and frequency penalties' directly and clearly describes the main objective of the changeset, which introduces presence_penalty and frequency_penalty parameters to SamplingParams and implements penalty-based decoding logic.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cacheflow/sampling_params.py (1)

6-18: Confirmed breaking API change — SamplingParams.__init__ requires presence_penalty, frequency_penalty, and top_k with no defaults.

Direct instantiation calls in benchmark/trace.py (lines 86–110, 9 occurrences) and benchmark/benchmark_text_completion.py (line 46) will fail at runtime because they do not provide these three required arguments. The dictionaries random_sampling_params_dict and beam_search_params_dict are also missing these keys.

Either add default values to __init__ (e.g., presence_penalty: float = 0.0) or update all direct constructor calls to include these arguments. Alternatively, migrate these call sites to use SamplingParams.from_dict(), which already provides sensible defaults.

🤖 Fix all issues with AI agents
In `@cacheflow/core/scheduler.py`:
- Around line 250-254: The type annotation for seq_data is incorrect: it is
declared as Dict[int, List[SequenceData]] but is populated with single
SequenceData instances from seq_group.get_seqs(status=SequenceStatus.RUNNING);
change the annotation to Dict[int, SequenceData] (and update any other
occurrences like the similar annotation at the other location around the seq
processing, e.g., line ~260) to match SequenceGroupMetadata and the actual
values returned by seq.seq_id / seq.data.

In `@cacheflow/sequence.py`:
- Around line 38-42: The __repr__ in SequenceData refers to a non-existent
attribute self.prompt, which raises AttributeError; update the
SequenceData.__repr__ to stop referencing self.prompt and instead use existing
attributes (e.g., self.prompt_token_ids and self.output_token_ids) or a
derived/memoized prompt property if intended; modify the string construction in
__repr__ (the method named __repr__ on class SequenceData) to include only valid
attributes such as prompt_token_ids and output_token_ids (or call the proper
accessor if a prompt getter exists).
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f88db3 and 55f8b0a.

📒 Files selected for processing (9)
  • cacheflow/core/scheduler.py
  • cacheflow/frontend/fastapi_frontend.py
  • cacheflow/frontend/simple_frontend.py
  • cacheflow/model_executor/input_metadata.py
  • cacheflow/model_executor/layers/sampler.py
  • cacheflow/sampling_params.py
  • cacheflow/sequence.py
  • cacheflow/worker/worker.py
  • simple_server.py
🧰 Additional context used
🧬 Code graph analysis (7)
cacheflow/core/scheduler.py (2)
cacheflow/sequence.py (7)
  • Sequence (45-111)
  • SequenceData (16-42)
  • SequenceGroup (114-149)
  • SequenceGroupMetadata (152-166)
  • SequenceOutputs (169-193)
  • SequenceStatus (9-13)
  • get_seqs (126-133)
cacheflow/core/block_manager.py (1)
  • get_block_table (238-240)
cacheflow/worker/worker.py (3)
cacheflow/sequence.py (8)
  • SequenceData (16-42)
  • SequenceGroupMetadata (152-166)
  • get_token_ids (30-31)
  • get_token_ids (97-98)
  • get_last_token_id (33-36)
  • get_last_token_id (100-101)
  • get_len (27-28)
  • get_len (94-95)
cacheflow/block.py (2)
  • get_token_ids (35-36)
  • get_last_token_id (38-40)
cacheflow/model_executor/input_metadata.py (1)
  • InputMetadata (10-54)
cacheflow/frontend/simple_frontend.py (1)
cacheflow/sequence.py (1)
  • Sequence (45-111)
cacheflow/model_executor/input_metadata.py (2)
cacheflow/sampling_params.py (1)
  • SamplingParams (4-106)
cacheflow/sequence.py (1)
  • SequenceData (16-42)
cacheflow/model_executor/layers/sampler.py (1)
cacheflow/model_executor/input_metadata.py (1)
  • InputMetadata (10-54)
cacheflow/sequence.py (3)
cacheflow/block.py (3)
  • get_token_ids (35-36)
  • get_last_token_id (38-40)
  • LogicalTokenBlock (8-40)
cacheflow/core/block_manager.py (1)
  • fork (133-139)
cacheflow/worker/cache_engine.py (1)
  • copy (123-127)
cacheflow/frontend/fastapi_frontend.py (1)
cacheflow/sequence.py (1)
  • Sequence (45-111)
🪛 Ruff (0.14.13)
cacheflow/sampling_params.py

22-23: Avoid specifying long messages outside the exception class

(TRY003)


25-26: Avoid specifying long messages outside the exception class

(TRY003)


105-105: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (19)
simple_server.py (1)

14-15: LGTM — test inputs now exercise the new penalties.

cacheflow/frontend/fastapi_frontend.py (1)

96-100: LGTM — prompt is now propagated into Sequence.

cacheflow/frontend/simple_frontend.py (1)

37-53: LGTM — prompt now flows into Sequence creation.

cacheflow/model_executor/input_metadata.py (1)

14-24: LGTM — InputMetadata now carries seq_data consistently.

cacheflow/worker/worker.py (4)

1-12: LGTM!

The import changes correctly add Optional for type hints and update the sequence imports to include SequenceData alongside existing imports, aligning with the new data model.


108-131: LGTM!

The prompt handling correctly derives sequence IDs from seq_data.keys() and uses SequenceData.get_token_ids() to retrieve prompt tokens. The logic flow is preserved while adopting the new data model.


142-166: LGTM!

The generation token handling correctly uses SequenceData accessors: get_last_token_id() for retrieving the token to process and get_len() for context length calculations. The position calculation remains correct.


188-200: LGTM!

The seq_data aggregation correctly merges per-group sequence data into a combined dictionary for InputMetadata. This aligns with the updated InputMetadata signature that expects seq_data: Dict[int, SequenceData].

cacheflow/model_executor/layers/sampler.py (6)

35-43: LGTM!

The penalty application is correctly positioned before temperature scaling, following the standard order for LLM sampling. The assertions properly validate that the penalty arrays match the batch dimension.


84-102: LGTM!

The _get_penalties function correctly follows the same pattern as _get_temperatures, properly distinguishing between prompt inputs (single penalty per group) and generation tokens (penalty per sequence).


105-123: LGTM!

The function correctly gathers output tokens per sequence. The comment on line 113-114 properly documents the recomputation edge case where prompts may have output tokens.


126-168: LGTM!

The _apply_penalties function correctly implements the OpenAI API penalty formula. The early return optimization for zero penalties is efficient, and the use of np.bincount for frequency computation is appropriate.


218-224: LGTM!

The signature change to accept lists and perform internal tensor conversion is consistent with other helper functions in this module, improving API consistency.


385-387: LGTM!

The cumulative_logprobs access correctly uses the new SequenceData-based path, aligning with the refactored data model.

cacheflow/sequence.py (5)

45-64: LGTM!

The Sequence constructor correctly initializes the new SequenceData with prompt token IDs and stores the prompt string separately. The initialization flow is clear and consistent.


73-85: LGTM!

The renamed _append_tokens_to_blocks method correctly indicates its internal nature with the underscore prefix, and the block management logic is sound.


87-92: LGTM!

The append_token method correctly maintains consistency by updating both the logical token blocks and the SequenceData output token list, along with cumulative log probabilities.


94-106: LGTM!

The delegation methods cleanly forward to SequenceData, and fork correctly uses deep copy to ensure child sequences have independent data, which is essential for beam search scenarios.


152-166: LGTM!

The SequenceGroupMetadata class correctly adopts the seq_data: Dict[int, SequenceData] pattern, aligning with the refactored data model used throughout the codebase.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +250 to +254
seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix seq_data type annotation to match actual values.

seq_data is populated with SequenceData instances, not lists. This should be Dict[int, SequenceData] to match SequenceGroupMetadata.

🔧 Suggested fix
-            seq_data: Dict[int, List[SequenceData]] = {}
+            seq_data: Dict[int, SequenceData] = {}

Also applies to: 260-260

🤖 Prompt for AI Agents
In `@cacheflow/core/scheduler.py` around lines 250 - 254, The type annotation for
seq_data is incorrect: it is declared as Dict[int, List[SequenceData]] but is
populated with single SequenceData instances from
seq_group.get_seqs(status=SequenceStatus.RUNNING); change the annotation to
Dict[int, SequenceData] (and update any other occurrences like the similar
annotation at the other location around the seq processing, e.g., line ~260) to
match SequenceGroupMetadata and the actual values returned by seq.seq_id /
seq.data.

Comment on lines +38 to +42
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt={self.prompt}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids})")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

__repr__ references non-existent self.prompt attribute.

The __repr__ method references self.prompt on line 40, but SequenceData only has prompt_token_ids. This will raise an AttributeError when the representation is accessed (e.g., during debugging or logging).

🐛 Proposed fix
     def __repr__(self) -> str:
         return (f"SequenceData("
-                f"prompt={self.prompt}, "
                 f"prompt_token_ids={self.prompt_token_ids}, "
                 f"output_token_ids={self.output_token_ids})")
🤖 Prompt for AI Agents
In `@cacheflow/sequence.py` around lines 38 - 42, The __repr__ in SequenceData
refers to a non-existent attribute self.prompt, which raises AttributeError;
update the SequenceData.__repr__ to stop referencing self.prompt and instead use
existing attributes (e.g., self.prompt_token_ids and self.output_token_ids) or a
derived/memoized prompt property if intended; modify the string construction in
__repr__ (the method named __repr__ on class SequenceData) to include only valid
attributes such as prompt_token_ids and output_token_ids (or call the proper
accessor if a prompt getter exists).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants