Skip to content

Commit 86aa561

Browse files
authored
Merge pull request #13 from bargav25/new-llm
added flash attention forward
2 parents b0275e2 + d2481a4 commit 86aa561

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

llm/flash-attention.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.11.11","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":31041,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"## Flash Attention Forward Kernel Implementation\n\n### Task\n\nImplement a **Flash Attention (v2) Forward kernel** using Triton. Your kernel should take the following inputs:\n\n- `Q` (Query)\n- `K` (Key)\n- `V` (Value)\n\nAnd produce the following outputs:\n\n- `O` (Output)\n- `L` (Logsumexp values)\n\n### Requirements\n\n- Your Triton kernel must be launched with a grid configuration of **`(T_q, batch_size)`**, where:\n - Each Triton program instance handles **one tile of the `Q` tensor**,\n - and accesses data for a **single batch index**.\n\n- Within each program instance:\n - Load only the relevant tile from `Q`, and the corresponding batch slice from `K` and `V`,\n - Compute the attention scores and apply softmax using the logsumexp trick,\n - Store the result in the appropriate section of the output tensor `O`,\n - Store the logsumexp values in tensor `L`.\n\n### Notes\n\n- We will test with powers of 2 and at least 16, so you don’t need to worry about\nout-of-bounds accesses.\n","metadata":{}},{"cell_type":"code","source":"import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef flash_fwd_kernel(Q_ptr, K_ptr, V_ptr, \n O_ptr, L_ptr,\n stride_qb, stride_qq, stride_qd,\n stride_kb, stride_kk, stride_kd,\n stride_vb, stride_vk, stride_vd,\n stride_ob, stride_ok, stride_od,\n stride_lb, stride_lq,\n N_q, N_k,\n scale,\n D: tl.constexpr,\n BLOCK_SIZE_Q: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr):\n \n # Program Indices\n query_tile_index = tl.program_id(0)\n batch_index = tl.program_id(1)\n\n # Block pointers\n Q_block_ptr = tl.make_block_ptr(Q_ptr + batch_index * stride_qb,\n shape=(N_q, D),\n strides=(stride_qq, stride_qd),\n offsets=(query_tile_index * BLOCK_SIZE_Q, 0),\n block_shape=(BLOCK_SIZE_Q, D),\n order=(1,0))\n ####### Your Code goes here ############\n \n pass\n ","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Solution","metadata":{}},{"cell_type":"code","source":"import torch\nimport triton\nimport triton.language as tl\nimport math\n\n\n@triton.jit\ndef flash_fwd_kernel(Q_ptr, K_ptr, V_ptr, \n O_ptr, L_ptr,\n stride_qb, stride_qq, stride_qd,\n stride_kb, stride_kk, stride_kd,\n stride_vb, stride_vk, stride_vd,\n stride_ob, stride_ok, stride_od,\n stride_lb, stride_lq,\n N_q, N_k,\n scale,\n D: tl.constexpr,\n BLOCK_SIZE_Q: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr):\n \n # Program Indices\n query_tile_index = tl.program_id(0)\n batch_index = tl.program_id(1)\n\n # Block pointers\n Q_block_ptr = tl.make_block_ptr(Q_ptr + batch_index * stride_qb,\n shape=(N_q, D),\n strides=(stride_qq, stride_qd),\n offsets=(query_tile_index * BLOCK_SIZE_Q, 0),\n block_shape=(BLOCK_SIZE_Q, D),\n order=(1,0))\n \n K_block_ptr = tl.make_block_ptr(K_ptr + batch_index * stride_kb,\n shape=(D, N_k),\n strides=(stride_kd, stride_kk),\n offsets=(0, 0),\n block_shape=(D, BLOCK_SIZE_K),\n order=(0,1)) # Note: K is transposed in the kernel\n \n V_block_ptr = tl.make_block_ptr(V_ptr + batch_index * stride_vb,\n shape=(N_k, D),\n strides=(stride_vk, stride_vd),\n offsets=(0, 0),\n block_shape=(BLOCK_SIZE_K, D),\n order=(1,0))\n \n O_block_ptr = tl.make_block_ptr(O_ptr + batch_index * stride_ob,\n shape=(N_q, D),\n strides=(stride_ok, stride_od),\n offsets=(query_tile_index * BLOCK_SIZE_Q, 0),\n block_shape=(BLOCK_SIZE_Q, D),\n order=(1,0))\n \n L_block_ptr = tl.make_block_ptr(L_ptr + batch_index * stride_lb,\n shape=(N_q,),\n strides=(stride_lq,),\n offsets=(query_tile_index * BLOCK_SIZE_Q,),\n block_shape=(BLOCK_SIZE_Q,),\n order=(0,))\n \n l = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 # Initialize l to 1.0\n out = tl.zeros([BLOCK_SIZE_Q, D], dtype=tl.float32)\n\n prev_max = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float('inf') # Initialize s_max to negative infinity\n\n # Load query\n q = tl.load(Q_block_ptr).to(tl.float32)\n\n for i in range(0, N_k, BLOCK_SIZE_K):\n\n # Load keys and values\n k = tl.load(K_block_ptr).to(tl.float32)\n v = tl.load(V_block_ptr).to(tl.float32)\n\n # Compute the attention scores\n s = tl.dot(q, k) * scale\n curr_max = tl.maximum(prev_max, tl.max(s, axis=1))\n p = tl.math.exp(s - curr_max[:, None])\n\n\n # Compute the output\n alpha = tl.math.exp(prev_max - curr_max)\n out = out * alpha[:, None] + tl.dot(p, v)\n\n # To store the logsumexp for backward pass\n curr_l = tl.sum(p, axis=1)\n l = l * alpha + curr_l\n\n prev_max = curr_max\n\n # Advance block pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_K))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_K, 0))\n\n out = out / l[:, None] # Normalize the output\n tl.store(O_block_ptr, out.to(O_ptr.dtype.element_ty))\n\n # Store the logsumexp\n log_l = prev_max + tl.log(l)\n tl.store(L_block_ptr, log_l.to(L_ptr.dtype.element_ty))\n\n ","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2025-06-06T19:03:37.207225Z","iopub.execute_input":"2025-06-06T19:03:37.208005Z","iopub.status.idle":"2025-06-06T19:03:37.223005Z","shell.execute_reply.started":"2025-06-06T19:03:37.207975Z","shell.execute_reply":"2025-06-06T19:03:37.222284Z"}},"outputs":[],"execution_count":15},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## TESTS","metadata":{}},{"cell_type":"code","source":" \n# Define problem size\nB, N_q, N_k, D = 1, 64, 128, 256 # Batch size, query len, key len, hidden dim\nBLOCK_SIZE_Q = 16\nBLOCK_SIZE_K = 16\n\n# Initialize inputs\nQ = torch.randn((B, N_q, D), dtype=torch.float16, device='cuda')\nK = torch.randn((B, N_k, D), dtype=torch.float16, device='cuda')\nV = torch.randn((B, N_k, D), dtype=torch.float16, device='cuda')\n\n# Outputs\nO = torch.empty((B, N_q, D), dtype=torch.float16, device='cuda')\nL = torch.empty((B, N_q), dtype=torch.float32, device='cuda')\n\n# Compute strides\nstride_qb, stride_qq, stride_qd = Q.stride()\nstride_kb, stride_kk, stride_kd = K.stride()\nstride_vb, stride_vk, stride_vd = V.stride()\nstride_ob, stride_ok, stride_od = O.stride()\nstride_lb, stride_lq = L.stride()\n\n# Call Triton kernel\ngrid = (triton.cdiv(N_q, BLOCK_SIZE_Q), B)\n\nflash_fwd_kernel[grid](\n Q, K, V, O, L,\n stride_qb, stride_qq, stride_qd,\n stride_kb, stride_kk, stride_kd,\n stride_vb, stride_vk, stride_vd,\n stride_ob, stride_ok, stride_od,\n stride_lb, stride_lq,\n N_q, N_k,\n scale=1.0 / math.sqrt(D),\n D=D,\n BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n BLOCK_SIZE_K=BLOCK_SIZE_K,\n)\n\n# Print result\nprint(\"Output O:\", O[0, :5]) # Print first 5 query results\nprint(\"Logsumexp L:\", L[0, :5])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-06-06T19:03:49.722693Z","iopub.execute_input":"2025-06-06T19:03:49.723399Z","iopub.status.idle":"2025-06-06T19:03:50.790921Z","shell.execute_reply.started":"2025-06-06T19:03:49.723376Z","shell.execute_reply":"2025-06-06T19:03:50.790068Z"}},"outputs":[{"name":"stdout","text":"Output O: tensor([[-0.0423, 0.0367, 0.0788, ..., 0.1219, -0.1204, 0.1300],\n [-0.0228, -0.0631, -0.1155, ..., 0.0247, 0.0787, 0.0546],\n [-0.0792, 0.0242, -0.0278, ..., 0.2046, 0.0991, 0.0636],\n [-0.0681, -0.2094, 0.0441, ..., 0.1156, -0.0679, 0.1702],\n [-0.0521, -0.1407, -0.1186, ..., 0.1582, 0.0699, -0.0427]],\n device='cuda:0', dtype=torch.float16)\nLogsumexp L: tensor([5.4945, 5.4222, 5.2435, 5.4245, 5.3262], device='cuda:0')\n","output_type":"stream"}],"execution_count":16},{"cell_type":"code","source":"# Doing the same operation using PyTorch matmul operations\nQ_ref = Q.to(torch.float32)\nK_ref = K.to(torch.float32)\nV_ref = V.to(torch.float32)\n\nscale = 1.0 / math.sqrt(D)\nscores = torch.matmul(Q_ref, K_ref.transpose(-2, -1)) * scale # (B, N_q, N_k)\nattn = torch.nn.functional.softmax(scores, dim=-1) # (B, N_q, N_k)\nO_ref = torch.matmul(attn, V_ref) # (B, N_q, D)\n\n# Comparing Both\n\n# Convert Triton output to float32 for comparison\nO_triton = O.to(torch.float32)\n\nprint(torch.allclose(O_triton, O_ref, atol=1e-1, rtol=1e-2)) # Should be True\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-06-06T19:05:23.841987Z","iopub.execute_input":"2025-06-06T19:05:23.842471Z","iopub.status.idle":"2025-06-06T19:05:23.849223Z","shell.execute_reply.started":"2025-06-06T19:05:23.842450Z","shell.execute_reply":"2025-06-06T19:05:23.848594Z"}},"outputs":[{"name":"stdout","text":"True\n","output_type":"stream"}],"execution_count":21},{"cell_type":"code","source":"L_ref = torch.log(torch.sum(torch.exp(scores - scores.max(dim=-1, keepdim=True).values), dim=-1)) + scores.max(dim=-1).values\nL_triton = L.to(torch.float32)\n\nprint(torch.allclose(L_triton, L_ref, atol=1e-1, rtol=1e-2)) # Should be True\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-06-06T19:05:59.357044Z","iopub.execute_input":"2025-06-06T19:05:59.357336Z","iopub.status.idle":"2025-06-06T19:05:59.363363Z","shell.execute_reply.started":"2025-06-06T19:05:59.357315Z","shell.execute_reply":"2025-06-06T19:05:59.362815Z"}},"outputs":[{"name":"stdout","text":"True\n","output_type":"stream"}],"execution_count":22},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}

0 commit comments

Comments
 (0)