Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions docs/packaging/python_packaging.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,32 @@ try:
from . import _rocm_init
except ModuleNotFoundError:
pass
else:
_rocm_init.initialize()
del _rocm_init
```

Generate a `_rocm_init.py` file like this (using any suitable scripting):

```bash
echo "
import rocm_sdk
rocm_sdk.initialize_process(library_shortnames=[
'amd_comgr',
'amdhip64',
'roctx64',
'hiprtc',
'hipblas',
'hipfft',
'hiprand',
'hipsparse',
'hipsolver',
'rccl',
'hipblaslt',
'miopen',
],
check_version='$(rocm-sdk version)')
def initialize():
import rocm_sdk
rocm_sdk.initialize_process(preload_shortnames=[
'amd_comgr',
'amdhip64',
'roctx64',
'hiprtc',
'hipblas',
'hipfft',
'hiprand',
'hipsparse',
'hipsolver',
'rccl',
'hipblaslt',
'miopen',
],
check_version='$(rocm-sdk version)')
" > torch/_rocm_init.py
```

Expand Down
23 changes: 21 additions & 2 deletions external-builds/pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ Now checkout repositories:

```bash
python pytorch_torch_repo.py checkout
python pytorch_torch_audio_repo.py checkout
python pytorch_torch_vision_repo.py checkout
python pytorch_audio_repo.py checkout
python pytorch_vision_repo.py checkout
```

- On Windows, use shorter paths to avoid command length limits:
Expand Down Expand Up @@ -194,3 +194,22 @@ To create patches

1. Commit your change(s) within the relevant source folder(s)
1. Run the `save-patches` subcommand of the relevant source management script(s)

## Alternate Branches / Patch Sets

### PyTorch Nightly

This checks out the `nightly` branches from https://github.com/pytorch,
tracking the latest pytorch.org nightly release:

- https://github.com/pytorch/pytorch/tree/nightly
- https://github.com/pytorch/audio/tree/nightly
- https://github.com/pytorch/vision/tree/nightly

```
python pytorch_torch_repo.py checkout --repo-hashtag nightly
python pytorch_audio_repo.py checkout --repo-hashtag nightly
python pytorch_vision_repo.py checkout --repo-hashtag nightly
# Note that triton will be checked out at the PyTorch pin.
python pytorch_triton_repo.py checkout
```
137 changes: 131 additions & 6 deletions external-builds/pytorch/build_prod_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@

```
# For therock-nightly-python
build_prod_wheels.py
build_prod_wheels.py \
install-rocm \
--index-url https://d2awnip2yjpvqn.cloudfront.net/v2/gfx110X-dgpu/

# For therock-dev-python (unstable but useful for testing outside of prod)
build_prod_wheels.py
build_prod_wheels.py \
install-rocm \
--index-url https://d25kgig7rdsyks.cloudfront.net/v2/gfx110X-dgpu/
```
Expand Down Expand Up @@ -122,11 +122,43 @@
import subprocess
import sys
import tempfile
import textwrap

script_dir = Path(__file__).resolve().parent

is_windows = platform.system() == "Windows"

# List of library preloads for Linux to generate into _rocm_init.py
LINUX_LIBRARY_PRELOADS = [
"amd_comgr",
"amdhip64",
"rocprofiler-sdk-roctx", # Linux only for the moment.
"roctx64", # Linux only for the moment.
"hiprtc",
"hipblas",
"hipfft",
"hiprand",
"hipsparse",
"hipsolver",
"rccl", # Linux only for the moment.
"hipblaslt",
"miopen",
]

# List of library preloads for Linux to generate into _rocm_init.py
WINDOWS_LIBRARY_PRELOADS = [
"amd_comgr",
"amdhip64",
"hiprtc",
"hipblas",
"hipfft",
"hiprand",
"hipsparse",
"hipsolver",
"hipblaslt",
"miopen",
]


def exec(args: list[str | Path], cwd: Path, env: dict[str, str] | None = None):
args = [str(arg) for arg in args]
Expand Down Expand Up @@ -189,6 +221,24 @@ def get_rocm_path(path_name: str) -> Path:
)


def get_rocm_init_contents(args: argparse.Namespace):
"""Gets the contents of the _rocm_init.py file to add to the build."""
sdk_version = get_rocm_sdk_version()
library_preloads = (
WINDOWS_LIBRARY_PRELOADS if is_windows else LINUX_LIBRARY_PRELOADS
)
library_preloads_formatted = ", ".join(f"'{s}'" for s in library_preloads)
return textwrap.dedent(
f"""
def initialize():
import rocm_sdk
rocm_sdk.initialize_process(
preload_shortnames=[{library_preloads_formatted}],
check_version='{sdk_version}')
"""
)


def remove_dir_if_exists(dir: Path):
if dir.exists():
print(f"++ Removing {dir}")
Expand Down Expand Up @@ -255,6 +305,15 @@ def do_install_rocm(args: argparse.Namespace):
print(f"Installed version: {get_rocm_sdk_version()}")


def add_env_compiler_flags(env: dict[str, str], flagname: str, *compiler_flags: str):
current = env.get(flagname, "")
append = ""
for compiler_flag in compiler_flags:
append += f" {compiler_flag}"
env[flagname] = f"{current}{append}"
print(f"-- Appended {flagname}+={append}")


def do_build(args: argparse.Namespace):
if args.install_rocm:
do_install_rocm(args)
Expand Down Expand Up @@ -296,6 +355,7 @@ def do_build(args: argparse.Namespace):
env: dict[str, str] = {
"CMAKE_PREFIX_PATH": str(cmake_prefix),
"ROCM_HOME": str(root_dir),
"ROCM_PATH": str(root_dir),
"PYTORCH_ROCM_ARCH": pytorch_rocm_arch,
# TODO: Figure out what is blocking GLOO and enable.
"USE_GLOO": "OFF",
Expand Down Expand Up @@ -330,9 +390,30 @@ def do_build(args: argparse.Namespace):
}
)

# Workaround missing devicelib bitcode
# TODO: When "ROCM_PATH" and/or "ROCM_HOME" is set in the environment, the
# clang frontend ignores its default heuristics and (depending on version)
# finds the wrong path to the device library. This is bad/annoying. But
# the PyTorch build shouldn't even need these to be set. Unfortunately, it
# has been hardcoded for a long time. So we use a clang env var to force
# a specific device lib path to workaround the hack to get pytorch to build.
# This may or may not only affect the Python wheels with their own quirks
# on directory layout.
# Obviously, this should be completely burned with fire once the root causes
# are eliminted.
hip_device_lib_path = get_rocm_path("root") / "llvm" / "amdgcn" / "bitcode"
if not hip_device_lib_path.exists():
print(
"WARNING: Default location of device libs not found. Relying on "
"clang heuristics which are known to be buggy in this configuration"
)
else:
env["HIP_DEVICE_LIB_PATH"] = hip_device_lib_path

# Build triton.
triton_requirement = None
if triton_dir:
if args.build_triton or (args.build_triton is None and triton_dir):
assert triton_dir, "Must specify --triton-dir if --build-triton"
triton_requirement = do_build_triton(args, triton_dir, dict(env))
else:
print("--- Not building triton (no --triton-dir)")
Expand All @@ -346,13 +427,23 @@ def do_build(args: argparse.Namespace):
print("--- Not building pytorch (no --pytorch-dir)")

# Build pytorch audio.
if pytorch_audio_dir:
if args.build_pytorch_audio or (
args.build_pytorch_audio is None and pytorch_audio_dir
):
assert (
pytorch_audio_dir
), "Must specify --pytorch-audio-dir if --build-pytorch-audio"
do_build_pytorch_audio(args, pytorch_audio_dir, dict(env))
else:
print("--- Not build pytorch-audio (no --pytorch-audio-dir)")

# Build pytorch vision.
if pytorch_vision_dir:
if args.build_pytorch_vision or (
args.build_pytorch_vision is None and pytorch_vision_dir
):
assert (
pytorch_vision_dir
), "Must specify --pytorch-vision-dir if --build-pytorch-vision"
do_build_pytorch_vision(args, pytorch_vision_dir, dict(env))
else:
print("--- Not build pytorch-vision (no --pytorch-vision-dir)")
Expand Down Expand Up @@ -414,6 +505,7 @@ def do_build_pytorch(
pytorch_build_version = (pytorch_dir / "version.txt").read_text().strip()
pytorch_build_version += args.version_suffix
print(f" Default PYTORCH_BUILD_VERSION: {pytorch_build_version}")
env["USE_ROCM"] = "ON"
env["PYTORCH_BUILD_VERSION"] = pytorch_build_version
env["PYTORCH_BUILD_NUMBER"] = args.pytorch_build_number

Expand All @@ -428,10 +520,13 @@ def do_build_pytorch(
f"--- PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {env['PYTORCH_EXTRA_INSTALL_REQUIREMENTS']}"
)

# Add the _rocm_init.py file.
(pytorch_dir / "torch" / "_rocm_init.py").write_text(get_rocm_init_contents(args))

# Workaround missing features on windows.
if is_windows:
env.update(
{
"USE_ROCM": "ON",
"USE_FLASH_ATTENTION": "0",
"USE_MEM_EFF_ATTENTION": "0",
"DISTUTILS_USE_SDK": "1",
Expand All @@ -446,6 +541,17 @@ def do_build_pytorch(
}
)

if not is_windows:
# Prepend the ROCm sysdeps dir so that we use bundled libraries.
# While a decent thing to be doing, this is presently required because:
# TODO: include/rocm_smi/kfd_ioctl.h is included without its advertised
# transitive includes. This triggers a compilation error for a missing
# libdrm/drm.h.
sysdeps_dir = get_rocm_path("root") / "lib" / "rocm_sysdeps"
assert sysdeps_dir.exists(), f"No sysdeps directory found: {sysdeps_dir}"
add_env_compiler_flags(env, "CXXFLAGS", f"-I{sysdeps_dir / 'include'}")
add_env_compiler_flags(env, "LDFLAGS", f"-L{sysdeps_dir / 'lib'}")

print("+++ Uninstalling pytorch:")
exec(
[sys.executable, "-m", "pip", "uninstall", "torch", "-y"],
Expand Down Expand Up @@ -634,6 +740,25 @@ def add_common(p: argparse.ArgumentParser):
build_p.add_argument(
"--pytorch-build-number", default="1", help="Build number to append to version"
)
build_p.add_argument(
"--build-triton",
action=argparse.BooleanOptionalAction,
default=None,
help="Enable building of triton (requires --triton-dir)",
)
build_p.add_argument(
"--build-pytorch-audio",
action=argparse.BooleanOptionalAction,
default=None,
help="Enable building of torch audio (requires --pytorch-audio-dir)",
)
build_p.add_argument(
"--build-pytorch-vision",
action=argparse.BooleanOptionalAction,
default=None,
help="Enable building of torch vision (requires --pytorch-vision-dir)",
)

today = date.today()
formatted_date = today.strftime("%Y%m%d")
build_p.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion external-builds/pytorch/pytorch_triton_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def do_checkout(args: argparse.Namespace):
else:
# Derive the commit pin base on ci commit.
args.repo_hashtag = get_triton_pin(torch_dir)
build_env["TRITON_WHEEL_VERSION_SUFFIX"] = f"+git{args.repo_hashtag[:8]}"
# Latest triton calculates its own git hash and TRITON_WHEEL_VERSION_SUFFIX
# goes after the "+". Older versions must supply their own "+". We just
# leave it out entirely to avoid version errors.
build_env["TRITON_WHEEL_VERSION_SUFFIX"] = ""
print(f"Triton CI commit pin: {args.repo_hashtag}")

def _do_hipify(args: argparse.Namespace):
Expand Down
Loading