From 6bfd496766e3f5c623f92448a2a418207ae1d389 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Thu, 11 Dec 2025 11:10:25 +0800 Subject: [PATCH] [GSA] Fix prefetch bug --- ucm/sparse/gsa/gsa.py | 22 ++++++++++----------- ucm/sparse/gsa/prefetch/prefetch_engine.py | 13 ++++++------ ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp | 17 ++++++---------- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index b1bf1e5c1..32e5cbee3 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -119,6 +119,7 @@ def set_block_hashes(self, token_ids): hash_value = self.request_hasher( (parent_block_hash_value, curr_block_token_ids_tuple) ) + self.block_hashes.append(str(hash_value)) parent_block_hash_value = hash_value if self.rank != 0 and not self.use_mla: @@ -421,7 +422,7 @@ def cal_topk(self, intermediate_q, current_layer_id): dot_product_weights.masked_fill_(self.exclude_mask == 1, float("-inf")) selected_block_nums = self.topk_len_list[0] _, top_indices = torch.topk( - dot_product_weights, selected_block_nums, dim=-1, sorted=False + dot_product_weights, selected_block_nums, dim=-1, sorted=True ) self.topk_caches[current_layer_id][self.cal_topk_id] = top_indices @@ -582,7 +583,9 @@ def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None: if not self.use_mla: self.gsa_q_cache[current_layer_id][: len(ids)].copy_(query[ids]) else: - self.gsa_q_cache[current_layer_id][self.decode_index].copy_(query) + self.gsa_q_cache[current_layer_id][: len(self.decode_index)].copy_( + query + ) is_cal_kpre = len(self.model_input["calc_block_table"]) > 0 self.gsa_offload_ops.add_copy_req( is_cal_kpre, current_layer_id, ids, self.gsa_q_cache[current_layer_id] @@ -656,7 +659,7 @@ def attention_begin( else: attn_metadata.block_tables[ : len(self.prefetch_engine.req_ids_bs) - ].copy_(self.model_input["block_tables_mp"][current_layer_id]) + ] = self.model_input["block_tables_mp"][current_layer_id] attn_metadata.seq_lens.copy_( self.model_input["gsa_seq_len"][current_layer_id] ) @@ -670,9 +673,7 @@ def attention_begin( current_layer_id ][self.decode_index] else: - attn_metadata.decode.block_table[ - : len(self.prefetch_engine.req_ids_bs) - ].copy_( + attn_metadata.decode.block_table[: len(self.decode_index)] = ( self.model_input["block_tables_mp"][current_layer_id][ self.decode_index ] @@ -937,9 +938,9 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): fn = getattr(self.connector, "load") precision = self.element_size if self.use_mla: - block_data_size = kv_caches[0].numel() * precision - else: block_data_size = kv_caches[0][0].numel() * precision + else: + block_data_size = kv_caches[0][0][0].numel() * precision offsets_k = [] key_src_tensors = [] @@ -1069,10 +1070,7 @@ def _start_topk_cal(self) -> None: if req_meta.is_gsa(): cal_topk_id.append(req_meta.index_in_batch) is_decode.append(True) - one_topk_len = ( - gsa_config.compute_topk_len(len(req_meta.blocks)) - + gsa_config.num_prefetch_blocks - ) + one_topk_len = gsa_config.compute_topk_len(len(req_meta.blocks)) topk_len_list.append(one_topk_len) if CUDA_TOPK: include_masks.append(req_meta.include_mask) diff --git a/ucm/sparse/gsa/prefetch/prefetch_engine.py b/ucm/sparse/gsa/prefetch/prefetch_engine.py index c38324606..29cd006a9 100644 --- a/ucm/sparse/gsa/prefetch/prefetch_engine.py +++ b/ucm/sparse/gsa/prefetch/prefetch_engine.py @@ -140,9 +140,8 @@ def model_input_deal( if self.atb_gsa_enable: block_table_index = torch.tensor(self.select_bs_index, device="cpu") - self.topk_len = ( - gsa_config.compute_topk_len(self._get_max_block_len(gsa_metadata)) - + gsa_config.num_prefetch_blocks + self.topk_len = gsa_config.compute_topk_len( + self._get_max_block_len(gsa_metadata) ) topk_buf_tmp = self.use_topk_caches[:, block_table_index, :] topk_buf_tmp = topk_buf_tmp[:, :, : self.topk_len] @@ -190,9 +189,8 @@ def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp): ) self.topk_bs = [] for index, req_id in enumerate(self.req_ids_bs): - one_topk_len = ( - gsa_config.compute_topk_len(len(gsa_metadata.gsa_stats[req_id].blocks)) - + gsa_config.num_prefetch_blocks + one_topk_len = gsa_config.compute_topk_len( + len(gsa_metadata.gsa_stats[req_id].blocks) ) self.topk_bs.append( [ @@ -536,7 +534,8 @@ def _set_req_stat( def _get_max_block_len(self, gsa_metadata) -> int: max_len = 0 for req_id in self.req_ids_bs: - max_len = max(max_len, len(gsa_metadata.gsa_stats[req_id].blocks)) + if self.is_gsa_req_id[req_id]: + max_len = max(max_len, len(gsa_metadata.gsa_stats[req_id].blocks)) return max_len def _no_gsa_input_deal( diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index 168af244e..0ce6f7531 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -137,8 +137,9 @@ GSAPrefetchEngineC::GSAPrefetchEngineC(torch::Tensor& freeBlock, torch::Tensor& mExtraTopkLen = extraTopkLen; mLogger.log(LogLevel::INFO, "GSAPrefetchEngineC Init mLayerNum %d mMaxBs %u, mUseMla %d, mHeadSzie %u, mTPSize " - "%u mBlockSize %u mHeadNum %u\n", - mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum); + "%u mBlockSize %u mHeadNum %u, mIsPythonLoad %d\n", + mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum, + mIsPythonLoad); } size_t GSAPrefetchEngineC::GetOffset(uint32_t layerID, bool isV) @@ -343,7 +344,6 @@ void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, int blockID = mDocsTables[reqID][layerID][item]; hitBlocks.insert(blockID); hitBlocksIdx.insert(std::make_pair(item, blockID)); - if (hitBlocks.size() == (topkLen - mExtraTopkLen)) { break; } } else { missIdxs.push_back(item); } @@ -351,8 +351,7 @@ void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, oss << "------\n"; mLogger.log(LogLevel::DEBUG, oss.str().c_str()); oss.str(""); - if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen && - hitBlocks.size() != (topkLen - mExtraTopkLen)) { + if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen) { mLogger.log(LogLevel::ERROR, "|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %s, layer: %d, hit size: " "%lu, miss size: %lu , topkLen: %d, not equal error\n", @@ -368,7 +367,6 @@ void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo, { int layerID = oneBsInfo.layerID; std::string reqID = oneBsInfo.reqID; - uint32_t topkLen = oneBsInfo.topkLen; int bsIndex = oneBsInfo.bsIndex; int oneFreeBlockLen = mFreeBlockLen[layerID][bsIndex].item(); @@ -377,8 +375,7 @@ void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo, uint32_t index = 0; int oneFreeBlockIndex = 0; - while (oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size() && - hitBlocks.size() < (topkLen - mExtraTopkLen)) { + while (oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size()) { int oneFreeBlockID = freeBlockPtr[oneFreeBlockIndex]; if (hitBlocks.find(oneFreeBlockID) != hitBlocks.end()) { oneFreeBlockIndex += 1; @@ -415,9 +412,7 @@ void GSAPrefetchEngineC::RunOneBsPrefetch(std::string reqID, int topkLen, int bs oneBsInfo.bsIndex = bsIndex; oneBsInfo.layerID = i; GetHitAndMissBlock(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); - if (missIdxs.size() != 0 && hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) { - RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); - } + if (missIdxs.size() != 0) { RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); } int successIndex = 0; for (auto it = hitBlocksIdx.begin(); it != hitBlocksIdx.end(); it++) { mLoadSuccessBlocks[i][bsIndex][successIndex] = it->second;