From 511f5e9ebd0f2c63d996e7361bb8408e323abe1d Mon Sep 17 00:00:00 2001 From: fmiao2372 Date: Wed, 15 Oct 2025 05:29:24 +0000 Subject: [PATCH 1/2] support prefix caching --- .../llama_infer/prepare_block_metadata.cc | 90 +++++++++++++++---- 1 file changed, 72 insertions(+), 18 deletions(-) diff --git a/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc b/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc index 577a1297fe..2d23b38513 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc @@ -21,19 +21,25 @@ #include "utils/utils.h" // retrun: amax and where(>0) -std::pair> get_max_and_where_nonzero( - int* seq_lens_encoder, const int elem_cnt) { - int max_seq_len = seq_lens_encoder[0]; +std::tuple> get_max_and_where_nonzero( + int* seq_lens_encoder, int* seq_lens_decoder, const int elem_cnt) { + int max_seq_len_without_context = 0; + int max_seq_len_with_context = 0; std::vector valid_batch; for (int i = 0; i < elem_cnt; ++i) { - if (seq_lens_encoder[i] > max_seq_len) { - max_seq_len = seq_lens_encoder[i]; - } if (seq_lens_encoder[i] > 0) { valid_batch.push_back(i); + if (seq_lens_encoder[i] > max_seq_len_without_context) { + max_seq_len_without_context = seq_lens_encoder[i]; + max_seq_len_with_context = seq_lens_encoder[i]; + } + + if (seq_lens_decoder[i] > 0 && seq_lens_encoder[i] + seq_lens_decoder[i] > max_seq_len_with_context) { + max_seq_len_with_context = seq_lens_encoder[i] + seq_lens_decoder[i]; + } } } - return {max_seq_len, valid_batch}; + return {max_seq_len_without_context, max_seq_len_with_context, valid_batch}; } // return: where(>0) @@ -104,6 +110,24 @@ void pad_fill(const T* input_p, } } +template +void pad_fill(const T* input_p, + const T* offsets, + T* padded, + std::vector valid_batches, + int input_linewidth, + int padded_linewidth, + int block_size) { + int copy_len = std::min(input_linewidth, padded_linewidth); +#pragma omp parallel for num_threads(OMP_THREAD_NUM) + for (int i = 0; i < static_cast(valid_batches.size()); ++i) { + for (int j = 0; j < copy_len; ++j) { + padded[i * padded_linewidth + j] = + input_p[valid_batches[i] * input_linewidth + j + offsets[valid_batches[i]] / block_size]; + } + } +} + template void pad_fill(const T* input_p, T* padded, std::vector valid_batches) { #pragma omp parallel for num_threads(OMP_THREAD_NUM) @@ -194,8 +218,8 @@ std::vector PrepareBlockMetadata( const int max_blocks_each = block_tables.shape()[1]; phi::DataType device_dtype = phi::StringToDataType(dtype); - auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero( - const_cast(seq_lens_encoder_cpu.data()), max_batches_in); + auto [max_enc_len_without_context, max_enc_len_with_context, valid_batches_enc] = get_max_and_where_nonzero( + const_cast(seq_lens_encoder_cpu.data()), const_cast(seq_lens_decoder_cpu.data()), max_batches_in); int enc_count = valid_batches_enc.size(); auto valid_batches_dec = where_nonzero( @@ -223,10 +247,10 @@ std::vector PrepareBlockMetadata( auto input_ids_cpu = input_ids_selected.copy_to(paddle::CPUPlace(), true); - int max_buckets = (max_enc_len + block_size - 1) / block_size; - int max_prompt_len = max_buckets * block_size; + int max_buckets_without_context = (max_enc_len_without_context + block_size - 1) / block_size; + int max_prompt_len_without_context = max_buckets_without_context * block_size; - auto src_padded = paddle::full({total_batch * max_prompt_len}, + auto src_padded = paddle::full({total_batch * max_prompt_len_without_context}, 0, paddle::DataType::INT64, paddle::CPUPlace()); @@ -234,23 +258,53 @@ std::vector PrepareBlockMetadata( reinterpret_cast(src_padded.data()), static_cast(valid_batches_enc.size()), max_seq_len, - max_prompt_len); + max_prompt_len_without_context); - auto blk_padded = paddle::full({total_batch * max_buckets}, + auto blk_padded = paddle::full({total_batch * max_buckets_without_context}, -1, paddle::DataType::INT32, paddle::CPUPlace()); pad_fill(const_cast(block_tables_cpu.data()), + const_cast(seq_lens_decoder_cpu.data()), reinterpret_cast(blk_padded.data()), valid_batches_enc, max_blocks_each, - max_buckets); + max_buckets_without_context, + block_size); auto blk_padded_hpu = custom_kernel::copy_tensor_wrapper(dev_ctx, blk_padded, hpu_place); - auto rope_emb_seg = paddle::experimental::slice( - rope_emb, {2}, {0}, {max_prompt_len}, {}, {}); + int max_buckets_with_context = (max_enc_len_with_context + block_size - 1) / block_size; + int max_prompt_len_with_context = max_buckets_with_context * block_size; + + auto block_list_padded = paddle::full({total_batch * max_buckets_with_context}, + -1, + paddle::DataType::INT32, + paddle::CPUPlace()); + pad_fill(const_cast(block_tables_cpu.data()), + reinterpret_cast(block_list_padded.data()), + valid_batches_enc, + max_blocks_each, + max_buckets_with_context); + + auto block_list_hpu = + custom_kernel::copy_tensor_wrapper(dev_ctx, block_list_padded, hpu_place); + + paddle::Tensor rope_emb_seg; + if (max_prompt_len_without_context == max_prompt_len_with_context) { + rope_emb_seg = paddle::experimental::slice( + rope_emb, {2}, {0}, {max_prompt_len_without_context}, {}, {}); + } else { + std::vector rope_emb_segs; + for (auto b : valid_batches_enc) { + int start = seq_lens_decoder_cpu.data()[b]; + auto seg = paddle::experimental::slice( + rope_emb, {2}, {start}, {start + max_prompt_len_without_context}, {}, {}); + rope_emb_segs.push_back(seg); + } + rope_emb_seg = paddle::experimental::concat(rope_emb_segs, 1); + } rope_emb_seg = paddle::experimental::cast(rope_emb_seg, device_dtype); auto total_batch_cpu_tensor = paddle::full( @@ -262,7 +316,7 @@ std::vector PrepareBlockMetadata( return {src_padded, rope_emb_seg, dummy_tensor, - dummy_tensor, + block_list_hpu, blk_padded_hpu, dummy_tensor, dummy_tensor, From f5919e3ca43adf21ca7c997ce498a40dd7d04724 Mon Sep 17 00:00:00 2001 From: tianshuo78520a Date: Thu, 16 Oct 2025 16:18:17 +0000 Subject: [PATCH 2/2] Update Paddle submodule to latest develop --- Paddle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Paddle b/Paddle index cc367e8767..5dbecdcb0e 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit cc367e8767d49819b5100f22e279cd62a1587670 +Subproject commit 5dbecdcb0e4ddd3488927f49082dfb66c794f9e7