diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 93cfb1c1..aaf3fcc6 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -55,26 +55,25 @@ def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None: def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - This is a naive implementation and not optimized for speed. + Optimized implementation of interleave_kv to reduce memory overhead. """ assert xk.ndim == xv.ndim == 3 # (B * T, H, D) assert xk.shape == xv.shape - - if all([s == 0 for s in self.metadata.seqlens]): - # No cache to interleave + + if all(seqlen == 0 for seqlen in self.metadata.seqlens): + # No data to interleave return xk, xv - # Make it a list of [(T, H, D)] - xk: Tuple[torch.Tensor] = torch.split(xk, self.metadata.seqlens) # type: ignore - xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore - assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}" - - # Retrieve cache - cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)] - cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)] - - interleaved_k = interleave_list(cache_k, list(xk)) - interleaved_v = interleave_list(cache_v, list(xv)) + # Efficiently split the tensors based on seqlens + xk_splits = torch.split(xk, self.metadata.seqlens) + xv_splits = torch.split(xv, self.metadata.seqlens) + + # Retrieve cached values up to the valid sequence length + cache_k_splits = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)] + cache_v_splits = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)] + + interleaved_k = [item for pair in zip(cache_k_splits, xk_splits) for item in pair] + interleaved_v = [item for pair in zip(cache_v_splits, xv_splits) for item in pair] return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0) @@ -198,4 +197,4 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: prefill=first_prefill or subsequent_prefill, mask=mask, seqlens=seqlens, - ) + ) \ No newline at end of file