-
Notifications
You must be signed in to change notification settings - Fork 43
[WIP] HSTU KV Cache Manager V2 #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Greptile SummaryThis PR implements an asynchronous KV cache management system (V2) for HSTU inference, replacing synchronous operations with a multi-threaded approach using Key Changes
Issues Found
ArchitectureThe async pattern introduces a pipeline where cache preparation and onloading happen in background threads while the main thread processes embeddings, enabling better GPU utilization and reduced latency. Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Client as Inference Client
participant Model as InferenceRankingGR
participant AsyncMgr as AsyncHSTUKVCacheManager
participant PrepareExec as Prepare Executor Thread
participant OnloadExec as Onload Executor Thread
participant GPUMgr as GPU KV Cache Manager
participant HostMgr as Host KV Storage
Client->>Model: forward(batch, user_ids, total_history_lengths)
Model->>AsyncMgr: prepare_kvcache_async()
AsyncMgr->>PrepareExec: submit prepare_kvcache task
AsyncMgr->>OnloadExec: submit onload_kvcache task
AsyncMgr->>Model: return futures & buffers
par Async Preparation
PrepareExec->>GPUMgr: prepare_kvcache (allocate pages)
PrepareExec->>HostMgr: lookup cached data
PrepareExec-->>AsyncMgr: metadata ready
and Async Onload
OnloadExec->>HostMgr: get cached KV data
OnloadExec->>GPUMgr: onload to GPU cache
OnloadExec-->>AsyncMgr: onload complete
end
Model->>Model: strip_cached_tokens()
Model->>Model: embedding_collection()
Model->>Model: preprocessor()
Model->>AsyncMgr: prepare_kvcache_wait()
AsyncMgr->>PrepareExec: wait for metadata future
AsyncMgr->>Model: return KVCacheMetadata
Model->>Model: hstu_block.predict()
Note over Model: Uses KV cache with onload handle
Model->>AsyncMgr: offload_kvcache()
AsyncMgr->>GPUMgr: offload_kvcache (async)
GPUMgr->>GPUMgr: gather KV pages
GPUMgr->>HostMgr: copy to host storage
Model->>Model: postprocessor & MLP
Model->>Client: return logits
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
14 files reviewed, 6 comments
| static_onload_handle.reset() | ||
| onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache, | ||
| user_ids, static_onload_handle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Race condition: static_onload_handle.reset() is called before waiting for the previous onload_fut, which could still be using the handle.
| static_onload_handle.reset() | |
| onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache, | |
| user_ids, static_onload_handle) | |
| kvcache_metadata_fut = self.executor.submit(paged_kvcache_ops.prepare_kvcache, | |
| self.gpu_kvcache_mgr, self.host_kv_mgr, | |
| user_ids, total_history_lengths, | |
| static_page_ids_gpu_buffer, static_offload_page_ids_gpu_buffer, | |
| offload_uids_buffer, | |
| metadata_host_buffer, metadata_gpu_buffer) | |
| # Wait for previous onload to complete before resetting handle | |
| onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache, | |
| user_ids, static_onload_handle) | |
| static_onload_handle.reset() |
| new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ]) | ||
| if new_tokens <= 0: | ||
| print(total_history_lengths) | ||
| print(origin_cached_lengths) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Assertion will fail if new_tokens is 0 or negative, but the code continues execution after printing. This will cause issues downstream.
| new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ]) | |
| if new_tokens <= 0: | |
| print(total_history_lengths) | |
| print(origin_cached_lengths) | |
| new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ]) | |
| if new_tokens <= 0: | |
| raise ValueError(f"Invalid new_tokens: {new_tokens}. total_history_lengths={total_history_lengths}, origin_cached_lengths={origin_cached_lengths}") |
| num_cached = torch.maximum(origin_num_cached - num_context, torch.tensor([0], dtype=torch.int32)) | ||
| num_cached_action = num_cached // 2 | ||
| num_cached_item = num_cached - num_cached_action |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The torch.maximum will create a new tensor on each call. Use torch.clamp_min with inplace operation or cache the zero tensor for better performance.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| metadata_gpu_buffer, | ||
| static_onload_handle): | ||
|
|
||
| # onload_fut.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The onload_fut.result() is commented out but should be called before kvcache_metadata_fut.result() since onload must complete before the metadata is used. Was commenting out the onload wait intentional, or should it be executed before metadata wait?
| #define cudaCheck(ans) { cudaSuccesAssert((ans), __FILE__, __LINE__); } | ||
| inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true) | ||
| { | ||
| if (code != cudaSuccess) | ||
| { | ||
| fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); | ||
| // if (abort) exit(code); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The cudaSuccesAssert function has abort commented out, which means CUDA errors will only print warnings but not halt execution. This can lead to silent failures and corrupted data.
| #define cudaCheck(ans) { cudaSuccesAssert((ans), __FILE__, __LINE__); } | |
| inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true) | |
| { | |
| if (code != cudaSuccess) | |
| { | |
| fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); | |
| // if (abort) exit(code); | |
| } | |
| inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true) | |
| { | |
| if (code != cudaSuccess) | |
| { | |
| fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); | |
| if (abort) exit(code); | |
| } | |
| } |
| # if int(user_ids[0]) == 0: | ||
| # pg.dmp = True | ||
| if not disable_kvcache: | ||
| logits = model.forward(batch, user_ids, total_history_lengths.cpu()) | ||
| else: | ||
| logits = model.forward_nokvcache(batch) | ||
|
|
||
| if pg.dmp: | ||
| if disable_kvcache: | ||
| for lidx in range(model._hstu_config.num_layers): | ||
| if user_ids[0] < 10 or user_ids[0] >= 690: | ||
| shutil.move(f"/tmp/in_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy") | ||
| shutil.move(f"/tmp/key_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy") | ||
| shutil.move(f"/tmp/value_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy") | ||
| shutil.move(f"/tmp/attn_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy") | ||
| shutil.move(f"/tmp/out_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy") | ||
|
|
||
| else: | ||
| os.remove(f"/tmp/key_l{lidx}.npy") | ||
| os.remove(f"/tmp/value_l{lidx}.npy") | ||
| else: | ||
| for lidx in range(model._hstu_config.num_layers): | ||
| if user_ids[0] < 10 or user_ids[0] >= 690: | ||
| shutil.move(f"/tmp/in_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy") | ||
| shutil.move(f"/tmp/key_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy") | ||
| shutil.move(f"/tmp/value_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy") | ||
| shutil.move(f"/tmp/attn_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy") | ||
| shutil.move(f"/tmp/out_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy") | ||
| else: | ||
| os.remove(f"/tmp/key_l{lidx}.npy") | ||
| os.remove(f"/tmp/value_l{lidx}.npy") | ||
| pg.dmp = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Debugging/profiling code with hardcoded file paths (/tmp/, dump/, cached/) and user ID checks should be removed or guarded behind a debug flag before merging.
Implement HSTU KVCacheManager V2: