From 7e57d4842fe3c7390625d81587618895495bc3bb Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 6 Dec 2025 18:02:56 +0800 Subject: [PATCH 01/11] logits on mac is ok --- src/parallax/p2p/message_util.py | 10 +++++ src/parallax/p2p/proto/forward.proto | 1 + src/parallax/p2p/proto/forward_pb2.py | 12 +++--- src/parallax/server/executor/base_executor.py | 19 ++++++++- src/parallax/server/executor/mlx_executor.py | 42 ++++++++++++++++++- .../server/executor/sglang_executor.py | 26 ++++++++++++ src/parallax/server/http_server.py | 19 ++++++++- src/parallax/server/request.py | 8 ++++ 8 files changed, 127 insertions(+), 10 deletions(-) diff --git a/src/parallax/p2p/message_util.py b/src/parallax/p2p/message_util.py index 1181c988..725e0d9d 100644 --- a/src/parallax/p2p/message_util.py +++ b/src/parallax/p2p/message_util.py @@ -50,6 +50,10 @@ def request_to_proto( if request.next_token_id is not None: proto_req.next_token_id = request.next_token_id + # Add token_logit if available + if hasattr(request, "token_logit") and request.token_logit is not None: + proto_req.token_logit = request.token_logit + forward_request.reqs.append(proto_req) return forward_request @@ -86,6 +90,11 @@ def proto_to_request( sampling_params = proto_to_sampling_params(proto_req.sampling_params) + # Extract token_logit if present + token_logit = None + if proto_req.HasField("token_logit"): + token_logit = proto_req.token_logit + request = IntermediateRequest( request_id=proto_req.rid, current_position=current_position, @@ -96,6 +105,7 @@ def proto_to_request( next_token_id=next_token_id, sampling_params=sampling_params, lora_path=proto_req.lora_path if proto_req.lora_path != "" else None, + token_logit=token_logit, ) requests.append(request) diff --git a/src/parallax/p2p/proto/forward.proto b/src/parallax/p2p/proto/forward.proto index 4854ab57..bac1ac37 100644 --- a/src/parallax/p2p/proto/forward.proto +++ b/src/parallax/p2p/proto/forward.proto @@ -36,6 +36,7 @@ message Req { int32 next_token_id = 6; bytes hidden_states = 7; string lora_path = 8; + optional float token_logit = 9; // Logit value for the sampled token } message SamplingParams { diff --git a/src/parallax/p2p/proto/forward_pb2.py b/src/parallax/p2p/proto/forward_pb2.py index 9e4e3f69..c60b0994 100644 --- a/src/parallax/p2p/proto/forward_pb2.py +++ b/src/parallax/p2p/proto/forward_pb2.py @@ -24,15 +24,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xc7\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xf1\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x18\n\x0btoken_logit\x18\t \x01(\x02H\x00\x88\x01\x01\x42\x0e\n\x0c_token_logit\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.parallax.p2p.proto.forward_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_FORWARDMODE']._serialized_start=723 - _globals['_FORWARDMODE']._serialized_end=771 + _globals['_FORWARDMODE']._serialized_start=765 + _globals['_FORWARDMODE']._serialized_end=813 _globals['_FORWARDREQUEST']._serialized_start=50 _globals['_FORWARDREQUEST']._serialized_end=140 _globals['_FORWARDRESPONSE']._serialized_start=142 @@ -42,7 +42,7 @@ _globals['_ABORTRESPONSE']._serialized_start=206 _globals['_ABORTRESPONSE']._serialized_end=221 _globals['_REQ']._serialized_start=224 - _globals['_REQ']._serialized_end=423 - _globals['_SAMPLINGPARAMS']._serialized_start=426 - _globals['_SAMPLINGPARAMS']._serialized_end=721 + _globals['_REQ']._serialized_end=465 + _globals['_SAMPLINGPARAMS']._serialized_start=468 + _globals['_SAMPLINGPARAMS']._serialized_end=763 # @@protoc_insertion_point(module_scope) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index a092ac1e..81d3f7b5 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -369,7 +369,15 @@ def prepare_next_batch_requests( hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] pre_length += 1 - next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) + # Get logit for this request if available + token_logit = None + if self.is_last_peer and hasattr(self, "_latest_token_logits"): + if self._latest_token_logits is not None and len(self._latest_token_logits) > i: + token_logit = self._latest_token_logits[i] + + next_req = self._prepare_next_single_request( + src_request, hidden_state_for_req, token_logit + ) batched_requests.append(next_req) else: batched_requests = None @@ -576,6 +584,7 @@ def _handle_raw_request(self, raw_request: Dict): max_total_length = len(prompt) + max_new_tokens lora_path = raw_request.get("lora_path") + return_logits = raw_request.get("return_logits", False) # Get return_logits parameter raw_sampling_params = raw_request.get("sampling_params") if raw_sampling_params is None: @@ -600,6 +609,7 @@ def _handle_raw_request(self, raw_request: Dict): max_new_tokens=max_new_tokens, max_total_length=max_total_length, lora_path=lora_path, + return_logits=return_logits, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] @@ -633,7 +643,9 @@ def _notify_http_request_error(self, raw_request: Optional[Dict], error: Excepti except Exception: # pragma: no cover - best effort notification logger.debug("Failed to send error notification to HTTP handler", exc_info=True) - def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request: + def _prepare_next_single_request( + self, request: Request, hidden_states: Any, token_logit: Optional[float] = None + ) -> Request: """Handle request state changes both inter and intra peers. This function prepares the request object to be sent to the *next* peer in the @@ -642,6 +654,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Args: request: The request that was just processed by this peer. hidden_states: The output hidden_states/output_ids from the model for this request. + token_logit: The logit value for the sampled token (optional). Returns: A new Request object ready to be sent to the next destination. @@ -662,6 +675,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, + token_logit=token_logit, ) if self.is_last_peer: # Last peer decodes a token and sends it back to the first peer. @@ -680,6 +694,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, + token_logit=token_logit, ) # This peer is the first or an intermediate peer. if self.is_first_peer: diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 67c14ca7..3561398f 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -198,6 +198,9 @@ def __init__( f"KVCacheManager ready; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}" ) + # Store latest sampled token logit values (not full distribution) + self._latest_token_logits = None + def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" if not requests: @@ -253,6 +256,12 @@ def handle_input_requests(self, requests: List[Request]): req_dict["eos"] = True if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True + + # Add logit value for the sampled token (if requested and available) + if hasattr(original_req, "return_logits") and original_req.return_logits: + if hasattr(req, "token_logit") and req.token_logit is not None: + req_dict["logits"] = req.token_logit + if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) else: @@ -330,10 +339,41 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Process last peer: need additional sampling + detokenization if return_decoded_tokens: sampling_info = SamplingBatchInfo.from_reqs(requests) - return mx.array( + + # For MLX, hidden_states at last shard is already logits (after lm_head) + # hidden_states shape: [batch_size, seq_len, vocab_size] + token_ids = mx.array( self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info) ) + # Extract logit values for sampled tokens only if requested + need_logits = any( + hasattr(req, "return_logits") and req.return_logits for req in requests + ) + + if need_logits: + try: + # Get last position logits for each request + batch_logits = [] + for i, req in enumerate(requests): + if lengths[i] > 0: + # Get logit at last position + last_idx = int(lengths[i]) - 1 + last_logits = hidden_states[i, last_idx, :] # [vocab_size] + # Extract logit for the sampled token + token_id = int(token_ids[i]) + logit_value = float(last_logits[token_id]) + batch_logits.append(logit_value) + + self._latest_token_logits = batch_logits if batch_logits else None + except Exception as e: + logger.debug(f"Failed to extract token logits: {e}") + self._latest_token_logits = None + else: + self._latest_token_logits = None + + return token_ids + return hidden_states def _release_request(self, rid: str): diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 7637765c..46198d9b 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -159,6 +159,9 @@ def __init__( self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group + # Store latest sampled token logits (not full distribution) + self._latest_token_logits = None + def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" @@ -299,6 +302,12 @@ def handle_input_requests(self, requests: List[Request]): req_dict["eos"] = True if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True + + # Add logit value for the sampled token (if requested and available) + if hasattr(original_req, "return_logits") and original_req.return_logits: + if hasattr(req, "token_logit") and req.token_logit is not None: + req_dict["logits"] = req.token_logit + if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) else: @@ -349,6 +358,23 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if return_decoded_tokens: # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) + + # Extract logits for the sampled tokens only if requested + # Check if any request in the batch needs logits + need_logits = any( + hasattr(req, "return_logits") and req.return_logits + for req in prepared_inputs["requests"] + ) + + if need_logits and hasattr(logits_output, "next_token_logits"): + # Get logits for sampled tokens + real_logits = logits_output.next_token_logits[ + torch.arange(len(next_token_ids)), next_token_ids + ] + self._latest_token_logits = real_logits.cpu().float().tolist() + else: + self._latest_token_logits = None + return next_token_ids else: # Intermediate peer: return hidden states for next peer diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index de1e9573..d8583cfd 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -22,7 +22,7 @@ import uuid from dataclasses import dataclass, field from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional import fastapi import uvicorn @@ -87,6 +87,9 @@ class HTTPRequestInfo: error_message: Optional[str] = None error_type: Optional[str] = None error_status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR + # logits support + return_logits: bool = False # Whether to return logits + logits_list: List = field(default_factory=list) # Store logits for each token class HTTPHandler: @@ -128,6 +131,7 @@ def create_request(self, request: Dict): rid = request["rid"] stream = request.get("stream", False) model = request.get("model", "default") + return_logits = request.get("return_logits", False) # Check if logits requested chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) create_time = time.time() @@ -140,6 +144,7 @@ def create_request(self, request: Dict): create_time=create_time, update_time=update_time, detokenizer=detokenizer, + return_logits=return_logits, ) if stream: request_info.token_queue = asyncio.Queue() @@ -151,6 +156,11 @@ def release_request(self, rid: str): def send_request(self, request: Dict): """Sends the request to model executor using IPC.""" + # Ensure return_logits is included in the request sent to executor + rid = request.get("rid") + if rid and rid in self.processing_requests: + request_info = self.processing_requests[rid] + request["return_logits"] = request_info.return_logits self.send_to_executor.send_pyobj(request) def abort_request(self, request_id: str): @@ -280,6 +290,9 @@ def generate_non_stream_response(self, rid): "reasoning_content": None, "tool_calls": None, } + # Add logits if requested + if request_info.return_logits: + choice["logits"] = request_info.logits_list return response async def _handle_executor_error(self, rid: str, recv_dict: Dict): @@ -331,6 +344,10 @@ async def _handle_loop(self): request_info.detokenizer.add_token(next_token_id) output = request_info.detokenizer.last_segment + # Store logits if requested + if request_info.return_logits and "logits" in recv_dict: + request_info.logits_list.append(recv_dict["logits"]) + is_finished = recv_dict.get("eos", False) or recv_dict.get("length", False) # Only process and send non-EOS tokens diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f76accdc..841f1b37 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -158,6 +158,7 @@ def __init__( max_total_length: int = 1024, status: RequestStatus = RequestStatus.PREFILLING, lora_path: Optional[str] = None, + return_logits: bool = False, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -170,6 +171,7 @@ def __init__( lora_path=lora_path, ) self.prompt = prompt + self.return_logits = return_logits if max_new_tokens < 1: raise ValueError("max_new_tokens must be at least 1.") @@ -262,6 +264,7 @@ def __init__( routing_table: Optional[List[str]] = [], sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, + token_logit: Optional[float] = None, ): super().__init__( request_id=request_id, @@ -283,6 +286,7 @@ def __init__( self.current_position = current_position self.hidden_states = hidden_states self.next_token_id = next_token_id + self.token_logit = token_logit @property def input_length(self) -> int: @@ -301,6 +305,7 @@ def from_initial_request( initial_request: InitialRequest, hidden_states: Optional[Any] = None, lora_path: Optional[str] = None, + token_logit: Optional[float] = None, ) -> "IntermediateRequest": """Convert an InitialRequest to an IntermediateRequest. @@ -333,6 +338,7 @@ def from_initial_request( sampling_params=initial_request.sampling_params, routing_table=initial_request.routing_table, lora_path=lora_path, + token_logit=token_logit, ) @classmethod @@ -341,6 +347,7 @@ def from_intermediate_request( old_request: "IntermediateRequest", new_hidden_states: Any, lora_path: Optional[str] = None, + token_logit: Optional[float] = None, ) -> "IntermediateRequest": """ Creates a new IntermediateRequest from an old one, with updated hidden states. @@ -356,6 +363,7 @@ def from_intermediate_request( routing_table=old_request.routing_table, sampling_params=old_request.sampling_params, lora_path=lora_path, + token_logit=token_logit, ) def __repr__(self): From e246fbcc680f4b6fe37d70d8a08db7f9fbeb1278 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 6 Dec 2025 18:55:03 +0800 Subject: [PATCH 02/11] not check return_logits in executor --- src/parallax/server/executor/mlx_executor.py | 41 ++++++++----------- .../server/executor/sglang_executor.py | 10 +---- 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 3561398f..2caefead 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -346,30 +346,23 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info) ) - # Extract logit values for sampled tokens only if requested - need_logits = any( - hasattr(req, "return_logits") and req.return_logits for req in requests - ) - - if need_logits: - try: - # Get last position logits for each request - batch_logits = [] - for i, req in enumerate(requests): - if lengths[i] > 0: - # Get logit at last position - last_idx = int(lengths[i]) - 1 - last_logits = hidden_states[i, last_idx, :] # [vocab_size] - # Extract logit for the sampled token - token_id = int(token_ids[i]) - logit_value = float(last_logits[token_id]) - batch_logits.append(logit_value) - - self._latest_token_logits = batch_logits if batch_logits else None - except Exception as e: - logger.debug(f"Failed to extract token logits: {e}") - self._latest_token_logits = None - else: + # Extract logit values for sampled tokens + try: + # Get last position logits for each request + batch_logits = [] + for i, req in enumerate(requests): + if lengths[i] > 0: + # Get logit at last position + last_idx = int(lengths[i]) - 1 + last_logits = hidden_states[i, last_idx, :] # [vocab_size] + # Extract logit for the sampled token + token_id = int(token_ids[i]) + logit_value = float(last_logits[token_id]) + batch_logits.append(logit_value) + + self._latest_token_logits = batch_logits if batch_logits else None + except Exception as e: + logger.debug(f"Failed to extract token logits: {e}") self._latest_token_logits = None return token_ids diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 46198d9b..48f5bb95 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -359,14 +359,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) - # Extract logits for the sampled tokens only if requested - # Check if any request in the batch needs logits - need_logits = any( - hasattr(req, "return_logits") and req.return_logits - for req in prepared_inputs["requests"] - ) - - if need_logits and hasattr(logits_output, "next_token_logits"): + # Extract logits for the sampled tokens + if hasattr(logits_output, "next_token_logits"): # Get logits for sampled tokens real_logits = logits_output.next_token_logits[ torch.arange(len(next_token_ids)), next_token_ids From c59399248cb58466e74f39f07ba53546341234e4 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 13 Dec 2025 17:38:49 +0800 Subject: [PATCH 03/11] mac and gpu return probs --- src/parallax/server/executor/mlx_executor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 2caefead..193b45b8 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -355,10 +355,13 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Get logit at last position last_idx = int(lengths[i]) - 1 last_logits = hidden_states[i, last_idx, :] # [vocab_size] + logprobe = last_logits / sampling_info.temperatures.reshape(-1, 1) + logprobe[:] = mx.softmax(logprobe, axis=-1) # Extract logit for the sampled token token_id = int(token_ids[i]) - logit_value = float(last_logits[token_id]) - batch_logits.append(logit_value) + # logit_value = float(last_logits[token_id]) + # batch_logits.append(logit_value) + batch_logits.append(float(logprobe[i, token_id])) self._latest_token_logits = batch_logits if batch_logits else None except Exception as e: From 1b3e6795ac88684120be7ffb4417408e3cf0bc5d Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 13 Dec 2025 18:24:35 +0800 Subject: [PATCH 04/11] return_logits change to return_probs --- src/parallax/p2p/message_util.py | 16 ++++++------- src/parallax/p2p/proto/forward.proto | 2 +- src/parallax/p2p/proto/forward_pb2.py | 12 +++++----- src/parallax/server/executor/base_executor.py | 24 +++++++++---------- src/parallax/server/executor/mlx_executor.py | 16 ++++++------- .../server/executor/sglang_executor.py | 20 ++++++++-------- src/parallax/server/http_server.py | 24 +++++++++---------- src/parallax/server/request.py | 16 ++++++------- 8 files changed, 65 insertions(+), 65 deletions(-) diff --git a/src/parallax/p2p/message_util.py b/src/parallax/p2p/message_util.py index 725e0d9d..0bb97205 100644 --- a/src/parallax/p2p/message_util.py +++ b/src/parallax/p2p/message_util.py @@ -50,9 +50,9 @@ def request_to_proto( if request.next_token_id is not None: proto_req.next_token_id = request.next_token_id - # Add token_logit if available - if hasattr(request, "token_logit") and request.token_logit is not None: - proto_req.token_logit = request.token_logit + # Add token_prob if available + if hasattr(request, "token_prob") and request.token_prob is not None: + proto_req.token_prob = request.token_prob forward_request.reqs.append(proto_req) @@ -90,10 +90,10 @@ def proto_to_request( sampling_params = proto_to_sampling_params(proto_req.sampling_params) - # Extract token_logit if present - token_logit = None - if proto_req.HasField("token_logit"): - token_logit = proto_req.token_logit + # Extract token_prob if present + token_prob = None + if proto_req.HasField("token_prob"): + token_prob = proto_req.token_prob request = IntermediateRequest( request_id=proto_req.rid, @@ -105,7 +105,7 @@ def proto_to_request( next_token_id=next_token_id, sampling_params=sampling_params, lora_path=proto_req.lora_path if proto_req.lora_path != "" else None, - token_logit=token_logit, + token_prob=token_prob, ) requests.append(request) diff --git a/src/parallax/p2p/proto/forward.proto b/src/parallax/p2p/proto/forward.proto index bac1ac37..11078074 100644 --- a/src/parallax/p2p/proto/forward.proto +++ b/src/parallax/p2p/proto/forward.proto @@ -36,7 +36,7 @@ message Req { int32 next_token_id = 6; bytes hidden_states = 7; string lora_path = 8; - optional float token_logit = 9; // Logit value for the sampled token + optional float token_prob = 9; // Probability value for the sampled token } message SamplingParams { diff --git a/src/parallax/p2p/proto/forward_pb2.py b/src/parallax/p2p/proto/forward_pb2.py index c60b0994..6235b299 100644 --- a/src/parallax/p2p/proto/forward_pb2.py +++ b/src/parallax/p2p/proto/forward_pb2.py @@ -24,15 +24,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xf1\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x18\n\x0btoken_logit\x18\t \x01(\x02H\x00\x88\x01\x01\x42\x0e\n\x0c_token_logit\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xef\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x17\n\ntoken_prob\x18\t \x01(\x02H\x00\x88\x01\x01\x42\r\n\x0b_token_prob\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.parallax.p2p.proto.forward_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_FORWARDMODE']._serialized_start=765 - _globals['_FORWARDMODE']._serialized_end=813 + _globals['_FORWARDMODE']._serialized_start=763 + _globals['_FORWARDMODE']._serialized_end=811 _globals['_FORWARDREQUEST']._serialized_start=50 _globals['_FORWARDREQUEST']._serialized_end=140 _globals['_FORWARDRESPONSE']._serialized_start=142 @@ -42,7 +42,7 @@ _globals['_ABORTRESPONSE']._serialized_start=206 _globals['_ABORTRESPONSE']._serialized_end=221 _globals['_REQ']._serialized_start=224 - _globals['_REQ']._serialized_end=465 - _globals['_SAMPLINGPARAMS']._serialized_start=468 - _globals['_SAMPLINGPARAMS']._serialized_end=763 + _globals['_REQ']._serialized_end=463 + _globals['_SAMPLINGPARAMS']._serialized_start=466 + _globals['_SAMPLINGPARAMS']._serialized_end=761 # @@protoc_insertion_point(module_scope) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 81d3f7b5..6c1768c6 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -369,14 +369,14 @@ def prepare_next_batch_requests( hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] pre_length += 1 - # Get logit for this request if available - token_logit = None - if self.is_last_peer and hasattr(self, "_latest_token_logits"): - if self._latest_token_logits is not None and len(self._latest_token_logits) > i: - token_logit = self._latest_token_logits[i] + # Get prob for this request if available + token_prob = None + if self.is_last_peer and hasattr(self, "_latest_token_probs"): + if self._latest_token_probs is not None and len(self._latest_token_probs) > i: + token_prob = self._latest_token_probs[i] next_req = self._prepare_next_single_request( - src_request, hidden_state_for_req, token_logit + src_request, hidden_state_for_req, token_prob ) batched_requests.append(next_req) else: @@ -584,7 +584,7 @@ def _handle_raw_request(self, raw_request: Dict): max_total_length = len(prompt) + max_new_tokens lora_path = raw_request.get("lora_path") - return_logits = raw_request.get("return_logits", False) # Get return_logits parameter + return_probs = raw_request.get("return_probs", False) # Get return_probs parameter raw_sampling_params = raw_request.get("sampling_params") if raw_sampling_params is None: @@ -609,7 +609,7 @@ def _handle_raw_request(self, raw_request: Dict): max_new_tokens=max_new_tokens, max_total_length=max_total_length, lora_path=lora_path, - return_logits=return_logits, + return_probs=return_probs, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] @@ -644,7 +644,7 @@ def _notify_http_request_error(self, raw_request: Optional[Dict], error: Excepti logger.debug("Failed to send error notification to HTTP handler", exc_info=True) def _prepare_next_single_request( - self, request: Request, hidden_states: Any, token_logit: Optional[float] = None + self, request: Request, hidden_states: Any, token_prob: Optional[float] = None ) -> Request: """Handle request state changes both inter and intra peers. @@ -654,7 +654,7 @@ def _prepare_next_single_request( Args: request: The request that was just processed by this peer. hidden_states: The output hidden_states/output_ids from the model for this request. - token_logit: The logit value for the sampled token (optional). + token_prob: The probability value for the sampled token (optional). Returns: A new Request object ready to be sent to the next destination. @@ -675,7 +675,7 @@ def _prepare_next_single_request( next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, - token_logit=token_logit, + token_prob=token_prob, ) if self.is_last_peer: # Last peer decodes a token and sends it back to the first peer. @@ -694,7 +694,7 @@ def _prepare_next_single_request( next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, - token_logit=token_logit, + token_prob=token_prob, ) # This peer is the first or an intermediate peer. if self.is_first_peer: diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 193b45b8..0490bb1c 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -199,7 +199,7 @@ def __init__( ) # Store latest sampled token logit values (not full distribution) - self._latest_token_logits = None + self._latest_token_probs = None def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" @@ -257,10 +257,10 @@ def handle_input_requests(self, requests: List[Request]): if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True - # Add logit value for the sampled token (if requested and available) - if hasattr(original_req, "return_logits") and original_req.return_logits: - if hasattr(req, "token_logit") and req.token_logit is not None: - req_dict["logits"] = req.token_logit + # Add prob value for the sampled token (if requested and available) + if hasattr(original_req, "return_probs") and original_req.return_probs: + if hasattr(req, "token_prob") and req.token_prob is not None: + req_dict["probs"] = req.token_prob if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) @@ -363,10 +363,10 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # batch_logits.append(logit_value) batch_logits.append(float(logprobe[i, token_id])) - self._latest_token_logits = batch_logits if batch_logits else None + self._latest_token_probs = batch_logits if batch_logits else None except Exception as e: - logger.debug(f"Failed to extract token logits: {e}") - self._latest_token_logits = None + logger.debug(f"Failed to extract token probs: {e}") + self._latest_token_probs = None return token_ids diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 48f5bb95..ad3d9dc4 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -160,7 +160,7 @@ def __init__( self.tp_cpu_group = self.tp_group.cpu_group # Store latest sampled token logits (not full distribution) - self._latest_token_logits = None + self._latest_token_probs = None def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" @@ -303,10 +303,10 @@ def handle_input_requests(self, requests: List[Request]): if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True - # Add logit value for the sampled token (if requested and available) - if hasattr(original_req, "return_logits") and original_req.return_logits: - if hasattr(req, "token_logit") and req.token_logit is not None: - req_dict["logits"] = req.token_logit + # Add prob value for the sampled token (if requested and available) + if hasattr(original_req, "return_probs") and original_req.return_probs: + if hasattr(req, "token_prob") and req.token_prob is not None: + req_dict["probs"] = req.token_prob if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) @@ -359,15 +359,15 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) - # Extract logits for the sampled tokens + # Extract probs for the sampled tokens if hasattr(logits_output, "next_token_logits"): - # Get logits for sampled tokens - real_logits = logits_output.next_token_logits[ + # Get probs for sampled tokens + real_probs = logits_output.next_token_logits[ torch.arange(len(next_token_ids)), next_token_ids ] - self._latest_token_logits = real_logits.cpu().float().tolist() + self._latest_token_probs = real_probs.cpu().float().tolist() else: - self._latest_token_logits = None + self._latest_token_probs = None return next_token_ids else: diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index d8583cfd..9cedb44f 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -88,8 +88,8 @@ class HTTPRequestInfo: error_type: Optional[str] = None error_status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR # logits support - return_logits: bool = False # Whether to return logits - logits_list: List = field(default_factory=list) # Store logits for each token + return_probs: bool = False # Whether to return probabilities + probs_list: List = field(default_factory=list) # Store probs for each token class HTTPHandler: @@ -131,7 +131,7 @@ def create_request(self, request: Dict): rid = request["rid"] stream = request.get("stream", False) model = request.get("model", "default") - return_logits = request.get("return_logits", False) # Check if logits requested + return_probs = request.get("return_probs", False) # Check if probs requested chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) create_time = time.time() @@ -144,7 +144,7 @@ def create_request(self, request: Dict): create_time=create_time, update_time=update_time, detokenizer=detokenizer, - return_logits=return_logits, + return_probs=return_probs, ) if stream: request_info.token_queue = asyncio.Queue() @@ -156,11 +156,11 @@ def release_request(self, rid: str): def send_request(self, request: Dict): """Sends the request to model executor using IPC.""" - # Ensure return_logits is included in the request sent to executor + # Ensure return_probs is included in the request sent to executor rid = request.get("rid") if rid and rid in self.processing_requests: request_info = self.processing_requests[rid] - request["return_logits"] = request_info.return_logits + request["return_probs"] = request_info.return_probs self.send_to_executor.send_pyobj(request) def abort_request(self, request_id: str): @@ -290,9 +290,9 @@ def generate_non_stream_response(self, rid): "reasoning_content": None, "tool_calls": None, } - # Add logits if requested - if request_info.return_logits: - choice["logits"] = request_info.logits_list + # Add probs if requested + if request_info.return_probs: + choice["probs"] = request_info.probs_list return response async def _handle_executor_error(self, rid: str, recv_dict: Dict): @@ -344,9 +344,9 @@ async def _handle_loop(self): request_info.detokenizer.add_token(next_token_id) output = request_info.detokenizer.last_segment - # Store logits if requested - if request_info.return_logits and "logits" in recv_dict: - request_info.logits_list.append(recv_dict["logits"]) + # Store probs if requested + if request_info.return_probs and "probs" in recv_dict: + request_info.probs_list.append(recv_dict["probs"]) is_finished = recv_dict.get("eos", False) or recv_dict.get("length", False) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 841f1b37..e902a423 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -158,7 +158,7 @@ def __init__( max_total_length: int = 1024, status: RequestStatus = RequestStatus.PREFILLING, lora_path: Optional[str] = None, - return_logits: bool = False, + return_probs: bool = False, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -171,7 +171,7 @@ def __init__( lora_path=lora_path, ) self.prompt = prompt - self.return_logits = return_logits + self.return_probs = return_probs if max_new_tokens < 1: raise ValueError("max_new_tokens must be at least 1.") @@ -264,7 +264,7 @@ def __init__( routing_table: Optional[List[str]] = [], sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, - token_logit: Optional[float] = None, + token_prob: Optional[float] = None, ): super().__init__( request_id=request_id, @@ -286,7 +286,7 @@ def __init__( self.current_position = current_position self.hidden_states = hidden_states self.next_token_id = next_token_id - self.token_logit = token_logit + self.token_prob = token_prob @property def input_length(self) -> int: @@ -305,7 +305,7 @@ def from_initial_request( initial_request: InitialRequest, hidden_states: Optional[Any] = None, lora_path: Optional[str] = None, - token_logit: Optional[float] = None, + token_prob: Optional[float] = None, ) -> "IntermediateRequest": """Convert an InitialRequest to an IntermediateRequest. @@ -338,7 +338,7 @@ def from_initial_request( sampling_params=initial_request.sampling_params, routing_table=initial_request.routing_table, lora_path=lora_path, - token_logit=token_logit, + token_prob=token_prob, ) @classmethod @@ -347,7 +347,7 @@ def from_intermediate_request( old_request: "IntermediateRequest", new_hidden_states: Any, lora_path: Optional[str] = None, - token_logit: Optional[float] = None, + token_prob: Optional[float] = None, ) -> "IntermediateRequest": """ Creates a new IntermediateRequest from an old one, with updated hidden states. @@ -363,7 +363,7 @@ def from_intermediate_request( routing_table=old_request.routing_table, sampling_params=old_request.sampling_params, lora_path=lora_path, - token_logit=token_logit, + token_prob=token_prob, ) def __repr__(self): From f77b9440b833b32c0c54a79b264d8a44722ac506 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 13 Dec 2025 18:34:10 +0800 Subject: [PATCH 05/11] logprobe to probs --- src/parallax/server/executor/mlx_executor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 0490bb1c..56a6e793 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -355,13 +355,13 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Get logit at last position last_idx = int(lengths[i]) - 1 last_logits = hidden_states[i, last_idx, :] # [vocab_size] - logprobe = last_logits / sampling_info.temperatures.reshape(-1, 1) - logprobe[:] = mx.softmax(logprobe, axis=-1) + probs = last_logits / sampling_info.temperatures.reshape(-1, 1) + probs[:] = mx.softmax(probs, axis=-1) # Extract logit for the sampled token token_id = int(token_ids[i]) # logit_value = float(last_logits[token_id]) # batch_logits.append(logit_value) - batch_logits.append(float(logprobe[i, token_id])) + batch_logits.append(float(probs[i, token_id])) self._latest_token_probs = batch_logits if batch_logits else None except Exception as e: From d91bb5338b4a2bcfd03e98f4c67c06e9f6d51832 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 13 Dec 2025 20:22:32 +0800 Subject: [PATCH 06/11] probs add token --- src/parallax/server/http_server.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 9cedb44f..b93c6db9 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -87,9 +87,10 @@ class HTTPRequestInfo: error_message: Optional[str] = None error_type: Optional[str] = None error_status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR - # logits support + # probs support return_probs: bool = False # Whether to return probabilities probs_list: List = field(default_factory=list) # Store probs for each token + token_ids_list: List = field(default_factory=list) # Store token IDs for each token class HTTPHandler: @@ -223,6 +224,12 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False): } choice = response["choices"][0] choice["delta"] = {"role": role, "content": content} + # Add probs in the last chunk if requested (convert to object array format) + if is_last and request_info.return_probs: + choice["probs"] = [ + {self.tokenizer.decode([token_id]): prob} + for token_id, prob in zip(request_info.token_ids_list, request_info.probs_list) + ] response_json = json.dumps(response, separators=(",", ":")) return f"data: {response_json}\n\n".encode() @@ -290,9 +297,12 @@ def generate_non_stream_response(self, rid): "reasoning_content": None, "tool_calls": None, } - # Add probs if requested + # Add probs if requested (convert to object array format) if request_info.return_probs: - choice["probs"] = request_info.probs_list + choice["probs"] = [ + {self.tokenizer.decode([token_id]): prob} + for token_id, prob in zip(request_info.token_ids_list, request_info.probs_list) + ] return response async def _handle_executor_error(self, rid: str, recv_dict: Dict): @@ -344,9 +354,10 @@ async def _handle_loop(self): request_info.detokenizer.add_token(next_token_id) output = request_info.detokenizer.last_segment - # Store probs if requested + # Store probs and token IDs if requested if request_info.return_probs and "probs" in recv_dict: request_info.probs_list.append(recv_dict["probs"]) + request_info.token_ids_list.append(next_token_id) is_finished = recv_dict.get("eos", False) or recv_dict.get("length", False) From 8d1191f3df39477ca5f6134585cea2ee1eac89c1 Mon Sep 17 00:00:00 2001 From: jason Date: Mon, 15 Dec 2025 20:46:59 +0800 Subject: [PATCH 07/11] add needs_probs check --- src/parallax/server/executor/mlx_executor.py | 47 +++++++++++-------- .../server/executor/sglang_executor.py | 13 +++-- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 56a6e793..a1703b56 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -346,26 +346,33 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info) ) - # Extract logit values for sampled tokens - try: - # Get last position logits for each request - batch_logits = [] - for i, req in enumerate(requests): - if lengths[i] > 0: - # Get logit at last position - last_idx = int(lengths[i]) - 1 - last_logits = hidden_states[i, last_idx, :] # [vocab_size] - probs = last_logits / sampling_info.temperatures.reshape(-1, 1) - probs[:] = mx.softmax(probs, axis=-1) - # Extract logit for the sampled token - token_id = int(token_ids[i]) - # logit_value = float(last_logits[token_id]) - # batch_logits.append(logit_value) - batch_logits.append(float(probs[i, token_id])) - - self._latest_token_probs = batch_logits if batch_logits else None - except Exception as e: - logger.debug(f"Failed to extract token probs: {e}") + needs_probs = any( + isinstance(req, InitialRequest) and req.return_probs for req in requests + ) + + if needs_probs: + # Extract logit values for sampled tokens + try: + # Get last position logits for each request + batch_logits = [] + for i, req in enumerate(requests): + if lengths[i] > 0: + # Get logit at last position + last_idx = int(lengths[i]) - 1 + last_logits = hidden_states[i, last_idx, :] # [vocab_size] + probs = last_logits / sampling_info.temperatures.reshape(-1, 1) + probs[:] = mx.softmax(probs, axis=-1) + # Extract logit for the sampled token + token_id = int(token_ids[i]) + # logit_value = float(last_logits[token_id]) + # batch_logits.append(logit_value) + batch_logits.append(float(probs[i, token_id])) + + self._latest_token_probs = batch_logits if batch_logits else None + except Exception as e: + logger.debug(f"Failed to extract token probs: {e}") + self._latest_token_probs = None + else: self._latest_token_probs = None return token_ids diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index ad3d9dc4..1615a2a2 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -335,6 +335,7 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: forward_batch = prepared_inputs["forward_batch"] pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] + requests = prepared_inputs.get("requests", []) # Execute model with SGLang logits_output, _ = self.model_runner.forward( @@ -359,9 +360,15 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) - # Extract probs for the sampled tokens - if hasattr(logits_output, "next_token_logits"): - # Get probs for sampled tokens + # Only compute probs if any request in the batch needs it + # Check if any InitialRequest has return_probs=True + needs_probs = any( + isinstance(req, InitialRequest) and req.return_probs for req in requests + ) + + # Extract log probs for the sampled tokens only if needed + if needs_probs and hasattr(logits_output, "next_token_logits"): + # Get probs for sampled tokens (next_token_logits contains probabilities) real_probs = logits_output.next_token_logits[ torch.arange(len(next_token_ids)), next_token_ids ] From 17f4195e807da46a6ab8f7ec0bf77ce7faa5bca0 Mon Sep 17 00:00:00 2001 From: jason Date: Mon, 15 Dec 2025 21:13:48 +0800 Subject: [PATCH 08/11] executor remove _latest_token_probs --- src/parallax/server/executor/base_executor.py | 27 ++++++++++++++----- src/parallax/server/executor/mlx_executor.py | 23 +++++++--------- .../server/executor/sglang_executor.py | 13 ++++----- 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 6c1768c6..5b11755e 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -341,10 +341,26 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict } def prepare_next_batch_requests( - self, requests: List[Request], hidden_states: Any, context_lengths: Any + self, requests: List[Request], batch_output: Any, context_lengths: Any ) -> List[Request]: - """Prepares a batch of requests for the next stage of the pipeline.""" + """Prepares a batch of requests for the next stage of the pipeline. + + Args: + requests: List of requests in the batch + batch_output: Output from process_batch. Can be: + - For intermediate peers: hidden_states tensor + - For last peer: dict with 'hidden_states' and optional 'probs' + context_lengths: Context lengths for each request + """ if self.tp_rank == 0: + # Extract hidden_states and probs from output + if isinstance(batch_output, dict): + hidden_states = batch_output["hidden_states"] + token_probs = batch_output.get("probs", None) + else: + hidden_states = batch_output + token_probs = None + batched_requests = [] pre_length = 0 for i, src_request in enumerate(requests): @@ -371,9 +387,8 @@ def prepare_next_batch_requests( # Get prob for this request if available token_prob = None - if self.is_last_peer and hasattr(self, "_latest_token_probs"): - if self._latest_token_probs is not None and len(self._latest_token_probs) > i: - token_prob = self._latest_token_probs[i] + if self.is_last_peer and token_probs is not None and len(token_probs) > i: + token_prob = token_probs[i] next_req = self._prepare_next_single_request( src_request, hidden_state_for_req, token_prob @@ -490,7 +505,7 @@ def run_loop(self): # 7. Prepare requests for the next stage in the pipeline next_batch = self.prepare_next_batch_requests( requests=prepared_inputs["requests"], - hidden_states=output, + batch_output=output, context_lengths=prepared_inputs.get("context_lengths"), ) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index a1703b56..a4b9e55e 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -198,9 +198,6 @@ def __init__( f"KVCacheManager ready; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}" ) - # Store latest sampled token logit values (not full distribution) - self._latest_token_probs = None - def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" if not requests: @@ -350,11 +347,12 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: isinstance(req, InitialRequest) and req.return_probs for req in requests ) + token_probs = None if needs_probs: - # Extract logit values for sampled tokens + # Extract probability values for sampled tokens try: # Get last position logits for each request - batch_logits = [] + batch_probs = [] for i, req in enumerate(requests): if lengths[i] > 0: # Get logit at last position @@ -362,20 +360,19 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: last_logits = hidden_states[i, last_idx, :] # [vocab_size] probs = last_logits / sampling_info.temperatures.reshape(-1, 1) probs[:] = mx.softmax(probs, axis=-1) - # Extract logit for the sampled token - token_id = int(token_ids[i]) # logit_value = float(last_logits[token_id]) # batch_logits.append(logit_value) - batch_logits.append(float(probs[i, token_id])) + # Extract probability for the sampled token + token_id = int(token_ids[i]) + batch_probs.append(float(probs[i, token_id])) - self._latest_token_probs = batch_logits if batch_logits else None + token_probs = batch_probs if batch_probs else None except Exception as e: logger.debug(f"Failed to extract token probs: {e}") - self._latest_token_probs = None - else: - self._latest_token_probs = None + token_probs = None - return token_ids + # Return dict with token_ids and optional probs + return {"hidden_states": token_ids, "probs": token_probs} return hidden_states diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 1615a2a2..d020ee5e 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -159,9 +159,6 @@ def __init__( self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group - # Store latest sampled token logits (not full distribution) - self._latest_token_probs = None - def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" @@ -366,17 +363,17 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: isinstance(req, InitialRequest) and req.return_probs for req in requests ) - # Extract log probs for the sampled tokens only if needed + token_probs = None + # Extract probs for the sampled tokens only if needed if needs_probs and hasattr(logits_output, "next_token_logits"): # Get probs for sampled tokens (next_token_logits contains probabilities) real_probs = logits_output.next_token_logits[ torch.arange(len(next_token_ids)), next_token_ids ] - self._latest_token_probs = real_probs.cpu().float().tolist() - else: - self._latest_token_probs = None + token_probs = real_probs.cpu().float().tolist() - return next_token_ids + # Return dict with token_ids and optional probs + return {"hidden_states": next_token_ids, "probs": token_probs} else: # Intermediate peer: return hidden states for next peer # Note: SGLang stores hidden_states + residual separately From 2d8a80292e377092417fb2a502f0eeb36f0c935a Mon Sep 17 00:00:00 2001 From: jason Date: Tue, 16 Dec 2025 21:41:58 +0800 Subject: [PATCH 09/11] for many nodes --- src/parallax/p2p/message_util.py | 8 ++++++++ src/parallax/p2p/proto/forward.proto | 1 + src/parallax/p2p/proto/forward_pb2.py | 12 ++++++------ src/parallax/server/executor/mlx_executor.py | 4 +++- src/parallax/server/executor/sglang_executor.py | 4 +++- src/parallax/server/request.py | 4 ++++ 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/parallax/p2p/message_util.py b/src/parallax/p2p/message_util.py index 0bb97205..d6728f0d 100644 --- a/src/parallax/p2p/message_util.py +++ b/src/parallax/p2p/message_util.py @@ -54,6 +54,10 @@ def request_to_proto( if hasattr(request, "token_prob") and request.token_prob is not None: proto_req.token_prob = request.token_prob + # Add return_probs flag + if hasattr(request, "return_probs"): + proto_req.return_probs = request.return_probs + forward_request.reqs.append(proto_req) return forward_request @@ -95,6 +99,9 @@ def proto_to_request( if proto_req.HasField("token_prob"): token_prob = proto_req.token_prob + # Extract return_probs (defaults to False if not present) + return_probs = proto_req.return_probs if hasattr(proto_req, "return_probs") else False + request = IntermediateRequest( request_id=proto_req.rid, current_position=current_position, @@ -106,6 +113,7 @@ def proto_to_request( sampling_params=sampling_params, lora_path=proto_req.lora_path if proto_req.lora_path != "" else None, token_prob=token_prob, + return_probs=return_probs, ) requests.append(request) diff --git a/src/parallax/p2p/proto/forward.proto b/src/parallax/p2p/proto/forward.proto index 11078074..7a76e148 100644 --- a/src/parallax/p2p/proto/forward.proto +++ b/src/parallax/p2p/proto/forward.proto @@ -37,6 +37,7 @@ message Req { bytes hidden_states = 7; string lora_path = 8; optional float token_prob = 9; // Probability value for the sampled token + bool return_probs = 10; // Whether to return probabilities } message SamplingParams { diff --git a/src/parallax/p2p/proto/forward_pb2.py b/src/parallax/p2p/proto/forward_pb2.py index 6235b299..f25e3f3c 100644 --- a/src/parallax/p2p/proto/forward_pb2.py +++ b/src/parallax/p2p/proto/forward_pb2.py @@ -24,15 +24,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xef\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x17\n\ntoken_prob\x18\t \x01(\x02H\x00\x88\x01\x01\x42\r\n\x0b_token_prob\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\x85\x02\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x17\n\ntoken_prob\x18\t \x01(\x02H\x00\x88\x01\x01\x12\x14\n\x0creturn_probs\x18\n \x01(\x08\x42\r\n\x0b_token_prob\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.parallax.p2p.proto.forward_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_FORWARDMODE']._serialized_start=763 - _globals['_FORWARDMODE']._serialized_end=811 + _globals['_FORWARDMODE']._serialized_start=785 + _globals['_FORWARDMODE']._serialized_end=833 _globals['_FORWARDREQUEST']._serialized_start=50 _globals['_FORWARDREQUEST']._serialized_end=140 _globals['_FORWARDRESPONSE']._serialized_start=142 @@ -42,7 +42,7 @@ _globals['_ABORTRESPONSE']._serialized_start=206 _globals['_ABORTRESPONSE']._serialized_end=221 _globals['_REQ']._serialized_start=224 - _globals['_REQ']._serialized_end=463 - _globals['_SAMPLINGPARAMS']._serialized_start=466 - _globals['_SAMPLINGPARAMS']._serialized_end=761 + _globals['_REQ']._serialized_end=485 + _globals['_SAMPLINGPARAMS']._serialized_start=488 + _globals['_SAMPLINGPARAMS']._serialized_end=783 # @@protoc_insertion_point(module_scope) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index a4b9e55e..032a8a5c 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -344,7 +344,9 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: ) needs_probs = any( - isinstance(req, InitialRequest) and req.return_probs for req in requests + (isinstance(req, InitialRequest) and req.return_probs) + or (isinstance(req, IntermediateRequest) and req.return_probs) + for req in requests ) token_probs = None diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index d020ee5e..d78b77d7 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -360,7 +360,9 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Only compute probs if any request in the batch needs it # Check if any InitialRequest has return_probs=True needs_probs = any( - isinstance(req, InitialRequest) and req.return_probs for req in requests + (isinstance(req, InitialRequest) and req.return_probs) + or (isinstance(req, IntermediateRequest) and req.return_probs) + for req in requests ) token_probs = None diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index e902a423..d5ab84fb 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -265,6 +265,7 @@ def __init__( sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, token_prob: Optional[float] = None, + return_probs: bool = False, ): super().__init__( request_id=request_id, @@ -287,6 +288,7 @@ def __init__( self.hidden_states = hidden_states self.next_token_id = next_token_id self.token_prob = token_prob + self.return_probs = return_probs @property def input_length(self) -> int: @@ -339,6 +341,7 @@ def from_initial_request( routing_table=initial_request.routing_table, lora_path=lora_path, token_prob=token_prob, + return_probs=initial_request.return_probs, ) @classmethod @@ -364,6 +367,7 @@ def from_intermediate_request( sampling_params=old_request.sampling_params, lora_path=lora_path, token_prob=token_prob, + return_probs=old_request.return_probs, ) def __repr__(self): From 48df15f23a1fc421b99d3b868e2d3d192df480e1 Mon Sep 17 00:00:00 2001 From: jason Date: Wed, 17 Dec 2025 21:57:09 +0800 Subject: [PATCH 10/11] return dict --- src/parallax/server/executor/base_executor.py | 27 ++++++++++--------- src/parallax/server/executor/mlx_executor.py | 8 +++--- .../server/executor/sglang_executor.py | 7 +++-- src/parallax/server/executor/vllm_executor.py | 8 +++--- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 5b11755e..4f947421 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -347,19 +347,18 @@ def prepare_next_batch_requests( Args: requests: List of requests in the batch - batch_output: Output from process_batch. Can be: - - For intermediate peers: hidden_states tensor - - For last peer: dict with 'hidden_states' and optional 'probs' + batch_output: Output from process_batch. Always a dict with: + - 'hidden_states': token IDs (last peer) or hidden states tensor (intermediate peer) + - 'probs': list of probabilities (last peer) or None (intermediate peer) context_lengths: Context lengths for each request """ if self.tp_rank == 0: - # Extract hidden_states and probs from output - if isinstance(batch_output, dict): - hidden_states = batch_output["hidden_states"] - token_probs = batch_output.get("probs", None) - else: - hidden_states = batch_output - token_probs = None + # Extract hidden_states and probs from output (always a dict now) + assert isinstance( + batch_output, dict + ), f"Expected dict from process_batch, got {type(batch_output)}" + hidden_states = batch_output["hidden_states"] + token_probs = batch_output["probs"] batched_requests = [] pre_length = 0 @@ -386,9 +385,11 @@ def prepare_next_batch_requests( pre_length += 1 # Get prob for this request if available - token_prob = None - if self.is_last_peer and token_probs is not None and len(token_probs) > i: - token_prob = token_probs[i] + token_prob = ( + token_probs[i] + if (self.is_last_peer and token_probs and i < len(token_probs)) + else None + ) next_req = self._prepare_next_single_request( src_request, hidden_state_for_req, token_prob diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 032a8a5c..0314eccf 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -255,9 +255,8 @@ def handle_input_requests(self, requests: List[Request]): req_dict["length"] = True # Add prob value for the sampled token (if requested and available) - if hasattr(original_req, "return_probs") and original_req.return_probs: - if hasattr(req, "token_prob") and req.token_prob is not None: - req_dict["probs"] = req.token_prob + if original_req.return_probs and req.token_prob is not None: + req_dict["probs"] = req.token_prob if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) @@ -376,7 +375,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Return dict with token_ids and optional probs return {"hidden_states": token_ids, "probs": token_probs} - return hidden_states + # Intermediate peer: return hidden states without probs + return {"hidden_states": hidden_states, "probs": None} def _release_request(self, rid: str): """Release per-request resources in MLX.""" diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index d78b77d7..8429f8fa 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -301,9 +301,8 @@ def handle_input_requests(self, requests: List[Request]): req_dict["length"] = True # Add prob value for the sampled token (if requested and available) - if hasattr(original_req, "return_probs") and original_req.return_probs: - if hasattr(req, "token_prob") and req.token_prob is not None: - req_dict["probs"] = req.token_prob + if original_req.return_probs and req.token_prob is not None: + req_dict["probs"] = req.token_prob if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) @@ -382,7 +381,7 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: final_hidden_states = ( logits_output.tensors["hidden_states"] + logits_output.tensors["residual"] ) - return final_hidden_states + return {"hidden_states": final_hidden_states, "probs": None} def _release_request(self, rid: str): """Release per-request resources in SGLang.""" diff --git a/src/parallax/server/executor/vllm_executor.py b/src/parallax/server/executor/vllm_executor.py index 221c270b..bd2a37ba 100755 --- a/src/parallax/server/executor/vllm_executor.py +++ b/src/parallax/server/executor/vllm_executor.py @@ -232,13 +232,15 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: for seq in sampled_token_ids: padded_seq = seq + [-1] * (max_len - len(seq)) # Pad with -1 padded_tokens.append(padded_seq) - return torch.tensor(padded_tokens, dtype=torch.int64) + token_ids = torch.tensor(padded_tokens, dtype=torch.int64) else: - return torch.tensor(sampled_token_ids, dtype=torch.int64) + token_ids = torch.tensor(sampled_token_ids, dtype=torch.int64) + # vLLM doesn't support probs yet + return {"hidden_states": token_ids, "probs": None} else: # Intermediate peer: return hidden states for next peer final_hidden_states = output.tensors["hidden_states"] + output.tensors["residual"] - return final_hidden_states + return {"hidden_states": final_hidden_states, "probs": None} def _release_request(self, rid: str): """Release per-request resources in vLLM.""" From 99590ff1c4dcec7ca582bafc6b2a17da22a66e7c Mon Sep 17 00:00:00 2001 From: jason Date: Wed, 17 Dec 2025 23:55:06 +0800 Subject: [PATCH 11/11] correct unittest --- tests/test_executor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_executor.py b/tests/test_executor.py index 5fb734ba..f918d1be 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -69,13 +69,13 @@ def run_executor_pipeline_stage(executor, requests, batch_type, is_last_peer): prepared_batch = executor.prepare_batch_inputs(input_batch) assert prepared_batch is not None, "Failed to prepare batch inputs" batch_data = prepared_batch[batch_type] - hidden_states = executor.process_batch(batch_data, return_decoded_tokens=is_last_peer) + batch_output = executor.process_batch(batch_data, return_decoded_tokens=is_last_peer) output_reqs = executor.prepare_next_batch_requests( requests=batch_data["requests"], - hidden_states=hidden_states, + batch_output=batch_output, context_lengths=batch_data.get("context_lengths"), ) - return output_reqs, hidden_states + return output_reqs, batch_output @pytest.mark.parametrize( @@ -190,7 +190,8 @@ def test_decode_pipeline_multiple_steps(pipeline_devices, pp_end_layers, num_dec print(f"prompt: {prompt}") print(f"mlx-lm reference generation: {ref_output_text}") output_tokens_for_prompt = [ - gen_step_tokens[i].item() for gen_step_tokens in generated_tokens_pipeline + gen_step_tokens["hidden_states"][i].item() + for gen_step_tokens in generated_tokens_pipeline ] # Decode the token IDs into a string