Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Paddle
Submodule Paddle updated 305 files
22 changes: 20 additions & 2 deletions backends/intel_hpu/custom_ops/llama_infer/fused_sdpa_proj_t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ class FusedSdpaProjBTMH : public HpuFusedOperator {
attn_inputs.push_back(q_r);
attn_inputs.push_back(k_r);
attn_inputs.push_back(v_r);

if (!params.sdpa_params.is_causal) {
attn_inputs.push_back(createTensor(inputs[3].dims.size(),
inputs[3].type,
inputs[3].dims,
true,
inputs[3].name));
}
if (params.fp8_sdpa) {
attn_inputs.push_back(nullptr); // Mask
attn_inputs.push_back(nullptr); // Seed
Expand Down Expand Up @@ -310,6 +316,7 @@ void FusedSdpaProjBTMHKernel(
const Context& dev_ctx,
const phi::DenseTensor& query_states,
const phi::DenseTensor& key_value_states,
const phi::DenseTensor& attn_mask,
const phi::DenseTensor& linear_weights,
phi::DenseTensor* out_linear,
const phi::Scalar& scaling_factor,
Expand All @@ -329,6 +336,9 @@ void FusedSdpaProjBTMHKernel(
std::vector<DIMS> in_out_dims = ct.GetDims();

ct.Add(linear_weights);
if (causal.to<bool>() == false) {
ct.Add(attn_mask);
}

unsigned int flags = 0;
SDPA_SET_INPUT_AND_FLAGS(d_scale_q.get_ptr(), D_SCALE_Q)
Expand Down Expand Up @@ -422,6 +432,12 @@ std::vector<paddle::Tensor> FusedBaseSdpaProjBTMH(
static_cast<const phi::DenseTensor*>(query_states.impl().get());
auto key_value_states_tensor =
static_cast<const phi::DenseTensor*>(key_value_states.impl().get());
phi::DenseTensor* attn_mask_tensor = nullptr;
if (attn_mask) {
auto attn_mask_ptr = *(attn_mask.get_ptr());
attn_mask_tensor =
static_cast<phi::DenseTensor*>(attn_mask_ptr.impl().get());
}
auto linear_weights_tensor =
static_cast<const phi::DenseTensor*>(linear_weights.impl().get());

Expand Down Expand Up @@ -503,12 +519,13 @@ std::vector<paddle::Tensor> FusedBaseSdpaProjBTMH(
dev_ctx->Alloc(out_linear.get(), query_states_tensor->dtype());
}

if (!attn_mask && !valid_seq_len) {
if (!valid_seq_len) {
if (query_states.dtype() == phi::DataType::FLOAT16) {
custom_kernel::FusedSdpaProjBTMHKernel<phi::dtype::float16>(
*dev_ctx,
*query_states_tensor,
*key_value_states_tensor,
attn_mask_tensor ? *attn_mask_tensor : phi::DenseTensor(),
*linear_weights_tensor,
out_linear.get(),
phi::Scalar(scaling_factor),
Expand All @@ -528,6 +545,7 @@ std::vector<paddle::Tensor> FusedBaseSdpaProjBTMH(
*dev_ctx,
*query_states_tensor,
*key_value_states_tensor,
attn_mask_tensor ? *attn_mask_tensor : phi::DenseTensor(),
*linear_weights_tensor,
out_linear.get(),
phi::Scalar(scaling_factor),
Expand Down
123 changes: 98 additions & 25 deletions backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,26 @@
#include "utils/utils.h"

// retrun: amax and where(>0)
std::pair<int, std::vector<int>> get_max_and_where_nonzero(
int* seq_lens_encoder, const int elem_cnt) {
int max_seq_len = seq_lens_encoder[0];
std::tuple<int, int, std::vector<int>> get_max_and_where_nonzero(
int* seq_lens_encoder, int* seq_lens_decoder, const int elem_cnt) {
int max_seq_len_without_context = 0;
int max_seq_len_with_context = 0;
std::vector<int> valid_batch;
for (int i = 0; i < elem_cnt; ++i) {
if (seq_lens_encoder[i] > max_seq_len) {
max_seq_len = seq_lens_encoder[i];
}
if (seq_lens_encoder[i] > 0) {
valid_batch.push_back(i);
if (seq_lens_encoder[i] > max_seq_len_without_context) {
max_seq_len_without_context = seq_lens_encoder[i];
max_seq_len_with_context = seq_lens_encoder[i];
}

if (seq_lens_decoder[i] > 0 && seq_lens_encoder[i] + seq_lens_decoder[i] >
max_seq_len_with_context) {
max_seq_len_with_context = seq_lens_encoder[i] + seq_lens_decoder[i];
}
}
}
return {max_seq_len, valid_batch};
return {max_seq_len_without_context, max_seq_len_with_context, valid_batch};
}

// return: where(>0)
Expand Down Expand Up @@ -104,6 +111,25 @@ void pad_fill(const T* input_p,
}
}

template <typename T>
void pad_fill(const T* input_p,
const T* offsets,
T* padded,
std::vector<int> valid_batches,
int input_linewidth,
int padded_linewidth,
int block_size) {
int copy_len = std::min(input_linewidth, padded_linewidth);
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
for (int i = 0; i < static_cast<int>(valid_batches.size()); ++i) {
for (int j = 0; j < copy_len; ++j) {
padded[i * padded_linewidth + j] =
input_p[valid_batches[i] * input_linewidth + j +
offsets[valid_batches[i]] / block_size];
}
}
}

template <typename T>
void pad_fill(const T* input_p, T* padded, std::vector<int> valid_batches) {
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
Expand Down Expand Up @@ -194,8 +220,13 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
const int max_blocks_each = block_tables.shape()[1];
phi::DataType device_dtype = phi::StringToDataType(dtype);

auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero(
const_cast<int*>(seq_lens_encoder_cpu.data<int>()), max_batches_in);
auto [max_enc_len_without_context,
max_enc_len_with_context,
valid_batches_enc] =
get_max_and_where_nonzero(
const_cast<int*>(seq_lens_encoder_cpu.data<int>()),
const_cast<int*>(seq_lens_decoder_cpu.data<int>()),
max_batches_in);
int enc_count = valid_batches_enc.size();

auto valid_batches_dec = where_nonzero(
Expand Down Expand Up @@ -223,34 +254,76 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(

auto input_ids_cpu = input_ids_selected.copy_to(paddle::CPUPlace(), true);

int max_buckets = (max_enc_len + block_size - 1) / block_size;
int max_prompt_len = max_buckets * block_size;
int max_buckets_without_context =
(max_enc_len_without_context + block_size - 1) / block_size;
int max_prompt_len_without_context =
max_buckets_without_context * block_size;

auto src_padded = paddle::full({total_batch * max_prompt_len},
0,
paddle::DataType::INT64,
paddle::CPUPlace());
auto src_padded =
paddle::full({total_batch * max_prompt_len_without_context},
0,
paddle::DataType::INT64,
paddle::CPUPlace());
pad_fill<int64_t>(const_cast<int64_t*>(input_ids_cpu.data<int64_t>()),
reinterpret_cast<int64_t*>(src_padded.data<int64_t>()),
static_cast<int>(valid_batches_enc.size()),
max_seq_len,
max_prompt_len);
max_prompt_len_without_context);

auto blk_padded = paddle::full({total_batch * max_buckets},
auto blk_padded = paddle::full({total_batch * max_buckets_without_context},
-1,
paddle::DataType::INT32,
paddle::CPUPlace());
pad_fill<int32_t>(const_cast<int32_t*>(block_tables_cpu.data<int32_t>()),
reinterpret_cast<int32_t*>(blk_padded.data<int32_t>()),
valid_batches_enc,
max_blocks_each,
max_buckets);
pad_fill<int32_t>(
const_cast<int32_t*>(block_tables_cpu.data<int32_t>()),
const_cast<int32_t*>(seq_lens_decoder_cpu.data<int32_t>()),
reinterpret_cast<int32_t*>(blk_padded.data<int32_t>()),
valid_batches_enc,
max_blocks_each,
max_buckets_without_context,
block_size);

auto blk_padded_hpu =
custom_kernel::copy_tensor_wrapper(dev_ctx, blk_padded, hpu_place);

auto rope_emb_seg = paddle::experimental::slice(
rope_emb, {2}, {0}, {max_prompt_len}, {}, {});
int max_buckets_with_context =
(max_enc_len_with_context + block_size - 1) / block_size;
int max_prompt_len_with_context = max_buckets_with_context * block_size;

auto block_list_padded =
paddle::full({total_batch * max_buckets_with_context},
-1,
paddle::DataType::INT32,
paddle::CPUPlace());
pad_fill<int32_t>(
const_cast<int32_t*>(block_tables_cpu.data<int32_t>()),
reinterpret_cast<int32_t*>(block_list_padded.data<int32_t>()),
valid_batches_enc,
max_blocks_each,
max_buckets_with_context);

auto block_list_hpu = custom_kernel::copy_tensor_wrapper(
dev_ctx, block_list_padded, hpu_place);

paddle::Tensor rope_emb_seg;
if (max_prompt_len_without_context == max_prompt_len_with_context) {
rope_emb_seg = paddle::experimental::slice(
rope_emb, {2}, {0}, {max_prompt_len_without_context}, {}, {});
} else {
std::vector<paddle::Tensor> rope_emb_segs;
for (auto b : valid_batches_enc) {
int start = seq_lens_decoder_cpu.data<int>()[b];
auto seg = paddle::experimental::slice(
rope_emb,
{2},
{start},
{start + max_prompt_len_without_context},
{},
{});
rope_emb_segs.push_back(seg);
}
rope_emb_seg = paddle::experimental::concat(rope_emb_segs, 1);
}
rope_emb_seg = paddle::experimental::cast(rope_emb_seg, device_dtype);

auto total_batch_cpu_tensor = paddle::full(
Expand All @@ -262,7 +335,7 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
return {src_padded,
rope_emb_seg,
dummy_tensor,
dummy_tensor,
block_list_hpu,
blk_padded_hpu,
dummy_tensor,
dummy_tensor,
Expand Down