Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
curl -sSL https://install.python-poetry.org | python3
poetry install --all-extras
poetry install --all-extras -vvv
- name: Type-checking package with mypy
run: |
# Run this mypy instance against our main package.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
curl -sSL https://install.python-poetry.org | python3
poetry install --all-extras
poetry install --all-extras -vvv
- name: Test with pytest
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
189 changes: 189 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# VSCode
.vscode/

# DS Store
.DS_Store

# Results
*.csv

# Python pickle files
*.pkl

# Sphinx documentation
_build/

# vim swap files
*.swo
*.swp

# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h

# Benchmark dataset
*.json
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
1 change: 1 addition & 0 deletions 3rdparty/googletest
Submodule googletest added at 9ff245
34 changes: 34 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
cmake_minimum_required(VERSION 3.23.1)
project(deft
VERSION 2024
DESCRIPTION "An IO-aware fast attention kernel for efficient tree-structured interactions with LLMs"
LANGUAGES CUDA CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

find_package(Python3 REQUIRED)
find_package(CUDAToolkit REQUIRED)
if(NOT Python3_FOUND)
message(FATAL_ERROR "Python3 not found.")
endif()
if(NOT CUDAToolkit_FOUND)
message(FATAL_ERROR "CUDA not found.")
endif()
message(STATUS "Python3 found at ${Python3_EXECUTABLE}")
message(STATUS "CUDA version is ${CUDAToolkit_VERSION}")
set(DEFT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc)
include_directories(${DEFT_SOURCE_DIR})
include_directories(${CUDAToolkit_INCLUDE_DIRS})

set(DEFT_ENABLE_TESTS CACHE BOOL "Enable tests for DEFT" ON)

add_subdirectory(3rdparty/googletest)
if(DEFT_ENABLE_TESTS)
enable_testing()
set(DEFT_TEST_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tests/csrc)
file(GLOB_RECURSE TEST_DUMMY_SRC ${DEFT_TEST_SOURCE_DIR}/test_dummy.cu)
add_executable(test_dummy ${TEST_DUMMY_SRC})
target_link_libraries(test_dummy gtest_main gtest)
add_test(NAME test_dummy COMMAND test_dummy)
endif()
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ We propose DeFT, an IO-aware attention algorithm for efficient tree-structured i

- [2024/05] We update the second version of DeFT paper with a better algorithm for general tree-structured LLM inference: [DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference](https://arxiv.org/abs/2404.00242)!
- [2024/03] [DeFT: Flash Tree-Attention With IO-Awareness for Efficient Tree-Search-Based LLM Inference](https://openreview.net/pdf?id=HqfLHoX8bR) has been accepted as Oral presentation in [ICLR'24 AGI Workshop](https://iclr.cc/virtual/2024/23126)!

****

## Abstract
Given the increasing demand for tree-structured interactions with LLMs, we introduce DeFT (Decoding with Flash Tree-Attention), an IO-aware tree attention algorithm tailored for tree-structured inference. Unlike traditional sequence-based decoding, tree-structured decoding better accommodates modern task requirements, including self-consistency, few-shot prompting, multi-step reasoning, and multi-model/head coordination. However, existing sequence-based inference systems are ill-suited for tree-structured decoding, resulting in redundancy in computation, memory footprints, and memory access, thereby undermining inference efficiency. To address this challenge, DeFT maintains memory-efficient attention calculation with low memory footprints through two key stages: (1) QKV Preparation: We propose a KV-Guided Grouping Strategy with Tree Split to intelligently group QKV, optimizing GPU resource utilization while minimizing memory reads/writes for KV cache between GPU global memory and on-chip shared memory; (2)Attention Calculation: We compute partial attention of each QKV group in a fused kernel and employ a Tree-topology-aware Global Reduction strategy to obtain final attention. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation (e.g., Softmax), DeFT achieves up to 2.52/3.82x speedup in the end-to-end/attention latency across three practical tree-based workloads: namely, few-shot prompting, multi-step reasoning, and speculative decoding, over state-of-the-art attention algorithms.
Expand All @@ -42,10 +42,22 @@ poetry install
CUDA_VISIBLE_DEVICES=0 python examples/
```


### Run Tests

<!-- We profile DeFT kernel performance with [nvbench](https://github.com/NVIDIA/nvbench) and you can compile and run the benchmarks with the following commands: -->

```bash
cmake -B build
cmake --build build
cd build
ctest
```

## FAQ

1. **What is the difference between two versions of DeFT papers in arXiv?**

DeFT-v1


Expand All @@ -72,4 +84,4 @@ If you find DeFT useful or relevant to your project and research, please kindly
journal={arXiv preprint arXiv:2404.00242},
year={2024}
}
```
```
29 changes: 29 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
from typing import Any, Dict

from torch.utils.cpp_extension import BuildExtension, CUDAExtension

logger = logging.getLogger(__name__)


ext_modules = []
ext_modules.append(
CUDAExtension(
name='deft._kernels',
sources=['csrc/deft_api.cpp', 'csrc/deft/attention.cu'],
include_dirs=[
'csrc',
],
)
)


def build(setup_kwargs: Dict[str, Any]) -> None:
setup_kwargs.update(
{
'ext_modules': ext_modules,
'cmdclass': {
'build_ext': BuildExtension.with_options(use_ninja=False),
},
}
)
18 changes: 18 additions & 0 deletions csrc/deft/attention.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <cuda_runtime.h>
#include "deft/ops.h"
#include "attention.cuh"
#include <torch/all.h>

namespace deft {
torch::Tensor dummy(
torch::Tensor output,
torch::Tensor input
) {
dummy(
output.data_ptr<float>(),
input.data_ptr<float>(),
input.numel()
);
return output;
}
}
25 changes: 25 additions & 0 deletions csrc/deft/attention.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <cuda_runtime.h>

namespace deft {
__global__ void dummy_kernel(
float* __restrict__ output,
const float* __restrict__ input,
const int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
output[idx] = input[idx];
}
}

void dummy(
float* output,
const float* input,
int size
) {
dim3 block(256);
dim3 grid((size + block.x - 1) / block.x);
dummy_kernel<<<grid, block>>>(output, input, size);
}

}
12 changes: 12 additions & 0 deletions csrc/deft/ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <optional>
#include <torch/library.h>
#include <torch/all.h>

namespace deft {
torch::Tensor dummy(
torch::Tensor output,
torch::Tensor input
);
}
8 changes: 8 additions & 0 deletions csrc/deft_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <torch/python.h>
#include "deft/ops.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeFT";
m.def("dummy", &deft::dummy, "Dummy");
}
Loading