Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions cacheflow/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from typing import Dict, List, Optional, Tuple

from cacheflow.core.block_manager import BlockSpaceManager
from cacheflow.logger import init_logger
from cacheflow.core.policy import PolicyFactory
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceOutputs, SequenceStatus)
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)

logger = init_logger(__name__)

Expand Down Expand Up @@ -246,27 +247,17 @@ def step(self) -> List[SequenceGroup]:
group_id = seq_group.group_id
is_prompt = group_id in prompt_group_ids

input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
Comment on lines +250 to +254
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix seq_data type annotation to match actual values.

seq_data is populated with SequenceData instances, not lists. This should be Dict[int, SequenceData] to match SequenceGroupMetadata.

🔧 Suggested fix
-            seq_data: Dict[int, List[SequenceData]] = {}
+            seq_data: Dict[int, SequenceData] = {}

Also applies to: 260-260

🤖 Prompt for AI Agents
In `@cacheflow/core/scheduler.py` around lines 250 - 254, The type annotation for
seq_data is incorrect: it is declared as Dict[int, List[SequenceData]] but is
populated with single SequenceData instances from
seq_group.get_seqs(status=SequenceStatus.RUNNING); change the annotation to
Dict[int, SequenceData] (and update any other occurrences like the similar
annotation at the other location around the seq processing, e.g., line ~260) to
match SequenceGroupMetadata and the actual values returned by seq.seq_id /
seq.data.

block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
input_tokens[seq_id] = seq.get_token_ids()
else:
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()

seq_group_metadata = SequenceGroupMetadata(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
seq_data=seq_data,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
Expand Down
2 changes: 1 addition & 1 deletion cacheflow/frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def generate(self, request_dict: Dict):
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
seqs.append(seq)

arrival_time = time.time()
Expand Down
5 changes: 3 additions & 2 deletions cacheflow/frontend/simple_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def query(
sampling_params: SamplingParams,
) -> None:
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)
self._add_query(prompt, token_ids, sampling_params)

def _add_query(
self,
prompt: str,
token_ids: List[int],
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
Expand All @@ -48,7 +49,7 @@ def _add_query(
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
seqs.append(seq)

group_id = next(self.seq_group_counter)
Expand Down
10 changes: 6 additions & 4 deletions cacheflow/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
from typing import List, Dict, Tuple
from typing import Dict, List, Tuple

import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceData


class InputMetadata:

def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
Expand All @@ -39,6 +40,7 @@ def __init__(
assert context_lens.shape[0] == self.num_generation_tokens

def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
f'num_valid_tokens={self.num_valid_tokens}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
Expand Down
114 changes: 106 additions & 8 deletions cacheflow/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -31,6 +32,16 @@ def forward(
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]

# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(
logits, output_tokens, presence_penalties, frequency_penalties,
self.vocab_size)

# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
Expand All @@ -43,16 +54,14 @@ def forward(
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)

# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
probs = _apply_top_p_top_k(probs, p, k)
probs = _apply_top_p_top_k(probs, top_ps, top_ks)

# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
Expand All @@ -72,6 +81,93 @@ def _prune_hidden_states(
return hidden_states[last_token_indicies]


def _get_penalties(
input_metadata: InputMetadata,
) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
if i < input_metadata.num_prompts:
# A prompt input.
presence_penalties.append(p)
frequency_penalties.append(f)
else:
# A generation token.
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties


def _get_output_tokens(
input_metadata: InputMetadata,
) -> List[List[int]]:
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id = seq_ids[0]
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
else:
# A generation token.
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
return output_tokens


def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
presence_penalties: List[float],
frequency_penalties: List[float],
vocab_size: int,
) -> torch.Tensor:
num_seqs = logits.shape[0]
# Collect the indices of sequences that have non-zero penalties.
indices = []
for i in range(num_seqs):
if not output_tokens[i]:
continue
p = presence_penalties[i]
f = frequency_penalties[i]
if p == 0.0 and f == 0.0:
continue
indices.append(i)

# Return early if all sequences have zero penalties.
if not indices:
return logits

bin_counts = []
for i in indices:
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)

frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(
frequency_penalties, dtype=logits.dtype, device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(
presence_penalties, dtype=logits.dtype, device=logits.device)

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
return logits


def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
Expand Down Expand Up @@ -121,10 +217,11 @@ def _get_top_p_top_k(

def _apply_top_p_top_k(
probs: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
top_ps: List[float],
top_ks: List[int],
) -> torch.Tensor:
# TODO(woosuk): Optimize.
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)

# Apply top-p.
Expand Down Expand Up @@ -286,7 +383,8 @@ def _sample(

# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
input_metadata.seq_data[seq_id].cumulative_logprobs
for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)

Expand Down
35 changes: 26 additions & 9 deletions cacheflow/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class SamplingParams:
def __init__(
self,
n: int,
presence_penalty: float,
frequency_penalty: float,
temperature: float,
top_p: float,
top_k: int,
Expand All @@ -16,6 +18,12 @@ def __init__(
) -> None:
if n < 1:
raise ValueError(f"n must be at least 1, got {n}.")
if not -2.0 <= presence_penalty <= 2.0:
raise ValueError(
f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
if not -2.0 <= frequency_penalty <= 2.0:
raise ValueError(
f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
if temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {temperature}.")
Expand Down Expand Up @@ -57,6 +65,8 @@ def __init__(
"top_k must be -1 when using greedy sampling.")

self.n = n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand All @@ -67,6 +77,8 @@ def __init__(

def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k},"
Expand All @@ -77,13 +89,18 @@ def __repr__(self) -> str:

@classmethod
def from_dict(cls, d: Dict) -> "SamplingParams":
return cls(
n=d.get("n", 1),
temperature=d.get("temperature", 1.0),
top_p=d.get("top_p", 1.0),
top_k=d.get("top_k", -1),
use_beam_search=d.get("use_beam_search", False),
stop_token_ids=set(d.get("stop_token_ids", set())),
max_num_steps=d.get("max_num_steps", 16),
num_logprobs=d.get("num_logprobs", 0),
sampling_params = cls(
n=d.pop("n", 1),
presence_penalty=d.pop("presence_penalty", 0.0),
frequency_penalty=d.pop("frequency_penalty", 0.0),
temperature=d.pop("temperature", 1.0),
top_p=d.pop("top_p", 1.0),
top_k=d.pop("top_k", -1),
use_beam_search=d.pop("use_beam_search", False),
stop_token_ids=set(d.pop("stop_token_ids", set())),
max_num_steps=d.pop("max_num_steps", 16),
num_logprobs=d.pop("num_logprobs", 0),
)
if d:
raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
return sampling_params
Loading