diff --git a/README.md b/README.md index ac950788..772a9163 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,17 @@ The following runs `elicit` on the Cartesian product of the listed models and da elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled ``` +## Running big models +For big models that cannot fit on a single gpu, you'll need to use multiple +gpus per model. + +This is an example to run a single 8bit llama-65b model on 2 A40s that have +~50 GB of memory each. + +``` +elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 --int8 +``` + ## Caching The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`. diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 2a4c36e2..ad040ed6 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,7 +1,6 @@ """Functions for extracting the hidden states of a model.""" import logging import os -from contextlib import nullcontext, redirect_stdout from dataclasses import InitVar, dataclass, replace from itertools import zip_longest from typing import Any, Iterable, Literal @@ -34,13 +33,16 @@ float_to_int16, infer_label_column, infer_num_classes, - instantiate_model, instantiate_tokenizer, is_autoregressive, prevent_name_conflicts, select_split, select_train_val_splits, - select_usable_devices, +) +from ..utils.multi_gpu import ( + ModelDevices, + instantiate_model_with_devices, + select_devices_multi_gpus, ) from .dataset_name import ( DatasetDictWithName, @@ -149,29 +151,33 @@ def explode(self) -> list["Extract"]: def extract_hiddens( cfg: "Extract", *, - device: str | torch.device = "cpu", + devices: ModelDevices, split_type: Literal["train", "val"] = "train", rank: int = 0, world_size: int = 1, ) -> Iterable[dict]: + first_device = ( + devices if not isinstance(devices, ModelDevices) else devices.first_device + ) """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" + is_verbose = rank == 0 + # Silence datasets logging messages from all but the first process - if rank != 0: + if not is_verbose: filterwarnings("ignore") logging.disable(logging.CRITICAL) ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." - # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its - # welcome message on every rank - with redirect_stdout(None) if rank != 0 else nullcontext(): - model = instantiate_model(cfg.model, device=device, load_in_8bit=cfg.int8) - tokenizer = instantiate_tokenizer( - cfg.model, truncation_side="left", verbose=rank == 0 - ) + model = instantiate_model_with_devices( + cfg=cfg, device_config=devices, is_verbose=is_verbose + ) + tokenizer = instantiate_tokenizer( + cfg.model, truncation_side="left", verbose=is_verbose + ) is_enc_dec = model.config.is_encoder_decoder if is_enc_dec and cfg.use_encoder_states: @@ -225,7 +231,7 @@ def extract_hiddens( num_variants, num_choices, model.config.hidden_size, - device=device, + device=first_device, dtype=torch.int16, ) for layer_idx in layer_indices @@ -233,7 +239,7 @@ def extract_hiddens( lm_logits = torch.empty( num_variants, num_choices, - device=device, + device=first_device, dtype=torch.float32, ) text_questions = [] @@ -254,8 +260,7 @@ def extract_hiddens( add_special_tokens=True, return_tensors="pt", text_target=target, # type: ignore[arg-type] - ).to(device) - + ).to(first_device) input_ids = assert_type(Tensor, encoding.input_ids) if is_enc_dec: answer = assert_type(Tensor, encoding.labels) @@ -265,8 +270,7 @@ def extract_hiddens( # Don't include [CLS] and [SEP] in the answer add_special_tokens=False, return_tensors="pt", - ).to(device) - + ).to(first_device) answer = assert_type(Tensor, encoding2.input_ids) input_ids = torch.cat([input_ids, answer], dim=-1) @@ -413,13 +417,16 @@ def extract( disable_cache: bool = False, highlight_color: Color = "cyan", num_gpus: int = -1, + gpus_per_model: int = 1, min_gpu_mem: int | None = None, split_type: Literal["train", "val", None] = None, ) -> DatasetDictWithName: """Extract hidden states from a model and return a `DatasetDict` containing them.""" info, features = hidden_features(cfg) - devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) + devices: list[ModelDevices] = select_devices_multi_gpus( + gpus_per_model=gpus_per_model, num_gpus=num_gpus, min_memory=min_gpu_mem + ) limits = cfg.max_examples splits = assert_type(SplitDict, info.splits) @@ -455,7 +462,7 @@ def extract( ), gen_kwargs=dict( cfg=[cfg] * len(devices), - device=devices, + devices=devices, rank=list(range(len(devices))), split_type=[ty] * len(devices), world_size=[len(devices)] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 84818c83..90b3eedd 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,7 +30,7 @@ def create_config_id( config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() - if k not in ("device", "rank", "world_size") + if k not in ("devices", "rank", "world_size") } return super().create_config_id(config_kwargs, custom_features) diff --git a/elk/run.py b/elk/run.py index 85f244c7..d84f20cb 100644 --- a/elk/run.py +++ b/elk/run.py @@ -49,6 +49,7 @@ class Run(ABC, Serializable): num_gpus: int = -1 out_dir: Path | None = None disable_cache: bool = field(default=False, to_dict=False) + gpus_per_model: int = 1 def execute( self, @@ -61,6 +62,7 @@ def execute( disable_cache=self.disable_cache, highlight_color=highlight_color, num_gpus=self.num_gpus, + gpus_per_model=self.gpus_per_model, min_gpu_mem=self.min_gpu_mem, split_type=split_type, ) diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index a4294298..9783c41e 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -164,3 +164,18 @@ def select_usable_devices( print(f"Using {len(selection)} of {num_visible} GPUs: {selection}") return [f"cuda:{i}" for i in selection] + + +def get_available_memory_for_devices() -> dict[str, int]: + # PyNVML and PyTorch device indices should agree when CUDA_VISIBLE_DEVICES is + # not set. We need them to agree so that the PyNVML indices match the PyTorch + # indices, and we don't have to do any complex error-prone conversions. + num_visible = torch.cuda.device_count() + num_installed = pynvml.nvmlDeviceGetCount() + assert num_installed == num_visible, "PyNVML and PyTorch disagree on GPU count" + output = {} + # Get free memory for each GPU + for i in range(num_installed): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + output[f"cuda:{i}"] = int(pynvml.nvmlDeviceGetMemoryInfo(handle).free) + return output diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 9f429921..26319cea 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import transformers from transformers import ( @@ -20,44 +22,59 @@ _AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES +def determine_dtypes( + model_str: str, + is_cpu: bool, + load_in_8bit: bool, +) -> torch.dtype | str: + model_cfg = AutoConfig.from_pretrained(model_str) + + # When the torch_dtype is None, this generally means the model is fp32, because + # the config was probably created before the `torch_dtype` field was added. + fp32_weights = model_cfg.torch_dtype in (None, torch.float32) + + # Required by `bitsandbytes` to load in 8-bit. + if load_in_8bit: + # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint + # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and + # we can't guarantee that there won't be overflow if we downcast to fp16. + if fp32_weights: + raise ValueError("Cannot load in 8-bit if weights are fp32") + + torch_dtype = torch.float16 + + # CPUs generally don't support anything other than fp32. + elif is_cpu: + torch_dtype = torch.float32 + + # If the model is fp32 but bf16 is available, convert to bf16. + # Usually models with fp32 weights were actually trained in bf16, and + # converting them doesn't hurt performance. + elif fp32_weights and torch.cuda.is_bf16_supported(): + torch_dtype = torch.bfloat16 + print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") + else: + torch_dtype = "auto" + return torch_dtype + + def instantiate_model( model_str: str, - device: str | torch.device = "cpu", + load_in_8bit: bool, + is_cpu: bool, + torch_dtype: Optional[torch.dtype] = None, **kwargs, ) -> PreTrainedModel: """Instantiate a model string with the appropriate `Auto` class.""" - device = torch.device(device) - kwargs["device_map"] = {"": device} with prevent_name_conflicts(): model_cfg = AutoConfig.from_pretrained(model_str) - - # When the torch_dtype is None, this generally means the model is fp32, because - # the config was probably created before the `torch_dtype` field was added. - fp32_weights = model_cfg.torch_dtype in (None, torch.float32) - - # Required by `bitsandbytes` to load in 8-bit. - if kwargs.get("load_in_8bit"): - # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint - # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and - # we can't guarantee that there won't be overflow if we downcast to fp16. - if fp32_weights: - raise ValueError("Cannot load in 8-bit if weights are fp32") - - kwargs["torch_dtype"] = torch.float16 - - # CPUs generally don't support anything other than fp32. - elif device.type == "cpu": - kwargs["torch_dtype"] = torch.float32 - - # If the model is fp32 but bf16 is available, convert to bf16. - # Usually models with fp32 weights were actually trained in bf16, and - # converting them doesn't hurt performance. - elif fp32_weights and torch.cuda.is_bf16_supported(): - kwargs["torch_dtype"] = torch.bfloat16 - print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") - else: - kwargs["torch_dtype"] = "auto" + # If a torch_dtype was not specified, try to infer it. + kwargs["torch_dtype"] = torch_dtype or determine_dtypes( + model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit + ) + # Add load_in_8bit to kwargs + kwargs["load_in_8bit"] = load_in_8bit archs = model_cfg.architectures if not isinstance(archs, list): @@ -70,7 +87,6 @@ def instantiate_model( if arch_str.endswith(suffix): model_cls = getattr(transformers, arch_str) return model_cls.from_pretrained(model_str, **kwargs) - return AutoModel.from_pretrained(model_str, **kwargs) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py new file mode 100644 index 00000000..f6b945cc --- /dev/null +++ b/elk/utils/multi_gpu.py @@ -0,0 +1,185 @@ +from contextlib import nullcontext, redirect_stdout +from dataclasses import dataclass +from typing import TYPE_CHECKING, Type + +import torch +from accelerate import infer_auto_device_map, init_empty_weights +from torch.nn import Module +from transformers import PreTrainedModel + +from elk.utils import instantiate_model, select_usable_devices +from elk.utils.gpu_utils import get_available_memory_for_devices + +if TYPE_CHECKING: + from elk import Extract + + +@dataclass +class ModelDevices: + # The devices to instantiate a single model on + first_device: str + other_devices: list[str] + + @property + def is_single_gpu(self) -> bool: + return len(self.other_devices) == 0 + + @property + def used_devices(self) -> list[str]: + return [self.first_device] + self.other_devices + + @property + def has_cpu_device(self) -> bool: + devices = [torch.device(device) for device in self.used_devices] + return any(device.type == "cpu" for device in devices) + + +def instantiate_model_with_devices( + cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs +) -> PreTrainedModel: + first_device = device_config.first_device + + # TODO: Maybe we should ensure the device map is the same + # for all the extract processes? This is because the device map + # can affect performance highly and its annoying if one process + # is using a different device map than the others. + device_map = ( + {"": first_device} + if device_config.is_single_gpu + else create_device_map( + model_str=cfg.model, + load_in_8bit=cfg.int8, + model_devices=device_config, + verbose=is_verbose, + ) + ) + if is_verbose: + print( + f"Using {len(device_config.used_devices)} gpu(s) to" + f" instantiate a single model." + ) + with redirect_stdout(None) if not is_verbose else nullcontext(): + model = instantiate_model( + cfg.model, + device_map=device_map, + load_in_8bit=cfg.int8, + is_cpu=device_config.has_cpu_device, + **kwargs, + ) + return model + + +def create_device_map( + model_str: str, + load_in_8bit: bool, + model_devices: ModelDevices, + verbose: bool, +) -> dict[str, str]: + """Creates a device map for a model running on multiple GPUs.""" + with init_empty_weights(): + # Need to first instantiate an empty model to get the layer class + # We need to specify load_in_8bit False because its incompatible with + # init_empty_weights + model = instantiate_model( + model_str=model_str, + load_in_8bit=False, + is_cpu=False, + ) + + # e.g. {"cuda:0": 16000, "cuda:1": 16000} + max_memory_all_devices: dict[str, int] = get_available_memory_for_devices() + # now let's get the available memory for the devices we want to use + used_devices = model_devices.used_devices + max_memory_used_devices: dict[str, int | float] = { + device: max_memory_all_devices[device] for device in used_devices + } + # Decrease the memory potentially used by the first device + # because we're going to create additional tensors on it + max_memory_used_devices[model_devices.first_device] = ( + max_memory_used_devices[model_devices.first_device] * 0.6 + ) + # If 8bit, multiply the memory by 2 + # This is because we instantiated our empty model in (probably) float16 + # We aren't able to instantiate an empty model in 8bit currently + devices_accounted_8bit = ( + { + device: max_memory_used_devices[device] * 2 + for device in max_memory_used_devices + } + if load_in_8bit + else max_memory_used_devices + ) + + # Make sure that the transformer layer is not split + # because that contains residual connections + # See https://huggingface.co/docs/accelerate/usage_guides/big_modeling + # Otherwise we get an error like this: + # RuntimeError: Expected all tensors to be on the same device, + # but found at least two devices, cuda:0 and cuda1 + maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) + dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] + autodevice_map = infer_auto_device_map( + model, no_split_module_classes=dont_split, max_memory=devices_accounted_8bit + ) + + if verbose: + print(f"Autodevice map: {autodevice_map}") + assert "disk" not in autodevice_map.values(), ( + f"Unable to fit the model {model} into the given memory for {used_devices}." + f" Try increasing gpus_per_model?" + ) + return autodevice_map + + +def select_devices_multi_gpus( + gpus_per_model: int, + num_gpus: int, + min_memory: int | None = None, +) -> list[ModelDevices]: + if gpus_per_model == 1: + devices = select_usable_devices(num_gpus, min_memory=min_memory) + return [ + ModelDevices(first_device=devices, other_devices=[]) for devices in devices + ] + else: + # how many models can we create? + models_to_create = num_gpus // gpus_per_model + print( + f"Allocating devices for {models_to_create} models with {gpus_per_model}" + f" GPUs each" + ) + devices = select_usable_devices(num_gpus, min_memory=min_memory) + configs = split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) + print(f"Models will be instantiated on {configs}") + return configs + + +def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None: + """Get the class of the transformer layer used by the given model.""" + total_params = sum(p.numel() for p in model.parameters()) + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + module_params = sum(p.numel() for p in module.parameters()) + if module_params > total_params / 2: + type_of_cls = type(module[0]) + print(f"Found transformer layer of type {type_of_cls}") + return type_of_cls + + return None + + +def split_devices_into_model_devices( + devices: list[str], gpus_per_model: int, models_to_create: int +) -> list[ModelDevices]: + assert len(devices) >= gpus_per_model * models_to_create + configs = [] + while len(configs) < models_to_create: + first_device = devices.pop(0) + other_devices = devices[: gpus_per_model - 1] + devices = devices[gpus_per_model - 1 :] + configs.append(ModelDevices(first_device, other_devices)) + return configs diff --git a/tests/test_split_devices.py b/tests/test_split_devices.py new file mode 100644 index 00000000..da85b051 --- /dev/null +++ b/tests/test_split_devices.py @@ -0,0 +1,34 @@ +from elk.utils.multi_gpu import ModelDevices, split_devices_into_model_devices + + +def test_split_2_devices_1_gpu_per_model(): + devices = ["a", "b"] + gpus_per_model = 1 + models_to_create = 2 + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", []), ModelDevices("b", [])] + + +def test_split_4_devices_2_gpus_per_model(): + devices = ["a", "b", "c", "d"] + gpus_per_model = 2 + models_to_create = 2 + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", ["b"]), ModelDevices("c", ["d"])] + + +def test_split_7_devices_3_gpus_per_model(): + devices = ["a", "b", "c", "d", "e", "f", "g"] + gpus_per_model = 3 + models_to_create = 2 + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", ["b", "c"]), ModelDevices("d", ["e", "f"])]