diff --git a/src/parallax/p2p/message_util.py b/src/parallax/p2p/message_util.py index 1181c988..d6728f0d 100644 --- a/src/parallax/p2p/message_util.py +++ b/src/parallax/p2p/message_util.py @@ -50,6 +50,14 @@ def request_to_proto( if request.next_token_id is not None: proto_req.next_token_id = request.next_token_id + # Add token_prob if available + if hasattr(request, "token_prob") and request.token_prob is not None: + proto_req.token_prob = request.token_prob + + # Add return_probs flag + if hasattr(request, "return_probs"): + proto_req.return_probs = request.return_probs + forward_request.reqs.append(proto_req) return forward_request @@ -86,6 +94,14 @@ def proto_to_request( sampling_params = proto_to_sampling_params(proto_req.sampling_params) + # Extract token_prob if present + token_prob = None + if proto_req.HasField("token_prob"): + token_prob = proto_req.token_prob + + # Extract return_probs (defaults to False if not present) + return_probs = proto_req.return_probs if hasattr(proto_req, "return_probs") else False + request = IntermediateRequest( request_id=proto_req.rid, current_position=current_position, @@ -96,6 +112,8 @@ def proto_to_request( next_token_id=next_token_id, sampling_params=sampling_params, lora_path=proto_req.lora_path if proto_req.lora_path != "" else None, + token_prob=token_prob, + return_probs=return_probs, ) requests.append(request) diff --git a/src/parallax/p2p/proto/forward.proto b/src/parallax/p2p/proto/forward.proto index 4854ab57..7a76e148 100644 --- a/src/parallax/p2p/proto/forward.proto +++ b/src/parallax/p2p/proto/forward.proto @@ -36,6 +36,8 @@ message Req { int32 next_token_id = 6; bytes hidden_states = 7; string lora_path = 8; + optional float token_prob = 9; // Probability value for the sampled token + bool return_probs = 10; // Whether to return probabilities } message SamplingParams { diff --git a/src/parallax/p2p/proto/forward_pb2.py b/src/parallax/p2p/proto/forward_pb2.py index 9e4e3f69..f25e3f3c 100644 --- a/src/parallax/p2p/proto/forward_pb2.py +++ b/src/parallax/p2p/proto/forward_pb2.py @@ -24,15 +24,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xc7\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\x85\x02\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x17\n\ntoken_prob\x18\t \x01(\x02H\x00\x88\x01\x01\x12\x14\n\x0creturn_probs\x18\n \x01(\x08\x42\r\n\x0b_token_prob\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.parallax.p2p.proto.forward_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_FORWARDMODE']._serialized_start=723 - _globals['_FORWARDMODE']._serialized_end=771 + _globals['_FORWARDMODE']._serialized_start=785 + _globals['_FORWARDMODE']._serialized_end=833 _globals['_FORWARDREQUEST']._serialized_start=50 _globals['_FORWARDREQUEST']._serialized_end=140 _globals['_FORWARDRESPONSE']._serialized_start=142 @@ -42,7 +42,7 @@ _globals['_ABORTRESPONSE']._serialized_start=206 _globals['_ABORTRESPONSE']._serialized_end=221 _globals['_REQ']._serialized_start=224 - _globals['_REQ']._serialized_end=423 - _globals['_SAMPLINGPARAMS']._serialized_start=426 - _globals['_SAMPLINGPARAMS']._serialized_end=721 + _globals['_REQ']._serialized_end=485 + _globals['_SAMPLINGPARAMS']._serialized_start=488 + _globals['_SAMPLINGPARAMS']._serialized_end=783 # @@protoc_insertion_point(module_scope) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index f71bc051..a56fde2b 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -358,10 +358,25 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict } def prepare_next_batch_requests( - self, requests: List[Request], hidden_states: Any, context_lengths: Any + self, requests: List[Request], batch_output: Any, context_lengths: Any ) -> List[Request]: - """Prepares a batch of requests for the next stage of the pipeline.""" + """Prepares a batch of requests for the next stage of the pipeline. + + Args: + requests: List of requests in the batch + batch_output: Output from process_batch. Always a dict with: + - 'hidden_states': token IDs (last peer) or hidden states tensor (intermediate peer) + - 'probs': list of probabilities (last peer) or None (intermediate peer) + context_lengths: Context lengths for each request + """ if self.tp_rank == 0: + # Extract hidden_states and probs from output (always a dict now) + assert isinstance( + batch_output, dict + ), f"Expected dict from process_batch, got {type(batch_output)}" + hidden_states = batch_output["hidden_states"] + token_probs = batch_output["probs"] + batched_requests = [] pre_length = 0 for i, src_request in enumerate(requests): @@ -386,7 +401,16 @@ def prepare_next_batch_requests( hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] pre_length += 1 - next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) + # Get prob for this request if available + token_prob = ( + token_probs[i] + if (self.is_last_peer and token_probs and i < len(token_probs)) + else None + ) + + next_req = self._prepare_next_single_request( + src_request, hidden_state_for_req, token_prob + ) batched_requests.append(next_req) else: batched_requests = None @@ -502,7 +526,7 @@ def run_loop(self): # 7. Prepare requests for the next stage in the pipeline next_batch = self.prepare_next_batch_requests( requests=prepared_inputs["requests"], - hidden_states=output, + batch_output=output, context_lengths=prepared_inputs.get("context_lengths"), ) @@ -612,6 +636,7 @@ def _handle_raw_request(self, raw_request: Dict): logger.debug(f"Final input token length for request ID {rid}: {input_token_num}") lora_path = raw_request.get("lora_path") + return_probs = raw_request.get("return_probs", False) # Get return_probs parameter raw_sampling_params = raw_request.get("sampling_params") if raw_sampling_params is None: @@ -636,6 +661,7 @@ def _handle_raw_request(self, raw_request: Dict): max_new_tokens=max_new_tokens, max_total_length=max_total_length, lora_path=lora_path, + return_probs=return_probs, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] @@ -669,7 +695,9 @@ def _notify_http_request_error(self, raw_request: Optional[Dict], error: Excepti except Exception: # pragma: no cover - best effort notification logger.debug("Failed to send error notification to HTTP handler", exc_info=True) - def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request: + def _prepare_next_single_request( + self, request: Request, hidden_states: Any, token_prob: Optional[float] = None + ) -> Request: """Handle request state changes both inter and intra peers. This function prepares the request object to be sent to the *next* peer in the @@ -678,6 +706,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Args: request: The request that was just processed by this peer. hidden_states: The output hidden_states/output_ids from the model for this request. + token_prob: The probability value for the sampled token (optional). Returns: A new Request object ready to be sent to the next destination. @@ -698,6 +727,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, + token_prob=token_prob, ) if self.is_last_peer: # Last peer decodes a token and sends it back to the first peer. @@ -716,6 +746,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, + token_prob=token_prob, ) # This peer is the first or an intermediate peer. if self.is_first_peer: diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 6b5ecf9d..5aeac1b7 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -267,6 +267,11 @@ def handle_input_requests(self, requests: List[Request]): req_dict["eos"] = True if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True + + # Add prob value for the sampled token (if requested and available) + if original_req.return_probs and req.token_prob is not None: + req_dict["probs"] = req.token_prob + if self.enable_weight_refit: req_dict["weight_version"] = self.weight_version if hasattr(self, "send_to_ipc_socket"): @@ -350,11 +355,48 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Process last peer: need additional sampling + detokenization if return_decoded_tokens: sampling_info = SamplingBatchInfo.from_reqs(requests) - return mx.array( + + # For MLX, hidden_states at last shard is already logits (after lm_head) + # hidden_states shape: [batch_size, seq_len, vocab_size] + token_ids = mx.array( self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info) ) - return hidden_states + needs_probs = any( + (isinstance(req, InitialRequest) and req.return_probs) + or (isinstance(req, IntermediateRequest) and req.return_probs) + for req in requests + ) + + token_probs = None + if needs_probs: + # Extract probability values for sampled tokens + try: + # Get last position logits for each request + batch_probs = [] + for i, req in enumerate(requests): + if lengths[i] > 0: + # Get logit at last position + last_idx = int(lengths[i]) - 1 + last_logits = hidden_states[i, last_idx, :] # [vocab_size] + probs = last_logits / sampling_info.temperatures.reshape(-1, 1) + probs[:] = mx.softmax(probs, axis=-1) + # logit_value = float(last_logits[token_id]) + # batch_logits.append(logit_value) + # Extract probability for the sampled token + token_id = int(token_ids[i]) + batch_probs.append(float(probs[i, token_id])) + + token_probs = batch_probs if batch_probs else None + except Exception as e: + logger.debug(f"Failed to extract token probs: {e}") + token_probs = None + + # Return dict with token_ids and optional probs + return {"hidden_states": token_ids, "probs": token_probs} + + # Intermediate peer: return hidden states without probs + return {"hidden_states": hidden_states, "probs": None} def _release_request(self, rid: str): """Release per-request resources in MLX.""" diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index c795935e..85b5c67c 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -307,6 +307,11 @@ def handle_input_requests(self, requests: List[Request]): req_dict["eos"] = True if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True + + # Add prob value for the sampled token (if requested and available) + if original_req.return_probs and req.token_prob is not None: + req_dict["probs"] = req.token_prob + if self.enable_weight_refit: req_dict["weight_version"] = self.weight_version if hasattr(self, "send_to_ipc_socket"): @@ -336,6 +341,7 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: forward_batch = prepared_inputs["forward_batch"] pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] + requests = prepared_inputs.get("requests", []) # Execute model with SGLang logits_output, _ = self.model_runner.forward( @@ -359,14 +365,33 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if return_decoded_tokens: # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) - return next_token_ids + + # Only compute probs if any request in the batch needs it + # Check if any InitialRequest has return_probs=True + needs_probs = any( + (isinstance(req, InitialRequest) and req.return_probs) + or (isinstance(req, IntermediateRequest) and req.return_probs) + for req in requests + ) + + token_probs = None + # Extract probs for the sampled tokens only if needed + if needs_probs and hasattr(logits_output, "next_token_logits"): + # Get probs for sampled tokens (next_token_logits contains probabilities) + real_probs = logits_output.next_token_logits[ + torch.arange(len(next_token_ids)), next_token_ids + ] + token_probs = real_probs.cpu().float().tolist() + + # Return dict with token_ids and optional probs + return {"hidden_states": next_token_ids, "probs": token_probs} else: # Intermediate peer: return hidden states for next peer # Note: SGLang stores hidden_states + residual separately final_hidden_states = ( logits_output.tensors["hidden_states"] + logits_output.tensors["residual"] ) - return final_hidden_states + return {"hidden_states": final_hidden_states, "probs": None} def _release_request(self, rid: str): """Release per-request resources in SGLang.""" diff --git a/src/parallax/server/executor/vllm_executor.py b/src/parallax/server/executor/vllm_executor.py index b18e739a..c3ea96e9 100755 --- a/src/parallax/server/executor/vllm_executor.py +++ b/src/parallax/server/executor/vllm_executor.py @@ -237,13 +237,15 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: for seq in sampled_token_ids: padded_seq = seq + [-1] * (max_len - len(seq)) # Pad with -1 padded_tokens.append(padded_seq) - return torch.tensor(padded_tokens, dtype=torch.int64) + token_ids = torch.tensor(padded_tokens, dtype=torch.int64) else: - return torch.tensor(sampled_token_ids, dtype=torch.int64) + token_ids = torch.tensor(sampled_token_ids, dtype=torch.int64) + # vLLM doesn't support probs yet + return {"hidden_states": token_ids, "probs": None} else: # Intermediate peer: return hidden states for next peer final_hidden_states = output.tensors["hidden_states"] + output.tensors["residual"] - return final_hidden_states + return {"hidden_states": final_hidden_states, "probs": None} def _release_request(self, rid: str): """Release per-request resources in vLLM.""" diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index d529af52..f55c468d 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -22,7 +22,7 @@ import uuid from dataclasses import dataclass, field from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional import fastapi import uvicorn @@ -87,6 +87,10 @@ class HTTPRequestInfo: error_message: Optional[str] = None error_type: Optional[str] = None error_status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR + # probs support + return_probs: bool = False # Whether to return probabilities + probs_list: List = field(default_factory=list) # Store probs for each token + token_ids_list: List = field(default_factory=list) # Store token IDs for each token # Weight version for RL weight_version: Optional[int] = None @@ -130,6 +134,7 @@ def create_request(self, request: Dict): rid = request["rid"] stream = request.get("stream", False) model = request.get("model", "default") + return_probs = request.get("return_probs", False) # Check if probs requested chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) create_time = time.time() @@ -142,6 +147,7 @@ def create_request(self, request: Dict): create_time=create_time, update_time=update_time, detokenizer=detokenizer, + return_probs=return_probs, ) if stream: request_info.token_queue = asyncio.Queue() @@ -153,6 +159,11 @@ def release_request(self, rid: str): def send_request(self, request: Dict): """Sends the request to model executor using IPC.""" + # Ensure return_probs is included in the request sent to executor + rid = request.get("rid") + if rid and rid in self.processing_requests: + request_info = self.processing_requests[rid] + request["return_probs"] = request_info.return_probs self.send_to_executor.send_pyobj(request) def abort_request(self, request_id: str): @@ -215,6 +226,12 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False): } choice = response["choices"][0] choice["delta"] = {"role": role, "content": content} + # Add probs in the last chunk if requested (convert to object array format) + if is_last and request_info.return_probs: + choice["probs"] = [ + {self.tokenizer.decode([token_id]): prob} + for token_id, prob in zip(request_info.token_ids_list, request_info.probs_list) + ] if request_info.weight_version: response["weight_version"] = request_info.weight_version response_json = json.dumps(response, separators=(",", ":")) @@ -284,6 +301,12 @@ def generate_non_stream_response(self, rid): "reasoning_content": None, "tool_calls": None, } + # Add probs if requested (convert to object array format) + if request_info.return_probs: + choice["probs"] = [ + {self.tokenizer.decode([token_id]): prob} + for token_id, prob in zip(request_info.token_ids_list, request_info.probs_list) + ] if request_info.weight_version: response["weight_version"] = request_info.weight_version return response @@ -338,6 +361,11 @@ async def _handle_loop(self): output = request_info.detokenizer.last_segment request_info.weight_version = recv_dict.get("weight_version", None) + # Store probs and token IDs if requested + if request_info.return_probs and "probs" in recv_dict: + request_info.probs_list.append(recv_dict["probs"]) + request_info.token_ids_list.append(next_token_id) + is_finished = recv_dict.get("eos", False) or recv_dict.get("length", False) # Only process and send non-EOS tokens diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f76accdc..d5ab84fb 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -158,6 +158,7 @@ def __init__( max_total_length: int = 1024, status: RequestStatus = RequestStatus.PREFILLING, lora_path: Optional[str] = None, + return_probs: bool = False, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -170,6 +171,7 @@ def __init__( lora_path=lora_path, ) self.prompt = prompt + self.return_probs = return_probs if max_new_tokens < 1: raise ValueError("max_new_tokens must be at least 1.") @@ -262,6 +264,8 @@ def __init__( routing_table: Optional[List[str]] = [], sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, + token_prob: Optional[float] = None, + return_probs: bool = False, ): super().__init__( request_id=request_id, @@ -283,6 +287,8 @@ def __init__( self.current_position = current_position self.hidden_states = hidden_states self.next_token_id = next_token_id + self.token_prob = token_prob + self.return_probs = return_probs @property def input_length(self) -> int: @@ -301,6 +307,7 @@ def from_initial_request( initial_request: InitialRequest, hidden_states: Optional[Any] = None, lora_path: Optional[str] = None, + token_prob: Optional[float] = None, ) -> "IntermediateRequest": """Convert an InitialRequest to an IntermediateRequest. @@ -333,6 +340,8 @@ def from_initial_request( sampling_params=initial_request.sampling_params, routing_table=initial_request.routing_table, lora_path=lora_path, + token_prob=token_prob, + return_probs=initial_request.return_probs, ) @classmethod @@ -341,6 +350,7 @@ def from_intermediate_request( old_request: "IntermediateRequest", new_hidden_states: Any, lora_path: Optional[str] = None, + token_prob: Optional[float] = None, ) -> "IntermediateRequest": """ Creates a new IntermediateRequest from an old one, with updated hidden states. @@ -356,6 +366,8 @@ def from_intermediate_request( routing_table=old_request.routing_table, sampling_params=old_request.sampling_params, lora_path=lora_path, + token_prob=token_prob, + return_probs=old_request.return_probs, ) def __repr__(self): diff --git a/tests/test_executor.py b/tests/test_executor.py index 5fb734ba..f918d1be 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -69,13 +69,13 @@ def run_executor_pipeline_stage(executor, requests, batch_type, is_last_peer): prepared_batch = executor.prepare_batch_inputs(input_batch) assert prepared_batch is not None, "Failed to prepare batch inputs" batch_data = prepared_batch[batch_type] - hidden_states = executor.process_batch(batch_data, return_decoded_tokens=is_last_peer) + batch_output = executor.process_batch(batch_data, return_decoded_tokens=is_last_peer) output_reqs = executor.prepare_next_batch_requests( requests=batch_data["requests"], - hidden_states=hidden_states, + batch_output=batch_output, context_lengths=batch_data.get("context_lengths"), ) - return output_reqs, hidden_states + return output_reqs, batch_output @pytest.mark.parametrize( @@ -190,7 +190,8 @@ def test_decode_pipeline_multiple_steps(pipeline_devices, pp_end_layers, num_dec print(f"prompt: {prompt}") print(f"mlx-lm reference generation: {ref_output_text}") output_tokens_for_prompt = [ - gen_step_tokens[i].item() for gen_step_tokens in generated_tokens_pipeline + gen_step_tokens["hidden_states"][i].item() + for gen_step_tokens in generated_tokens_pipeline ] # Decode the token IDs into a string