diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc573cf7cb..b79cd82b21 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -145,10 +145,11 @@ def _build_prefill_graph( else: raise ValueError(f"Invalid query tensor shape: {q.shape}") + s_stride, h_stride, d_stride = q.stride() cudnn_q = g.tensor( name="q", dim=(graph_b, h_qo, graph_s_qo, d_qk), - stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), + stride=(h_qo * d_qk, h_stride, s_stride, d_stride), data_type=cudnn.data_type.BFLOAT16, ) @@ -171,10 +172,11 @@ def _build_prefill_graph( raise ValueError(f"Invalid kv cache tensor shape: {k_cache.shape}") if k_cache.dim() == 3: + s_stride, h_stride, d_stride = k_cache.stride() cudnn_k_cache = g.tensor( name="k_cache", dim=(graph_b, h_kv, graph_s_kv, d_qk), - stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1), + stride=(h_kv * d_qk * graph_s_kv, h_stride, s_stride, d_stride), data_type=cudnn.data_type.BFLOAT16, ) @@ -183,10 +185,11 @@ def _build_prefill_graph( ragged_k.set_uid(UIDs.RAGGED_K_UID.value) cudnn_k_cache.set_ragged_offset(ragged_k) + s_stride, h_stride, d_stride = v_cache.stride() cudnn_v_cache = g.tensor( name="v_cache", dim=(graph_b, h_kv, graph_s_kv, d_vo), - stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1), + stride=(h_kv * d_vo * graph_s_kv, h_stride, s_stride, d_stride), data_type=cudnn.data_type.BFLOAT16, )