Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
413 changes: 288 additions & 125 deletions tests/pytorch/attention/run_attention_with_cp.py

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,16 @@ def test_dot_product_attention(
)

# Get backends
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
# For all other shapes test fwd+bwd
is_training = True
# TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s.
if config.bias_shape == "111s":
is_training = False
logging.info(
"Setting is_training to False as cuDNN does not support dbias for"
f" {config.bias_shape=} "
)
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
Expand Down Expand Up @@ -636,7 +645,8 @@ def test_dpa_bias(dtype, model_configs, model):
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
"bias_1_4": ModelConfig(
"bias_1_4": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="111s"),
"bias_1_5": ModelConfig(
4,
2048,
24,
Expand All @@ -646,7 +656,7 @@ def test_dpa_bias(dtype, model_configs, model):
bias_shape="1hss",
alibi_type="custom",
),
"bias_1_5": ModelConfig(
"bias_1_6": ModelConfig(
2,
2048,
24,
Expand Down Expand Up @@ -1143,10 +1153,16 @@ def _run_dot_product_attention(
bias = None
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
# For 1hss, 11ss, b1ss, bhss
shape_cache = shape
shape = shape.replace("_s_s", "_sq_skv")
# For 111s
if shape == shape_cache:
shape = shape.replace("_1_s", "_1_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != "1hss":
# For 111s, dbias calculation is not supported as of cuDNN 9.18
if config.bias_shape == "111s":
bias.requires_grad = False

# Create RNG
Expand Down
38 changes: 36 additions & 2 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_1_4": ModelConfig(
2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"
), # MHA
"cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
Expand All @@ -160,9 +163,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
bias_shape="11ss",
), # GQA
"cp_2_4": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
bias_shape="111s",
return_max_logit=True,
), # GQA
"cp_2_5": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_6": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
Expand All @@ -171,6 +195,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_3_4": ModelConfig(
2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64
), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
Expand All @@ -191,10 +218,13 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_0",
"cp_1_1",
"cp_1_4",
"cp_1_5",
"cp_2_0",
"cp_2_2",
"cp_2_3",
"cp_2_4",
"cp_3_2",
"cp_3_4",
"cp_4_2",
]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
Expand Down Expand Up @@ -324,12 +354,15 @@ def test_cp_with_fused_attention(
Float8CurrentScaling(fp8_dpa=True),
DelayedScaling(fp8_dpa=True),
]
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
is_training = False if config.bias_shape == "111s" else True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3),
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
Expand All @@ -348,6 +381,7 @@ def test_cp_with_fused_attention(
fp8_mha=fp8_mha,
scaling_mode=scaling_mode,
f16_O=f16_O,
is_training=is_training,
log_level=pytest_logging_level,
),
check=True,
Expand Down
5 changes: 3 additions & 2 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def get_available_attention_backends(
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
Expand All @@ -289,7 +288,9 @@ def get_available_attention_backends(
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
# TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s
if core_attention_bias_shape != "111s":
core_attention_bias_requires_grad = True

fused_attn_backends = []
available_backends = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1,
void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv,
bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
Expand Down Expand Up @@ -121,6 +122,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
max_pages_per_seq_v,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
is_training,
dropout_probability,
Expand Down Expand Up @@ -269,10 +272,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_alibi_mask(is_alibi);

if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
bias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_options.set_bias(bias);
}

Expand Down Expand Up @@ -548,16 +552,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO,
void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;

bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
Expand Down Expand Up @@ -622,6 +626,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
0,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
true,
dropout_probability,
Expand Down Expand Up @@ -811,19 +817,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_alibi_mask(is_alibi);

if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
bias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_backward_options.set_bias(bias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
// supported for forward bias calculation
if ((bias_b == 1) && (bias_h == h)) {
// bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation
// bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18
if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) {
dBias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_backward_options.set_dbias(dBias);
}
}
Expand Down Expand Up @@ -974,10 +981,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(

if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
if (dBias != nullptr) {
variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
}

Expand Down Expand Up @@ -1083,10 +1088,14 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
size_t bias_sq = 0;
size_t bias_skv = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
bias_sq = input_Bias->data.shape[2];
bias_skv = input_Bias->data.shape[3];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Expand Down Expand Up @@ -1152,7 +1161,7 @@ void fused_attn_arbitrary_seqlen_fwd(
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv};
output_bias->data.dtype = QKV_type;
}

Expand Down Expand Up @@ -1197,10 +1206,10 @@ void fused_attn_arbitrary_seqlen_fwd(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV,
devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv,
is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
Expand Down Expand Up @@ -1244,11 +1253,15 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
size_t bias_sq = 0;
size_t bias_skv = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
bias_sq = output_dBias->data.shape[2];
bias_skv = output_dBias->data.shape[3];
}

size_t max_batch_size = 0;
Expand Down Expand Up @@ -1291,11 +1304,11 @@ void fused_attn_arbitrary_seqlen_bwd(

fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);

Expand Down
Loading
Loading