From 2be2692a96db81a8a050897401e9064e4185ca1b Mon Sep 17 00:00:00 2001 From: fmiao2372 Date: Wed, 22 Oct 2025 05:39:08 +0000 Subject: [PATCH 1/2] support prefix caching --- .../llama_infer/fused_sdpa_proj_t.cc | 22 +++- .../llama_infer/prepare_block_metadata.cc | 123 ++++++++++++++---- 2 files changed, 118 insertions(+), 27 deletions(-) diff --git a/backends/intel_hpu/custom_ops/llama_infer/fused_sdpa_proj_t.cc b/backends/intel_hpu/custom_ops/llama_infer/fused_sdpa_proj_t.cc index 49be799d94..1bcfbc4032 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/fused_sdpa_proj_t.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/fused_sdpa_proj_t.cc @@ -171,7 +171,13 @@ class FusedSdpaProjBTMH : public HpuFusedOperator { attn_inputs.push_back(q_r); attn_inputs.push_back(k_r); attn_inputs.push_back(v_r); - + if (!params.sdpa_params.is_causal) { + attn_inputs.push_back(createTensor(inputs[3].dims.size(), + inputs[3].type, + inputs[3].dims, + true, + inputs[3].name)); + } if (params.fp8_sdpa) { attn_inputs.push_back(nullptr); // Mask attn_inputs.push_back(nullptr); // Seed @@ -310,6 +316,7 @@ void FusedSdpaProjBTMHKernel( const Context& dev_ctx, const phi::DenseTensor& query_states, const phi::DenseTensor& key_value_states, + const phi::DenseTensor& attn_mask, const phi::DenseTensor& linear_weights, phi::DenseTensor* out_linear, const phi::Scalar& scaling_factor, @@ -329,6 +336,9 @@ void FusedSdpaProjBTMHKernel( std::vector in_out_dims = ct.GetDims(); ct.Add(linear_weights); + if (causal.to() == false) { + ct.Add(attn_mask); + } unsigned int flags = 0; SDPA_SET_INPUT_AND_FLAGS(d_scale_q.get_ptr(), D_SCALE_Q) @@ -422,6 +432,12 @@ std::vector FusedBaseSdpaProjBTMH( static_cast(query_states.impl().get()); auto key_value_states_tensor = static_cast(key_value_states.impl().get()); + phi::DenseTensor* attn_mask_tensor = nullptr; + if (attn_mask) { + auto attn_mask_ptr = *(attn_mask.get_ptr()); + attn_mask_tensor = + static_cast(attn_mask_ptr.impl().get()); + } auto linear_weights_tensor = static_cast(linear_weights.impl().get()); @@ -503,12 +519,13 @@ std::vector FusedBaseSdpaProjBTMH( dev_ctx->Alloc(out_linear.get(), query_states_tensor->dtype()); } - if (!attn_mask && !valid_seq_len) { + if (!valid_seq_len) { if (query_states.dtype() == phi::DataType::FLOAT16) { custom_kernel::FusedSdpaProjBTMHKernel( *dev_ctx, *query_states_tensor, *key_value_states_tensor, + attn_mask_tensor ? *attn_mask_tensor : phi::DenseTensor(), *linear_weights_tensor, out_linear.get(), phi::Scalar(scaling_factor), @@ -528,6 +545,7 @@ std::vector FusedBaseSdpaProjBTMH( *dev_ctx, *query_states_tensor, *key_value_states_tensor, + attn_mask_tensor ? *attn_mask_tensor : phi::DenseTensor(), *linear_weights_tensor, out_linear.get(), phi::Scalar(scaling_factor), diff --git a/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc b/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc index 577a1297fe..c0e22cd3d8 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc @@ -21,19 +21,26 @@ #include "utils/utils.h" // retrun: amax and where(>0) -std::pair> get_max_and_where_nonzero( - int* seq_lens_encoder, const int elem_cnt) { - int max_seq_len = seq_lens_encoder[0]; +std::tuple> get_max_and_where_nonzero( + int* seq_lens_encoder, int* seq_lens_decoder, const int elem_cnt) { + int max_seq_len_without_context = 0; + int max_seq_len_with_context = 0; std::vector valid_batch; for (int i = 0; i < elem_cnt; ++i) { - if (seq_lens_encoder[i] > max_seq_len) { - max_seq_len = seq_lens_encoder[i]; - } if (seq_lens_encoder[i] > 0) { valid_batch.push_back(i); + if (seq_lens_encoder[i] > max_seq_len_without_context) { + max_seq_len_without_context = seq_lens_encoder[i]; + max_seq_len_with_context = seq_lens_encoder[i]; + } + + if (seq_lens_decoder[i] > 0 && seq_lens_encoder[i] + seq_lens_decoder[i] > + max_seq_len_with_context) { + max_seq_len_with_context = seq_lens_encoder[i] + seq_lens_decoder[i]; + } } } - return {max_seq_len, valid_batch}; + return {max_seq_len_without_context, max_seq_len_with_context, valid_batch}; } // return: where(>0) @@ -104,6 +111,25 @@ void pad_fill(const T* input_p, } } +template +void pad_fill(const T* input_p, + const T* offsets, + T* padded, + std::vector valid_batches, + int input_linewidth, + int padded_linewidth, + int block_size) { + int copy_len = std::min(input_linewidth, padded_linewidth); +#pragma omp parallel for num_threads(OMP_THREAD_NUM) + for (int i = 0; i < static_cast(valid_batches.size()); ++i) { + for (int j = 0; j < copy_len; ++j) { + padded[i * padded_linewidth + j] = + input_p[valid_batches[i] * input_linewidth + j + + offsets[valid_batches[i]] / block_size]; + } + } +} + template void pad_fill(const T* input_p, T* padded, std::vector valid_batches) { #pragma omp parallel for num_threads(OMP_THREAD_NUM) @@ -194,8 +220,13 @@ std::vector PrepareBlockMetadata( const int max_blocks_each = block_tables.shape()[1]; phi::DataType device_dtype = phi::StringToDataType(dtype); - auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero( - const_cast(seq_lens_encoder_cpu.data()), max_batches_in); + auto [max_enc_len_without_context, + max_enc_len_with_context, + valid_batches_enc] = + get_max_and_where_nonzero( + const_cast(seq_lens_encoder_cpu.data()), + const_cast(seq_lens_decoder_cpu.data()), + max_batches_in); int enc_count = valid_batches_enc.size(); auto valid_batches_dec = where_nonzero( @@ -223,34 +254,76 @@ std::vector PrepareBlockMetadata( auto input_ids_cpu = input_ids_selected.copy_to(paddle::CPUPlace(), true); - int max_buckets = (max_enc_len + block_size - 1) / block_size; - int max_prompt_len = max_buckets * block_size; + int max_buckets_without_context = + (max_enc_len_without_context + block_size - 1) / block_size; + int max_prompt_len_without_context = + max_buckets_without_context * block_size; - auto src_padded = paddle::full({total_batch * max_prompt_len}, - 0, - paddle::DataType::INT64, - paddle::CPUPlace()); + auto src_padded = + paddle::full({total_batch * max_prompt_len_without_context}, + 0, + paddle::DataType::INT64, + paddle::CPUPlace()); pad_fill(const_cast(input_ids_cpu.data()), reinterpret_cast(src_padded.data()), static_cast(valid_batches_enc.size()), max_seq_len, - max_prompt_len); + max_prompt_len_without_context); - auto blk_padded = paddle::full({total_batch * max_buckets}, + auto blk_padded = paddle::full({total_batch * max_buckets_without_context}, -1, paddle::DataType::INT32, paddle::CPUPlace()); - pad_fill(const_cast(block_tables_cpu.data()), - reinterpret_cast(blk_padded.data()), - valid_batches_enc, - max_blocks_each, - max_buckets); + pad_fill( + const_cast(block_tables_cpu.data()), + const_cast(seq_lens_decoder_cpu.data()), + reinterpret_cast(blk_padded.data()), + valid_batches_enc, + max_blocks_each, + max_buckets_without_context, + block_size); auto blk_padded_hpu = custom_kernel::copy_tensor_wrapper(dev_ctx, blk_padded, hpu_place); - auto rope_emb_seg = paddle::experimental::slice( - rope_emb, {2}, {0}, {max_prompt_len}, {}, {}); + int max_buckets_with_context = + (max_enc_len_with_context + block_size - 1) / block_size; + int max_prompt_len_with_context = max_buckets_with_context * block_size; + + auto block_list_padded = + paddle::full({total_batch * max_buckets_with_context}, + -1, + paddle::DataType::INT32, + paddle::CPUPlace()); + pad_fill( + const_cast(block_tables_cpu.data()), + reinterpret_cast(block_list_padded.data()), + valid_batches_enc, + max_blocks_each, + max_buckets_with_context); + + auto block_list_hpu = custom_kernel::copy_tensor_wrapper( + dev_ctx, block_list_padded, hpu_place); + + paddle::Tensor rope_emb_seg; + if (max_prompt_len_without_context == max_prompt_len_with_context) { + rope_emb_seg = paddle::experimental::slice( + rope_emb, {2}, {0}, {max_prompt_len_without_context}, {}, {}); + } else { + std::vector rope_emb_segs; + for (auto b : valid_batches_enc) { + int start = seq_lens_decoder_cpu.data()[b]; + auto seg = paddle::experimental::slice( + rope_emb, + {2}, + {start}, + {start + max_prompt_len_without_context}, + {}, + {}); + rope_emb_segs.push_back(seg); + } + rope_emb_seg = paddle::experimental::concat(rope_emb_segs, 1); + } rope_emb_seg = paddle::experimental::cast(rope_emb_seg, device_dtype); auto total_batch_cpu_tensor = paddle::full( @@ -262,7 +335,7 @@ std::vector PrepareBlockMetadata( return {src_padded, rope_emb_seg, dummy_tensor, - dummy_tensor, + block_list_hpu, blk_padded_hpu, dummy_tensor, dummy_tensor, From 9c603059f3d2d1a4d8903099c43489b83441e8e6 Mon Sep 17 00:00:00 2001 From: tianshuo78520a Date: Wed, 22 Oct 2025 16:18:28 +0000 Subject: [PATCH 2/2] Update Paddle submodule to latest develop --- Paddle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Paddle b/Paddle index 5dbecdcb0e..2b9ba85d9c 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 5dbecdcb0e4ddd3488927f49082dfb66c794f9e7 +Subproject commit 2b9ba85d9c512c05e20b38ea822dc808e410609f