diff --git a/bitblas/base/arch/__init__.py b/bitblas/base/arch/__init__.py index dd931f617..36353b1c4 100644 --- a/bitblas/base/arch/__init__.py +++ b/bitblas/base/arch/__init__.py @@ -35,6 +35,7 @@ def auto_infer_current_arch() -> TileDevice: is_ampere_arch, # noqa: F401 is_ada_arch, # noqa: F401 is_hopper_arch, # noqa: F401 + is_blackwell_arch, # noqa: F401 is_tensorcore_supported_precision, # noqa: F401 has_mma_support, # noqa: F401 ) diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index 25c83bff1..9d9220036 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -45,6 +45,14 @@ def is_hopper_arch(arch: TileDevice) -> bool: return all(conditions) +def is_blackwell_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + # Treat sm_100+ (Blackwell / future) as Blackwell family for dispatch purposes + conditions.append(arch.sm_version >= 100) + return all(conditions) + + def has_mma_support(arch: TileDevice) -> bool: conditions = [True] conditions.append(is_cuda_arch(arch)) @@ -87,7 +95,7 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til return (in_dtype, accum_dtype) in ampere_tensorcore_supported elif is_ada_arch(arch): return (in_dtype, accum_dtype) in ada_tensorcore_supported - elif is_hopper_arch(arch): + elif is_hopper_arch(arch) or is_blackwell_arch(arch): return (in_dtype, accum_dtype) in hopper_tensorcore_supported else: raise ValueError(f"Unsupported architecture: {arch}") diff --git a/bitblas/utils/target_detector.py b/bitblas/utils/target_detector.py index 9e3427817..48a8b2779 100644 --- a/bitblas/utils/target_detector.py +++ b/bitblas/utils/target_detector.py @@ -86,9 +86,9 @@ def auto_detect_nvidia_target(gpu_id: int = 0) -> str: Returns: str: The detected TVM target architecture. """ - # Return a predefined target if specified in the environment variable - # if "TVM_TARGET" in os.environ: - # return os.environ["TVM_TARGET"] + # Honor explicit override first + if "TVM_TARGET" in os.environ: + return os.environ["TVM_TARGET"] # Fetch all available tags and filter for NVIDIA tags all_tags = list_tags() @@ -101,5 +101,26 @@ def auto_detect_nvidia_target(gpu_id: int = 0) -> str: if gpu_model in NVIDIA_GPU_REMAP: gpu_model = NVIDIA_GPU_REMAP[gpu_model] + # If we can get compute capability, prefer constructing an explicit arch target + try: + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], + encoding="utf-8", + ).strip() + cap = output.split("\n")[gpu_id].strip() + # compute_cap is like "11.0" + if cap and cap.replace(".", "").isdigit(): + cap_int = int(cap.split(".")[0]) * 10 + int(cap.split(".")[1]) + if cap_int >= 110: + return "cuda -arch=sm_110" + elif cap_int >= 90: + return "cuda -arch=sm_90" + elif cap_int >= 89: + return "cuda -arch=sm_89" + elif cap_int >= 80: + return "cuda -arch=sm_80" + except Exception: + pass + target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" return target diff --git a/docs/bitblas-aarch64-notes.md b/docs/bitblas-aarch64-notes.md new file mode 100644 index 000000000..6d31f792d --- /dev/null +++ b/docs/bitblas-aarch64-notes.md @@ -0,0 +1,64 @@ +# BitBLAS aarch64 适配与 Blackwell 支持记录 + +记录本次在 aarch64 (Jetson Thor / CUDA 13.0 / Ubuntu 24.04) 上修复 TVM 构建问题并让 Blackwell (sm_110) 运行 BitBLAS 的关键改动和操作。 + +## 主要改动 + +1) **修复 aarch64 上 TVM 的 LLVM 依赖问题** + - 文件:`setup.py` + - 变更:将 `LLVM_VERSION` 从 `10.0.1` 升级到 `18.1.8`,并改用 `tar --no-same-owner` 解压,避免旧版 `libtinfo5` 依赖以及解压权限导致的 `Permission denied`。 + - 影响:TVM 构建不再因为 `/lib/aarch64-linux-gnu/libtinfo.so.5` 缺失或解压权限报错而中断。 + +2) **显式支持 Blackwell 识别与调度** + - 文件:`bitblas/base/arch/cuda.py`,`bitblas/base/arch/__init__.py` + - 新增 `is_blackwell_arch`,并将 TensorCore 支持判定中把 Blackwell 视为 Hopper/Ada 支持的精度集合。 + - `Matmul`/`MatmulDequantize` 调度逻辑中,Blackwell 不再抛 “Unsupported architecture”,直接走 Ampere/Ada 路径(可再调优)。 + +3) **尊重 TVM_TARGET 并自动选择 sm_110** + - 文件:`bitblas/utils/target_detector.py` + - 现在优先读取环境变量 `TVM_TARGET`;若未设置,会用 `nvidia-smi --query-gpu=compute_cap` 自动构造 arch(>=110 -> sm_110,>=90 -> sm_90,>=89 -> sm_89,>=80 -> sm_80)。 + - 目的:在 Blackwell 上不再提示 “TVM target not found”,且可显式锁定 `TVM_TARGET="cuda -arch=sm_110"`。 + +## 操作步骤(本机已验证) + +1) 环境与 Torch(PyTorch 2.4.0 Jetson 轮子) + ```bash + conda create -y -n bitblas311 python=3.11 + conda run -n bitblas311 python -m pip install ~/jetson-pytorch-builder/wheels/py311/torch-*.whl + ``` + +2) 安装 BitBLAS(含 TVM/tilelang 构建) + ```bash + cd ~/BitBLAS + conda run -n bitblas311 python -m pip install -e . + ``` + +3) Blackwell 运行(可选先调优) + ```bash + export TVM_TARGET="cuda -arch=sm_110" + # 如需调优生成专用算子 + conda run -n bitblas311 python - <<'PY' + import os, bitblas + cfg = bitblas.MatmulConfig(M=1, N=2048, K=1024, + A_dtype="float16", W_dtype="int4", + accum_dtype="float16", out_dtype="float16", + layout="nt", with_bias=False) + op = bitblas.Matmul(cfg, target=os.environ.get("TVM_TARGET", "cuda -arch=sm_110"), + enable_tuning=True, from_database=False) + op.hardware_aware_finetune(topk=20, parallel_build=True) + PY + + # 运行示例 + conda run -n bitblas311 TVM_TARGET="cuda -arch=sm_110" python testing/1.py + ``` + +4) 运行现状 + - 示例正常执行,Ref/BitBLAS 输出一致。 + - 仍会有 cutlass `vector_types` 弃用 warning,可忽略。 + - 如果不设 `TVM_TARGET`,将基于 compute_cap 自动落在 sm_110。 + +## 变更文件列表 +- `setup.py` +- `bitblas/base/arch/cuda.py` +- `bitblas/base/arch/__init__.py` +- `bitblas/utils/target_detector.py` diff --git a/setup.py b/setup.py index 5244cc187..34d0d0d27 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ import re import tarfile from io import BytesIO +import tempfile import os import sys import urllib.request @@ -142,8 +143,14 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty" # Extract the file print(f"Extracting {file_name} to {extract_path}") - with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar: - tar.extractall(path=extract_path) + with tempfile.NamedTemporaryFile(delete=False, suffix=".tar.xz") as tmp_tar: + tmp_tar.write(file_content) + tmp_tar_path = tmp_tar.name + try: + # Avoid ownership preservation to prevent permission errors on non-root users + subprocess.check_call(["tar", "--no-same-owner", "-xJf", tmp_tar_path, "-C", extract_path]) + finally: + os.remove(tmp_tar_path) print("Download and extraction completed successfully.") return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", ""))) @@ -153,8 +160,10 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty" "bitblas": ["py.typed"], } -LLVM_VERSION = "10.0.1" -IS_AARCH64 = False # Set to True if on an aarch64 platform +# Prefer a modern LLVM build that links against libtinfo6 (available on Ubuntu 24.04) +# instead of the legacy 10.x release that required libtinfo5. +LLVM_VERSION = "18.1.8" +IS_AARCH64 = True # Set to True if on an aarch64 platform EXTRACT_PATH = "3rdparty" # Default extraction path