Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/parallax/p2p/message_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def request_to_proto(
if request.next_token_id is not None:
proto_req.next_token_id = request.next_token_id

# Add token_prob if available
if hasattr(request, "token_prob") and request.token_prob is not None:
proto_req.token_prob = request.token_prob

# Add return_probs flag
if hasattr(request, "return_probs"):
proto_req.return_probs = request.return_probs

forward_request.reqs.append(proto_req)

return forward_request
Expand Down Expand Up @@ -86,6 +94,14 @@ def proto_to_request(

sampling_params = proto_to_sampling_params(proto_req.sampling_params)

# Extract token_prob if present
token_prob = None
if proto_req.HasField("token_prob"):
token_prob = proto_req.token_prob

# Extract return_probs (defaults to False if not present)
return_probs = proto_req.return_probs if hasattr(proto_req, "return_probs") else False

request = IntermediateRequest(
request_id=proto_req.rid,
current_position=current_position,
Expand All @@ -96,6 +112,8 @@ def proto_to_request(
next_token_id=next_token_id,
sampling_params=sampling_params,
lora_path=proto_req.lora_path if proto_req.lora_path != "" else None,
token_prob=token_prob,
return_probs=return_probs,
)

requests.append(request)
Expand Down
2 changes: 2 additions & 0 deletions src/parallax/p2p/proto/forward.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ message Req {
int32 next_token_id = 6;
bytes hidden_states = 7;
string lora_path = 8;
optional float token_prob = 9; // Probability value for the sampled token
bool return_probs = 10; // Whether to return probabilities
}

message SamplingParams {
Expand Down
12 changes: 6 additions & 6 deletions src/parallax/p2p/proto/forward_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 36 additions & 5 deletions src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,25 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict
}

def prepare_next_batch_requests(
self, requests: List[Request], hidden_states: Any, context_lengths: Any
self, requests: List[Request], batch_output: Any, context_lengths: Any
) -> List[Request]:
"""Prepares a batch of requests for the next stage of the pipeline."""
"""Prepares a batch of requests for the next stage of the pipeline.

Args:
requests: List of requests in the batch
batch_output: Output from process_batch. Always a dict with:
- 'hidden_states': token IDs (last peer) or hidden states tensor (intermediate peer)
- 'probs': list of probabilities (last peer) or None (intermediate peer)
context_lengths: Context lengths for each request
"""
if self.tp_rank == 0:
# Extract hidden_states and probs from output (always a dict now)
assert isinstance(
batch_output, dict
), f"Expected dict from process_batch, got {type(batch_output)}"
hidden_states = batch_output["hidden_states"]
token_probs = batch_output["probs"]

batched_requests = []
pre_length = 0
for i, src_request in enumerate(requests):
Expand All @@ -386,7 +401,16 @@ def prepare_next_batch_requests(
hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :]
pre_length += 1

next_req = self._prepare_next_single_request(src_request, hidden_state_for_req)
# Get prob for this request if available
token_prob = (
token_probs[i]
if (self.is_last_peer and token_probs and i < len(token_probs))
else None
)

next_req = self._prepare_next_single_request(
src_request, hidden_state_for_req, token_prob
)
batched_requests.append(next_req)
else:
batched_requests = None
Expand Down Expand Up @@ -502,7 +526,7 @@ def run_loop(self):
# 7. Prepare requests for the next stage in the pipeline
next_batch = self.prepare_next_batch_requests(
requests=prepared_inputs["requests"],
hidden_states=output,
batch_output=output,
context_lengths=prepared_inputs.get("context_lengths"),
)

Expand Down Expand Up @@ -612,6 +636,7 @@ def _handle_raw_request(self, raw_request: Dict):
logger.debug(f"Final input token length for request ID {rid}: {input_token_num}")

lora_path = raw_request.get("lora_path")
return_probs = raw_request.get("return_probs", False) # Get return_probs parameter

raw_sampling_params = raw_request.get("sampling_params")
if raw_sampling_params is None:
Expand All @@ -636,6 +661,7 @@ def _handle_raw_request(self, raw_request: Dict):
max_new_tokens=max_new_tokens,
max_total_length=max_total_length,
lora_path=lora_path,
return_probs=return_probs,
)
if "routing_table" in raw_request:
req.routing_table = raw_request["routing_table"]
Expand Down Expand Up @@ -669,7 +695,9 @@ def _notify_http_request_error(self, raw_request: Optional[Dict], error: Excepti
except Exception: # pragma: no cover - best effort notification
logger.debug("Failed to send error notification to HTTP handler", exc_info=True)

def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request:
def _prepare_next_single_request(
self, request: Request, hidden_states: Any, token_prob: Optional[float] = None
) -> Request:
"""Handle request state changes both inter and intra peers.

This function prepares the request object to be sent to the *next* peer in the
Expand All @@ -678,6 +706,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
Args:
request: The request that was just processed by this peer.
hidden_states: The output hidden_states/output_ids from the model for this request.
token_prob: The probability value for the sampled token (optional).

Returns:
A new Request object ready to be sent to the next destination.
Expand All @@ -698,6 +727,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
next_token_id=next_token_id,
routing_table=request.routing_table,
lora_path=request.lora_path,
token_prob=token_prob,
)
if self.is_last_peer:
# Last peer decodes a token and sends it back to the first peer.
Expand All @@ -716,6 +746,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
next_token_id=next_token_id,
routing_table=request.routing_table,
lora_path=request.lora_path,
token_prob=token_prob,
)
# This peer is the first or an intermediate peer.
if self.is_first_peer:
Expand Down
46 changes: 44 additions & 2 deletions src/parallax/server/executor/mlx_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ def handle_input_requests(self, requests: List[Request]):
req_dict["eos"] = True
if original_req.status == RequestStatus.FINISHED_MAX_LENGTH:
req_dict["length"] = True

# Add prob value for the sampled token (if requested and available)
if original_req.return_probs and req.token_prob is not None:
req_dict["probs"] = req.token_prob

if self.enable_weight_refit:
req_dict["weight_version"] = self.weight_version
if hasattr(self, "send_to_ipc_socket"):
Expand Down Expand Up @@ -350,11 +355,48 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
# Process last peer: need additional sampling + detokenization
if return_decoded_tokens:
sampling_info = SamplingBatchInfo.from_reqs(requests)
return mx.array(

# For MLX, hidden_states at last shard is already logits (after lm_head)
# hidden_states shape: [batch_size, seq_len, vocab_size]
token_ids = mx.array(
self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info)
)

return hidden_states
needs_probs = any(
(isinstance(req, InitialRequest) and req.return_probs)
or (isinstance(req, IntermediateRequest) and req.return_probs)
for req in requests
)

token_probs = None
if needs_probs:
# Extract probability values for sampled tokens
try:
# Get last position logits for each request
batch_probs = []
for i, req in enumerate(requests):
if lengths[i] > 0:
# Get logit at last position
last_idx = int(lengths[i]) - 1
last_logits = hidden_states[i, last_idx, :] # [vocab_size]
probs = last_logits / sampling_info.temperatures.reshape(-1, 1)
probs[:] = mx.softmax(probs, axis=-1)
# logit_value = float(last_logits[token_id])
# batch_logits.append(logit_value)
# Extract probability for the sampled token
token_id = int(token_ids[i])
batch_probs.append(float(probs[i, token_id]))

token_probs = batch_probs if batch_probs else None
except Exception as e:
logger.debug(f"Failed to extract token probs: {e}")
token_probs = None

# Return dict with token_ids and optional probs
return {"hidden_states": token_ids, "probs": token_probs}

# Intermediate peer: return hidden states without probs
return {"hidden_states": hidden_states, "probs": None}

def _release_request(self, rid: str):
"""Release per-request resources in MLX."""
Expand Down
29 changes: 27 additions & 2 deletions src/parallax/server/executor/sglang_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ def handle_input_requests(self, requests: List[Request]):
req_dict["eos"] = True
if original_req.status == RequestStatus.FINISHED_MAX_LENGTH:
req_dict["length"] = True

# Add prob value for the sampled token (if requested and available)
if original_req.return_probs and req.token_prob is not None:
req_dict["probs"] = req.token_prob

if self.enable_weight_refit:
req_dict["weight_version"] = self.weight_version
if hasattr(self, "send_to_ipc_socket"):
Expand Down Expand Up @@ -336,6 +341,7 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:

forward_batch = prepared_inputs["forward_batch"]
pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"]
requests = prepared_inputs.get("requests", [])

# Execute model with SGLang
logits_output, _ = self.model_runner.forward(
Expand All @@ -359,14 +365,33 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
if return_decoded_tokens:
# Last peer: sample and return token IDs
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
return next_token_ids

# Only compute probs if any request in the batch needs it
# Check if any InitialRequest has return_probs=True
needs_probs = any(
(isinstance(req, InitialRequest) and req.return_probs)
or (isinstance(req, IntermediateRequest) and req.return_probs)
for req in requests
)

token_probs = None
# Extract probs for the sampled tokens only if needed
if needs_probs and hasattr(logits_output, "next_token_logits"):
# Get probs for sampled tokens (next_token_logits contains probabilities)
real_probs = logits_output.next_token_logits[
torch.arange(len(next_token_ids)), next_token_ids
]
token_probs = real_probs.cpu().float().tolist()

# Return dict with token_ids and optional probs
return {"hidden_states": next_token_ids, "probs": token_probs}
else:
# Intermediate peer: return hidden states for next peer
# Note: SGLang stores hidden_states + residual separately
final_hidden_states = (
logits_output.tensors["hidden_states"] + logits_output.tensors["residual"]
)
return final_hidden_states
return {"hidden_states": final_hidden_states, "probs": None}

def _release_request(self, rid: str):
"""Release per-request resources in SGLang."""
Expand Down
8 changes: 5 additions & 3 deletions src/parallax/server/executor/vllm_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,15 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
for seq in sampled_token_ids:
padded_seq = seq + [-1] * (max_len - len(seq)) # Pad with -1
padded_tokens.append(padded_seq)
return torch.tensor(padded_tokens, dtype=torch.int64)
token_ids = torch.tensor(padded_tokens, dtype=torch.int64)
else:
return torch.tensor(sampled_token_ids, dtype=torch.int64)
token_ids = torch.tensor(sampled_token_ids, dtype=torch.int64)
# vLLM doesn't support probs yet
return {"hidden_states": token_ids, "probs": None}
else:
# Intermediate peer: return hidden states for next peer
final_hidden_states = output.tensors["hidden_states"] + output.tensors["residual"]
return final_hidden_states
return {"hidden_states": final_hidden_states, "probs": None}

def _release_request(self, rid: str):
"""Release per-request resources in vLLM."""
Expand Down
Loading