Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from fastdeploy.model_executor.forward_meta import HPUForwardMeta

from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from fastdeploy.model_executor.layers.normalization import RMSNorm


def get_attention_mask(seq_lens_encoder, seq_lens_decoder, batch_size, query_len):
Expand Down Expand Up @@ -80,6 +81,8 @@ def forward(
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
q_norm: RMSNorm = None,
k_norm: RMSNorm = None,
):
"""
Run a forward.
Expand All @@ -96,6 +99,8 @@ def forward(
o_proj,
layer,
forward_meta,
q_norm,
k_norm,
)
elif forward_meta.forward_mode.is_decode():
return self.forward_decode(
Expand All @@ -104,6 +109,8 @@ def forward(
o_proj,
layer,
forward_meta,
q_norm,
k_norm,
)
else:
return self.forward_extend(
Expand All @@ -112,6 +119,8 @@ def forward(
o_proj,
layer,
forward_meta,
q_norm,
k_norm,
)

def forward_mixed(
Expand All @@ -121,6 +130,8 @@ def forward_mixed(
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
q_norm: RMSNorm = None,
k_norm: RMSNorm = None,
):
"""Run a forward for mix."""
raise NotImplementedError()
Expand All @@ -132,6 +143,8 @@ def forward_decode(
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
q_norm: RMSNorm = None,
k_norm: RMSNorm = None,
):
"""Run a forward for decode."""
raise NotImplementedError()
Expand All @@ -143,6 +156,8 @@ def forward_extend(
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
q_norm: RMSNorm = None,
k_norm: RMSNorm = None,
):
"""Run a forward for extend."""
raise NotImplementedError()
Expand Down Expand Up @@ -249,7 +264,7 @@ def get_kv_cache_shape(
return (max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim)

def forward_extend(
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta, q_norm: RMSNorm=None, k_norm: RMSNorm=None
):
"""
forward_extend
Expand Down Expand Up @@ -280,21 +295,26 @@ def forward_extend(
qkv_act_scale_key=qkv_proj_act_scale_key,
)
else:
new_qkv_weight = paddle.concat([qkv_proj.weight_q, qkv_proj.weight_k, qkv_proj.weight_v], axis=-1)
query_states, key_value_states = fused_qkv_rope(
src,
qkv_proj.weight,
#qkv_proj.weight,
new_qkv_weight,
qkv_proj.bias,
forward_meta.rotary_embs,
getattr(qkv_proj, "act_scale", None),
getattr(qkv_proj, "weight_scale", None),
getattr(layer, "q_scale", None),
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
q_norm.weight if q_norm is not None else None,
k_norm.weight if k_norm is not None else None,
self.head_dim,
self.num_heads,
forward_meta.total_batch,
transpose=False,
use_neox_style=layer.use_neox_rotary_style,
epsilon=1e-6,
)

kv, B, BP_BS, M, H = key_value_states.shape
Expand Down Expand Up @@ -381,7 +401,7 @@ def forward_extend(
return out_linear_out

def forward_decode(
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta, q_norm: RMSNorm=None, k_norm: RMSNorm=None
):
"""
forward_decode
Expand Down Expand Up @@ -419,6 +439,7 @@ def forward_decode(
o_act_scale_key=o_proj_act_scale_key,
)
else:
new_qkv_weight = paddle.concat([qkv_proj.weight_q, qkv_proj.weight_k, qkv_proj.weight_v], axis=-1)
out_linear_out = fused_block_attention(
src,
forward_meta.rotary_embs,
Expand All @@ -430,11 +451,12 @@ def forward_decode(
forward_meta.attention_mask,
forward_meta.block_indices,
forward_meta.block_offsets,
qkv_proj.weight,
#qkv_proj.weight,
new_qkv_weight,
qkv_proj.bias,
o_proj.weight,
None,
None,
q_norm.weight if q_norm is not None else None,
k_norm.weight if k_norm is not None else None,
getattr(qkv_proj, "act_scale", None),
getattr(qkv_proj, "weight_scale", None),
getattr(layer, "q_scaling_scale", None),
Expand Down
58 changes: 58 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,43 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

layer.weight_k = layer.create_parameter(
shape= [4096, 1024], # Qwen3.
#shape = [512], # Qwen2.
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

layer.weight_v = layer.create_parameter(
shape= [4096, 1024],
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

layer.weight_q = layer.create_parameter(
shape= [4096, 4096],
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

layer.weight_up = layer.create_parameter(
shape= [4096, 12288], # Qwen3
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

layer.weight_gate = layer.create_parameter(
shape= [4096, 12288], # Qwen3
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)

split_axis = extra_weight_attrs.get("split_axis")
if hasattr(layer, "nranks") and layer.nranks > 0:
_set_var_distributed(layer.weight, split_axis=split_axis)
Expand Down Expand Up @@ -530,6 +567,15 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
shard_size = (self.local_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "gate":
print(">> loaded_weight gate:", loaded_weight)
self.weight_gate.set_value(loaded_weight)

if loaded_shard_id == "up":
print(">> loaded_weight up:", loaded_weight)
self.weight_up.set_value(loaded_weight)

if not param._is_initialized():
param.initialize()
param_shard_size = output_size // 2
Expand Down Expand Up @@ -670,6 +716,18 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N

loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "k":
print(">> loaded_weight k:", loaded_weight)
self.weight_k.set_value(loaded_weight)

if loaded_shard_id == "q":
print(">> loaded_weight q:", loaded_weight)
self.weight_q.set_value(loaded_weight)

if loaded_shard_id == "v":
print(">> loaded_weight v:", loaded_weight)
self.weight_v.set_value(loaded_weight)

if not param._is_initialized():
param.initialize()

Expand Down
15 changes: 14 additions & 1 deletion fastdeploy/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
)
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
Expand Down Expand Up @@ -222,6 +223,8 @@ def fused_attention_forward(
qkv_proj: QKVParallelLinear = None,
o_proj: RowParallelLinear = None,
forward_meta: HPUForwardMeta = None,
q_norm: RMSNorm = None,
k_norm: RMSNorm = None,
):
"""
The forward function of attention layer.
Expand All @@ -237,6 +240,8 @@ def fused_attention_forward(
o_proj,
self,
forward_meta,
q_norm,
k_norm,
)


Expand All @@ -251,6 +256,8 @@ def fused_self_atten_forward(
qkv_proj=self.qkv_proj,
o_proj=self.o_proj,
forward_meta=forward_meta,
q_norm=self.q_norm if hasattr(self, "q_norm") else None,
k_norm=self.k_norm if hasattr(self, "k_norm") else None,
)

return atten_out
Expand All @@ -273,9 +280,12 @@ def fused_mlp_forward(self, x):
down_act_scale_key=down_proj_act_scale_key,
)
else:
new_up_gate_weight = paddle.concat([self.up_gate_proj.weight_gate, self.up_gate_proj.weight_up], axis=-1)

out = fused_mlp(
x,
self.up_gate_proj.weight,
#self.up_gate_proj.weight,
new_up_gate_weight,
None,
self.down_proj.weight,
getattr(self.up_gate_proj, "act_scale", None),
Expand Down Expand Up @@ -306,6 +316,7 @@ def fused_mlp_forward(self, x):
Ernie4_5_MLP,
)
from fastdeploy.model_executor.models.qwen2 import Qwen2Attention, Qwen2MLP
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention


def convert_model(model, measurement_mode=False):
Expand All @@ -320,6 +331,8 @@ def convert_model(model, measurement_mode=False):
module.forward = types.MethodType(fused_self_atten_forward, module)
if isinstance(module, Qwen2Attention):
module.forward = types.MethodType(fused_self_atten_forward, module)
if isinstance(module, Qwen3Attention):
module.forward = types.MethodType(fused_self_atten_forward, module)
if isinstance(module, Ernie4_5_MLP):
module.measurement_mode = measurement_mode
module.forward = types.MethodType(fused_mlp_forward, module)
Expand Down
16 changes: 16 additions & 0 deletions start_qwen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
export HF_ENDPOINT=https://hf-mirror.com
export FD_MODEL_SOURCE=HUGGINGFACE

export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
export HABANA_PROFILE=0
export HPU_VISIBLE_DEVICES=6

/workspace/kill_python.sh
rm -rf log

HPU_WARMUP_BUCKET=0 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model /workspace/models/Qwen3-30B-A3B --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128 --load-choices 'default_v1'
#HPU_WARMUP_BUCKET=0 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model Qwen/Qwen3-8B --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128 --load-choices 'default_v1'