Skip to content

Conversation

@geoffreyQiu
Copy link
Collaborator

Implement HSTU KVCacheManager V2:

  • Asynchronous kvcache manager operations
  • Optimized onloading and offloading

@geoffreyQiu geoffreyQiu marked this pull request as draft December 8, 2025 07:17
@greptile-apps
Copy link

greptile-apps bot commented Jan 27, 2026

Greptile Summary

This PR implements an asynchronous KV cache management system (V2) for HSTU inference, replacing synchronous operations with a multi-threaded approach using ThreadPoolExecutor for improved performance.

Key Changes

  • Added AsyncHSTUKVCacheManager class that handles KV cache preparation, onloading, and offloading asynchronously using separate thread pools
  • Implemented extensive C++ backend (HostKVStorageImpl and GPUKVCacheMangerImpl) with ~1100 lines of new CUDA code for memory management
  • Added GatherPagedKVCacheAllLayers CUDA kernel to gather KV cache across all layers simultaneously for optimization
  • Integrated async manager into InferenceRankingGR.forward() to overlap KV cache operations with embedding and preprocessing
  • Added test scripts and evaluation code for consistency checking

Issues Found

  • Race condition in async_kvcache_manager.py:97-99 where static_onload_handle.reset() is called before the onload future completes
  • Silent error handling in CUDA code where the abort is commented out in error checking (line 43-50 of paged_kvcache_ops_cuda.cpp)
  • Invalid state handling when new_tokens <= 0 only prints debug info instead of raising an error
  • Commented out synchronization at line 115 in prepare_kvcache_wait() that may cause issues
  • Debug/profiling code with hardcoded paths left in production code

Architecture

The async pattern introduces a pipeline where cache preparation and onloading happen in background threads while the main thread processes embeddings, enabling better GPU utilization and reduced latency.

Confidence Score: 2/5

  • This PR has several critical issues including race conditions and disabled error handling that need resolution before merging.
  • Score reflects multiple logic issues: a race condition with handle reset timing, commented-out error abort in CUDA code allowing silent failures, missing synchronization that could cause data corruption, and incomplete error handling for edge cases. While the architecture is sound, these issues pose runtime reliability risks.
  • Pay close attention to examples/hstu/modules/async_kvcache_manager.py for threading synchronization issues and examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp for error handling.

Important Files Changed

Filename Overview
examples/hstu/modules/async_kvcache_manager.py New async KV cache manager with ThreadPoolExecutor for async operations, potential race condition with onload handle reset
examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp Major CUDA implementation with 1100+ new lines, includes host storage and GPU cache manager implementations
examples/hstu/model/inference_ranking_gr.py Integrated async KV cache manager into inference pipeline, replaced synchronous operations with async pattern
examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu Added new CUDA kernel for gathering KV cache across all layers simultaneously for optimization

Sequence Diagram

sequenceDiagram
    participant Client as Inference Client
    participant Model as InferenceRankingGR
    participant AsyncMgr as AsyncHSTUKVCacheManager
    participant PrepareExec as Prepare Executor Thread
    participant OnloadExec as Onload Executor Thread
    participant GPUMgr as GPU KV Cache Manager
    participant HostMgr as Host KV Storage
    
    Client->>Model: forward(batch, user_ids, total_history_lengths)
    Model->>AsyncMgr: prepare_kvcache_async()
    AsyncMgr->>PrepareExec: submit prepare_kvcache task
    AsyncMgr->>OnloadExec: submit onload_kvcache task
    AsyncMgr->>Model: return futures & buffers
    
    par Async Preparation
        PrepareExec->>GPUMgr: prepare_kvcache (allocate pages)
        PrepareExec->>HostMgr: lookup cached data
        PrepareExec-->>AsyncMgr: metadata ready
    and Async Onload
        OnloadExec->>HostMgr: get cached KV data
        OnloadExec->>GPUMgr: onload to GPU cache
        OnloadExec-->>AsyncMgr: onload complete
    end
    
    Model->>Model: strip_cached_tokens()
    Model->>Model: embedding_collection()
    Model->>Model: preprocessor()
    
    Model->>AsyncMgr: prepare_kvcache_wait()
    AsyncMgr->>PrepareExec: wait for metadata future
    AsyncMgr->>Model: return KVCacheMetadata
    
    Model->>Model: hstu_block.predict()
    Note over Model: Uses KV cache with onload handle
    
    Model->>AsyncMgr: offload_kvcache()
    AsyncMgr->>GPUMgr: offload_kvcache (async)
    GPUMgr->>GPUMgr: gather KV pages
    GPUMgr->>HostMgr: copy to host storage
    
    Model->>Model: postprocessor & MLP
    Model->>Client: return logits
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +97 to +99
static_onload_handle.reset()
onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache,
user_ids, static_onload_handle)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Race condition: static_onload_handle.reset() is called before waiting for the previous onload_fut, which could still be using the handle.

Suggested change
static_onload_handle.reset()
onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache,
user_ids, static_onload_handle)
kvcache_metadata_fut = self.executor.submit(paged_kvcache_ops.prepare_kvcache,
self.gpu_kvcache_mgr, self.host_kv_mgr,
user_ids, total_history_lengths,
static_page_ids_gpu_buffer, static_offload_page_ids_gpu_buffer,
offload_uids_buffer,
metadata_host_buffer, metadata_gpu_buffer)
# Wait for previous onload to complete before resetting handle
onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache,
user_ids, static_onload_handle)
static_onload_handle.reset()

Comment on lines +81 to +84
new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ])
if new_tokens <= 0:
print(total_history_lengths)
print(origin_cached_lengths)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Assertion will fail if new_tokens is 0 or negative, but the code continues execution after printing. This will cause issues downstream.

Suggested change
new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ])
if new_tokens <= 0:
print(total_history_lengths)
print(origin_cached_lengths)
new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ])
if new_tokens <= 0:
raise ValueError(f"Invalid new_tokens: {new_tokens}. total_history_lengths={total_history_lengths}, origin_cached_lengths={origin_cached_lengths}")

Comment on lines +184 to +186
num_cached = torch.maximum(origin_num_cached - num_context, torch.tensor([0], dtype=torch.int32))
num_cached_action = num_cached // 2
num_cached_item = num_cached - num_cached_action
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The torch.maximum will create a new tensor on each call. Use torch.clamp_min with inplace operation or cache the zero tensor for better performance.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

metadata_gpu_buffer,
static_onload_handle):

# onload_fut.result()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The onload_fut.result() is commented out but should be called before kvcache_metadata_fut.result() since onload must complete before the metadata is used. Was commenting out the onload wait intentional, or should it be executed before metadata wait?

Comment on lines +43 to +50
#define cudaCheck(ans) { cudaSuccesAssert((ans), __FILE__, __LINE__); }
inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
// if (abort) exit(code);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The cudaSuccesAssert function has abort commented out, which means CUDA errors will only print warnings but not halt execution. This can lead to silent failures and corrupted data.

Suggested change
#define cudaCheck(ans) { cudaSuccesAssert((ans), __FILE__, __LINE__); }
inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
// if (abort) exit(code);
}
inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}

Comment on lines +329 to +360
# if int(user_ids[0]) == 0:
# pg.dmp = True
if not disable_kvcache:
logits = model.forward(batch, user_ids, total_history_lengths.cpu())
else:
logits = model.forward_nokvcache(batch)

if pg.dmp:
if disable_kvcache:
for lidx in range(model._hstu_config.num_layers):
if user_ids[0] < 10 or user_ids[0] >= 690:
shutil.move(f"/tmp/in_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy")
shutil.move(f"/tmp/key_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy")
shutil.move(f"/tmp/value_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy")
shutil.move(f"/tmp/attn_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy")
shutil.move(f"/tmp/out_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy")

else:
os.remove(f"/tmp/key_l{lidx}.npy")
os.remove(f"/tmp/value_l{lidx}.npy")
else:
for lidx in range(model._hstu_config.num_layers):
if user_ids[0] < 10 or user_ids[0] >= 690:
shutil.move(f"/tmp/in_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy")
shutil.move(f"/tmp/key_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy")
shutil.move(f"/tmp/value_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy")
shutil.move(f"/tmp/attn_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy")
shutil.move(f"/tmp/out_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy")
else:
os.remove(f"/tmp/key_l{lidx}.npy")
os.remove(f"/tmp/value_l{lidx}.npy")
pg.dmp = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Debugging/profiling code with hardcoded file paths (/tmp/, dump/, cached/) and user ID checks should be removed or guarded behind a debug flag before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants