-
Notifications
You must be signed in to change notification settings - Fork 2
Add Fused Multi-Head Attention example #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Seeing some weird erroring when branching (being fixed in #53): Click to see snippetsThis works: qk = if !EVEN_K[] && j >= mask_start
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
mask = mask .& (offs_n .<= k_seqlen)
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk .+ mask
else
qk
endbut this doesn't: if !EVEN_K[] && j >= mask_start
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
mask = mask .& (offs_n .<= k_seqlen)
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk = qk .+ mask
endnor does this: qk = if !EVEN_K[] && j >= mask_start
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
if !EVEN_K[]
mask .& (offs_n .<= k_seqlen)
end
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk .+ mask
else
qk
endIn the second and third block, I get "ERROR: SSAValue %___ not found in context" after removing the second condition, I can suddenly have a nested if block, and I don't need the outer else block: if !EVEN_K[]
offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
if !EVEN_K[]
mask = mask .& (offs_n .<= k_seqlen)
end
mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
qk = qk .+ mask
endDoes the if block need to depend on compile time constants? I'd need this to make the padding and causal mask properly. |
That's an IRStructurizer error. Can you provide an MWE? |
See also #15
Seems to fall slightly short of my NNop / ONIONop baseline (no WMMA), although I haven't compared it to the Python version. On my GPU, it compiles and runs fastest with tile_n=32 and tile_m=32:EDIT: this is without tensor cores. simply switching the compute type to TFloat32 / BFloat16 and exploring the optimization and entry hint landscape makes forward and backward passes ~10x faster.
Notably, cutile-python has aEDIT: fixed in #32 and #27.latencyargument forct.load, as well asnum_ctasandoccupancyarguments for the kernel, which might affect performance. The python version also does a kernel config autotune by searching a space of hand-picked configurations.Another thing that might be important for correctness or covering edge cases is exposing flush_to_zero? Used in e.g.
exp2.