From 79a0be010e30df648b4b80fe2a62b4cc12d2f7c3 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:26:03 +0800 Subject: [PATCH 01/42] add llama map --- elk/extraction/extraction.py | 31 +++---- elk/extraction/llama/device_configs.py | 75 ++++++++++++++++ elk/extraction/llama/device_map.py | 117 +++++++++++++++++++++++++ 3 files changed, 208 insertions(+), 15 deletions(-) create mode 100644 elk/extraction/llama/device_configs.py create mode 100644 elk/extraction/llama/device_map.py diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 5446cd5c..2dd535b3 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -26,6 +26,11 @@ from transformers import AutoConfig, PreTrainedModel from transformers.modeling_outputs import Seq2SeqLMOutput +from .llama.device_configs import ( + Llama65bDeviceConfig, + select_devices_or_llama_65b_configs, + instantiate_model_or_llama, +) from ..promptsource import DatasetTemplates from ..utils import ( Color, @@ -40,7 +45,6 @@ prevent_name_conflicts, select_split, select_train_val_splits, - select_usable_devices, ) from .dataset_name import ( DatasetDictWithName, @@ -144,11 +148,16 @@ def explode(self) -> list["Extract"]: def extract_hiddens( cfg: "Extract", *, - device: str | torch.device = "cpu", + device_config: str | Llama65bDeviceConfig = "cpu", split_type: Literal["train", "val"] = "train", rank: int = 0, world_size: int = 1, ) -> Iterable[dict]: + device = ( + device_config + if not isinstance(device_config, Llama65bDeviceConfig) + else device_config.first_device + ) """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -160,20 +169,10 @@ def extract_hiddens( ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." - if cfg.int8: - # Required by `bitsandbytes` - dtype = torch.float16 - elif device == "cpu": - dtype = torch.float32 - else: - dtype = "auto" - # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its # welcome message on every rank with redirect_stdout(None) if rank != 0 else nullcontext(): - model = instantiate_model( - cfg.model, device_map={"": device}, load_in_8bit=cfg.int8, torch_dtype=dtype - ) + model = instantiate_model_or_llama(cfg=cfg, device_config=device_config) tokenizer = instantiate_tokenizer( cfg.model, truncation_side="left", verbose=rank == 0 ) @@ -395,7 +394,9 @@ def extract( """Extract hidden states from a model and return a `DatasetDict` containing them.""" info, features = hidden_features(cfg) - devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) + devices: Sequence[str | Llama65bDeviceConfig] = select_devices_or_llama_65b_configs( + model_name=cfg.model, num_gpus=num_gpus, min_memory=min_gpu_mem + ) limits = cfg.max_examples splits = assert_type(SplitDict, info.splits) @@ -431,7 +432,7 @@ def extract( ), gen_kwargs=dict( cfg=[cfg] * len(devices), - device=devices, + device_config=devices, rank=list(range(len(devices))), split_type=[ty] * len(devices), world_size=[len(devices)] * len(devices), diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py new file mode 100644 index 00000000..bc5497d8 --- /dev/null +++ b/elk/extraction/llama/device_configs.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass +from typing import Sequence + +import torch +from transformers import PreTrainedModel + +from elk import Extract +from elk.extraction.llama.device_map import get_llama_65b_8bit_device_map +from elk.utils import select_usable_devices, instantiate_model + + +@dataclass +class Llama65bDeviceConfig: + first_device: str + second_device: str + + +def select_devices_or_llama_65b_configs( + model_name: str, + num_gpus: int, + min_memory: int | None = None, +) -> Sequence[str | Llama65bDeviceConfig]: + if "llama-65b" not in model_name: + return select_usable_devices(num_gpus, min_memory=min_memory) + else: + print( + f"You've selected a llama-65b model, which requires at least two GPUs." + f"Each GPU must have at least 40 GiB of memory." + ) + print("Note that we will force the model to use 8-bit") + assert num_gpus >= 2, "llama-65b models require at least two GPUs" + # how many pairs of 2 gpus are specified? + num_pairs = num_gpus // 2 + print(f"Will create {num_pairs} llama workers ") + forty_gb = 42_949_672_960 + llama_workers_config = [] + for i in range(num_pairs): + devices = select_usable_devices(num_gpus, min_memory=forty_gb) + llama_workers_config.append( + Llama65bDeviceConfig(first_device=devices[0], second_device=devices[1]) + ) + return llama_workers_config + + +def instantiate_model_or_llama( + cfg: Extract, device_config: str | Llama65bDeviceConfig, **kwargs +) -> PreTrainedModel: + is_llama_65b = isinstance(device_config, Llama65bDeviceConfig) + first_device = device_config.first_device if is_llama_65b else device_config + if cfg.int8 or is_llama_65b: + # Required by `bitsandbytes` + dtype = torch.float16 + elif device_config == "cpu": + dtype = torch.float32 + else: + dtype = "auto" + if is_llama_65b: + model = instantiate_model( + cfg.model, + device_map=get_llama_65b_8bit_device_map( + first_device=first_device, second_device=device_config.second_device + ), + load_in_8bit=True, + torch_dtype=dtype, + **kwargs, + ) + else: + model = instantiate_model( + cfg.model, + device_map={"": first_device}, + load_in_8bit=cfg.int8, + torch_dtype=dtype, + **kwargs, + ) + return model diff --git a/elk/extraction/llama/device_map.py b/elk/extraction/llama/device_map.py new file mode 100644 index 00000000..fda64928 --- /dev/null +++ b/elk/extraction/llama/device_map.py @@ -0,0 +1,117 @@ +import torch +from accelerate import init_empty_weights, infer_auto_device_map + +from elk.utils import instantiate_model + + +def get_suggested_map(model_str: str, used_dtype: torch.dtype) -> dict[str, int]: + """Util function to get the suggested map for a given model string and dtype + Usually doesn't work out of the box, you'll need to manually + change the attention module + to the same device as the lm_head due to the residual connection. + """ + with init_empty_weights(): + # you need to first instantiate the model to get the suggested map + model = instantiate_model(model_str, torch_dtype=used_dtype) + suggested_map = infer_auto_device_map(model) + return suggested_map + + +def get_llama_65b_8bit_device_map( + first_device: str | torch.device, second_device: str | torch.device +) -> dict[str, str | torch.device]: + """ + This assumes that you are using 2 GPUs, with at least 40GB of memory each. + and that you are using 8bit + """ + return { + "model.embed_tokens": first_device, + "model.layers.0": first_device, + "model.layers.1": first_device, + "model.layers.2": first_device, + "model.layers.3": first_device, + "model.layers.4": first_device, + "model.layers.5": first_device, + "model.layers.6": first_device, + "model.layers.7": first_device, + "model.layers.8": first_device, + "model.layers.9": first_device, + "model.layers.10": first_device, + "model.layers.11": first_device, + "model.layers.12": first_device, + "model.layers.13": first_device, + "model.layers.14": first_device, + "model.layers.15": first_device, + "model.layers.16": first_device, + "model.layers.17": first_device, + "model.layers.18": first_device, + "model.layers.19": first_device, + "model.layers.20": first_device, + "model.layers.21": first_device, + "model.layers.22": first_device, + "model.layers.23": first_device, + "model.layers.24": first_device, + "model.layers.25": first_device, + "model.layers.26": first_device, + "model.layers.27.self_attn": first_device, + "model.layers.27.mlp.gate_proj": first_device, + "model.layers.27.mlp.down_proj": first_device, + "model.layers.27.mlp.up_proj": first_device, + "model.layers.27.mlp.act_fn": first_device, + "model.layers.27.input_layernorm": first_device, + "model.layers.27.post_attention_layernorm": first_device, + "model.layers.28": first_device, + "model.layers.29": first_device, + "model.layers.30": first_device, + "model.layers.31": first_device, + "model.layers.32": first_device, + "model.layers.33": first_device, + "model.layers.34": second_device, + "model.layers.35": second_device, + "model.layers.36": second_device, + "model.layers.37": second_device, + "model.layers.38": second_device, + "model.layers.39": second_device, + "model.layers.40": second_device, + "model.layers.41": second_device, + "model.layers.42": second_device, + "model.layers.43": second_device, + "model.layers.44": second_device, + "model.layers.45": second_device, + "model.layers.46": second_device, + "model.layers.47": second_device, + "model.layers.48": second_device, + "model.layers.49": second_device, + "model.layers.50": second_device, + "model.layers.51": second_device, + "model.layers.52": second_device, + "model.layers.53": second_device, + "model.layers.54": second_device, + "model.layers.55": second_device, + "model.layers.56": second_device, + "model.layers.57": second_device, + "model.layers.58": second_device, + "model.layers.59": second_device, + "model.layers.60": second_device, + "model.layers.61": second_device, + "model.layers.62": second_device, + "model.layers.63": second_device, + "model.layers.64": second_device, + "model.layers.65": second_device, + "model.layers.66": second_device, + "model.layers.67": second_device, + "model.layers.68": second_device, + "model.layers.69": second_device, + "model.layers.70": second_device, + "model.layers.71": second_device, + "model.layers.72": second_device, + "model.layers.73": second_device, + "model.layers.74": second_device, + "model.layers.75": second_device, + "model.layers.76": second_device, + "model.layers.77": second_device, + "model.layers.78": second_device, + "model.layers.79": second_device, + "model.norm": second_device, + "lm_head": first_device, + } From d692b7ffa1298740c7fdd93bb1de553d1485d7f8 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:27:44 +0800 Subject: [PATCH 02/42] add typechecking if --- elk/extraction/llama/device_configs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index bc5497d8..47374a71 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -1,13 +1,13 @@ from dataclasses import dataclass -from typing import Sequence +from typing import Sequence, TYPE_CHECKING import torch from transformers import PreTrainedModel - -from elk import Extract from elk.extraction.llama.device_map import get_llama_65b_8bit_device_map from elk.utils import select_usable_devices, instantiate_model +if TYPE_CHECKING: + from elk import Extract @dataclass class Llama65bDeviceConfig: @@ -43,7 +43,7 @@ def select_devices_or_llama_65b_configs( def instantiate_model_or_llama( - cfg: Extract, device_config: str | Llama65bDeviceConfig, **kwargs + cfg: "Extract", device_config: str | Llama65bDeviceConfig, **kwargs ) -> PreTrainedModel: is_llama_65b = isinstance(device_config, Llama65bDeviceConfig) first_device = device_config.first_device if is_llama_65b else device_config From 91d98d2b7820a22dff90920067fb29675ea7bec0 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:31:48 +0800 Subject: [PATCH 03/42] allocate the device properly --- elk/extraction/llama/device_configs.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index 47374a71..b0e1d5ea 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from elk import Extract + @dataclass class Llama65bDeviceConfig: first_device: str @@ -34,11 +35,18 @@ def select_devices_or_llama_65b_configs( print(f"Will create {num_pairs} llama workers ") forty_gb = 42_949_672_960 llama_workers_config = [] - for i in range(num_pairs): - devices = select_usable_devices(num_gpus, min_memory=forty_gb) - llama_workers_config.append( - Llama65bDeviceConfig(first_device=devices[0], second_device=devices[1]) + devices = select_usable_devices(num_gpus, min_memory=forty_gb) + # split the devices into pairs + configs = [] + while len(configs) < num_pairs: + first_device = devices.pop() + second_device = devices.pop() + configs.append( + Llama65bDeviceConfig( + first_device=first_device, second_device=second_device + ) ) + return llama_workers_config From aa5ee9e6745ad74d3ebe956983d452d689d0ac0d Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:34:04 +0800 Subject: [PATCH 04/42] print to debug --- elk/extraction/llama/device_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index b0e1d5ea..d504ec97 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -46,6 +46,7 @@ def select_devices_or_llama_65b_configs( first_device=first_device, second_device=second_device ) ) + print(f"Created {len(configs)} llama workers") return llama_workers_config From 7cbfb9a7c03766af907da512d2de0a337fbe2f92 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:36:34 +0800 Subject: [PATCH 05/42] change to device config --- elk/extraction/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 84818c83..53af2f5e 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,7 +30,7 @@ def create_config_id( config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() - if k not in ("device", "rank", "world_size") + if k not in ("device_config", "rank", "world_size") } return super().create_config_id(config_kwargs, custom_features) From cfb3200024869e1c78e46bfe55225a1395d453e2 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:44:04 +0800 Subject: [PATCH 06/42] more logs --- elk/extraction/generator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 53af2f5e..3137b847 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -27,6 +27,11 @@ def create_config_id( # to erase the world_size dimension so that the config id is the same no matter # how many processes are used. We also remove the explicit device, rank, and # world_size keys. + new = {} + for k, v in config_kwargs.get("gen_kwargs", {}).items(): + if k not in ("device_config", "rank", "world_size"): + new[k] = v[0] + print("new", new) config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() From 8cb4dec569242ce056b7cf8803d42ec381207fc7 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:44:56 +0800 Subject: [PATCH 07/42] print the value --- elk/extraction/generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 3137b847..f75116a7 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,6 +30,7 @@ def create_config_id( new = {} for k, v in config_kwargs.get("gen_kwargs", {}).items(): if k not in ("device_config", "rank", "world_size"): + print(f"key {k}: {v}") new[k] = v[0] print("new", new) config_kwargs["gen_kwargs"] = { From 9e8e321d305ea4e347cb4bf828f5f45994516a05 Mon Sep 17 00:00:00 2001 From: James Chua Date: Mon, 1 May 2023 19:47:29 +0800 Subject: [PATCH 08/42] fix not returning configs --- elk/extraction/generator.py | 6 ------ elk/extraction/llama/device_configs.py | 3 +-- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index f75116a7..53af2f5e 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -27,12 +27,6 @@ def create_config_id( # to erase the world_size dimension so that the config id is the same no matter # how many processes are used. We also remove the explicit device, rank, and # world_size keys. - new = {} - for k, v in config_kwargs.get("gen_kwargs", {}).items(): - if k not in ("device_config", "rank", "world_size"): - print(f"key {k}: {v}") - new[k] = v[0] - print("new", new) config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index d504ec97..febfbbdd 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -34,7 +34,6 @@ def select_devices_or_llama_65b_configs( num_pairs = num_gpus // 2 print(f"Will create {num_pairs} llama workers ") forty_gb = 42_949_672_960 - llama_workers_config = [] devices = select_usable_devices(num_gpus, min_memory=forty_gb) # split the devices into pairs configs = [] @@ -48,7 +47,7 @@ def select_devices_or_llama_65b_configs( ) print(f"Created {len(configs)} llama workers") - return llama_workers_config + return configs def instantiate_model_or_llama( From 36ea2e6b03288b1c326bd67a0916a59cf39a2b66 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 May 2023 11:51:17 +0000 Subject: [PATCH 09/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/extraction/extraction.py | 11 +++++------ elk/extraction/llama/device_configs.py | 9 +++++---- elk/extraction/llama/device_map.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 2dd535b3..f6747000 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -26,11 +26,6 @@ from transformers import AutoConfig, PreTrainedModel from transformers.modeling_outputs import Seq2SeqLMOutput -from .llama.device_configs import ( - Llama65bDeviceConfig, - select_devices_or_llama_65b_configs, - instantiate_model_or_llama, -) from ..promptsource import DatasetTemplates from ..utils import ( Color, @@ -39,7 +34,6 @@ float32_to_int16, infer_label_column, infer_num_classes, - instantiate_model, instantiate_tokenizer, is_autoregressive, prevent_name_conflicts, @@ -51,6 +45,11 @@ parse_dataset_string, ) from .generator import _GeneratorBuilder +from .llama.device_configs import ( + Llama65bDeviceConfig, + instantiate_model_or_llama, + select_devices_or_llama_65b_configs, +) from .prompt_loading import load_prompts diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index febfbbdd..831f3384 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -1,10 +1,11 @@ from dataclasses import dataclass -from typing import Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence import torch from transformers import PreTrainedModel + from elk.extraction.llama.device_map import get_llama_65b_8bit_device_map -from elk.utils import select_usable_devices, instantiate_model +from elk.utils import instantiate_model, select_usable_devices if TYPE_CHECKING: from elk import Extract @@ -25,8 +26,8 @@ def select_devices_or_llama_65b_configs( return select_usable_devices(num_gpus, min_memory=min_memory) else: print( - f"You've selected a llama-65b model, which requires at least two GPUs." - f"Each GPU must have at least 40 GiB of memory." + "You've selected a llama-65b model, which requires at least two GPUs." + "Each GPU must have at least 40 GiB of memory." ) print("Note that we will force the model to use 8-bit") assert num_gpus >= 2, "llama-65b models require at least two GPUs" diff --git a/elk/extraction/llama/device_map.py b/elk/extraction/llama/device_map.py index fda64928..931a11b0 100644 --- a/elk/extraction/llama/device_map.py +++ b/elk/extraction/llama/device_map.py @@ -1,5 +1,5 @@ import torch -from accelerate import init_empty_weights, infer_auto_device_map +from accelerate import infer_auto_device_map, init_empty_weights from elk.utils import instantiate_model From a59d0a2fafd8c4cf349e01831fd478b0107a3596 Mon Sep 17 00:00:00 2001 From: James Chua Date: Tue, 2 May 2023 00:10:50 +0800 Subject: [PATCH 10/42] test the effect of not returning the past key values --- elk/extraction/llama/device_configs.py | 19 ++++++++++--------- tests/test_smoke_elicit.py | 4 ++-- tests/test_smoke_eval.py | 2 +- tests/test_truncated_eigh.py | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index 831f3384..98b92cdd 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -63,22 +63,23 @@ def instantiate_model_or_llama( dtype = torch.float32 else: dtype = "auto" - if is_llama_65b: + if not is_llama_65b: model = instantiate_model( cfg.model, - device_map=get_llama_65b_8bit_device_map( - first_device=first_device, second_device=device_config.second_device - ), - load_in_8bit=True, + device_map={"": first_device}, + load_in_8bit=cfg.int8, torch_dtype=dtype, - **kwargs, ) else: model = instantiate_model( cfg.model, - device_map={"": first_device}, - load_in_8bit=cfg.int8, + device_map=get_llama_65b_8bit_device_map( + first_device=first_device, second_device=device_config.second_device + ), + load_in_8bit=True, torch_dtype=dtype, - **kwargs, + # Testing to see if this fixes increased memory usage + # over time + use_cache=False, ) return model diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 7cf0e8c9..aed8e51e 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -7,7 +7,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 dataset_name = "imdb" elicit = Elicit( data=Extract( @@ -38,7 +38,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 dataset_name = "imdb" elicit = Elicit( data=Extract( diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index d58db6cd..683e718a 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -19,7 +19,7 @@ def setup_elicit( tmp_path: Path, dataset_name="imdb", model_path="sshleifer/tiny-gpt2", - min_mem=10 * 1024**2, + min_mem=10 * 1024 ** 2, is_ccs: bool = True, ) -> Elicit: """Setup elicit config for testing, execute elicit, and save output to tmp_path. diff --git a/tests/test_truncated_eigh.py b/tests/test_truncated_eigh.py index 5241f1c0..84a3de87 100644 --- a/tests/test_truncated_eigh.py +++ b/tests/test_truncated_eigh.py @@ -11,7 +11,7 @@ def random_symmetric_matrix(n: int, k: int) -> torch.Tensor: assert k <= n, "Rank k should be less than or equal to the matrix size n." # Generate random n x k matrix A with elements drawn from a uniform distribution - A = torch.rand(n, k) / k**0.5 + A = torch.rand(n, k) / k ** 0.5 # Create a diagonal matrix D with k eigenvalues evenly distributed around zero eigenvalues = torch.linspace(-1, 1, k) From 5fc1b5f979d16b1225258d008257588a9080d9f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 May 2023 16:11:24 +0000 Subject: [PATCH 11/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_smoke_elicit.py | 4 ++-- tests/test_smoke_eval.py | 2 +- tests/test_truncated_eigh.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index aed8e51e..7cf0e8c9 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -7,7 +7,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 dataset_name = "imdb" elicit = Elicit( data=Extract( @@ -38,7 +38,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 dataset_name = "imdb" elicit = Elicit( data=Extract( diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index 683e718a..d58db6cd 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -19,7 +19,7 @@ def setup_elicit( tmp_path: Path, dataset_name="imdb", model_path="sshleifer/tiny-gpt2", - min_mem=10 * 1024 ** 2, + min_mem=10 * 1024**2, is_ccs: bool = True, ) -> Elicit: """Setup elicit config for testing, execute elicit, and save output to tmp_path. diff --git a/tests/test_truncated_eigh.py b/tests/test_truncated_eigh.py index 84a3de87..5241f1c0 100644 --- a/tests/test_truncated_eigh.py +++ b/tests/test_truncated_eigh.py @@ -11,7 +11,7 @@ def random_symmetric_matrix(n: int, k: int) -> torch.Tensor: assert k <= n, "Rank k should be less than or equal to the matrix size n." # Generate random n x k matrix A with elements drawn from a uniform distribution - A = torch.rand(n, k) / k ** 0.5 + A = torch.rand(n, k) / k**0.5 # Create a diagonal matrix D with k eigenvalues evenly distributed around zero eigenvalues = torch.linspace(-1, 1, k) From c2ee397b54733e868808f0e4ae89e2b5b2cb0b04 Mon Sep 17 00:00:00 2001 From: James Chua Date: Tue, 2 May 2023 00:44:17 +0800 Subject: [PATCH 12/42] add kwargs --- elk/extraction/llama/device_configs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index 98b92cdd..bfeaf342 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -69,6 +69,7 @@ def instantiate_model_or_llama( device_map={"": first_device}, load_in_8bit=cfg.int8, torch_dtype=dtype, + **kwargs, ) else: model = instantiate_model( @@ -78,8 +79,6 @@ def instantiate_model_or_llama( ), load_in_8bit=True, torch_dtype=dtype, - # Testing to see if this fixes increased memory usage - # over time - use_cache=False, + **kwargs, ) return model From 6abddbf521f774305daff8c61f49412ff1bfd897 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 01:37:20 +0800 Subject: [PATCH 13/42] add device map 0 --- llama_device_map.py | 194 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 llama_device_map.py diff --git a/llama_device_map.py b/llama_device_map.py new file mode 100644 index 00000000..dd48cba4 --- /dev/null +++ b/llama_device_map.py @@ -0,0 +1,194 @@ +import argparse +import random +import time +from threading import Thread + +import torch +from accelerate import infer_auto_device_map, init_empty_weights +from tqdm import tqdm +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from elk.extraction import PromptConfig +from elk.extraction.extraction import ( + Extract, + temp_extract_input_ids_cached, +) +from elk.utils import instantiate_model +from llama_overwrite import overwrite_30b, overwrite_65b + + +def pad_tensors(tensors, device, pad_value=0): + max_len = max([t.size(-1) for t in tensors]) + padded_tensors = [] + attention_masks = [] + for _t in tensors: + t = _t.to(device) + pad_len = max_len - t.size(-1) + padded_tensor = torch.cat( + [torch.full((1, pad_len), pad_value, dtype=t.dtype, device=device), t], + dim=-1, + ) + attention_mask = torch.cat( + [ + torch.zeros((1, pad_len), dtype=torch.bool, device=device), + torch.ones_like(t), + ], + dim=-1, + ) + padded_tensors.append(padded_tensor) + attention_masks.append(attention_mask) + return torch.cat(padded_tensors, dim=0), torch.cat(attention_masks, dim=0) + + +def batch_ids( + input_ids_unbatched: list[torch.Tensor], batch_size: int +) -> list[tuple[torch.Tensor, torch.Tensor]]: + input_ids_unbatched_sorted = sorted(input_ids_unbatched, key=lambda x: x.size(-1)) + output = [] + input_buffer = [] + for input_id_args in input_ids_unbatched_sorted: + input_buffer.append(input_id_args) + if len(input_buffer) == batch_size: + batch_input_ids, attention_mask = pad_tensors(input_buffer, device=0) + output.append((batch_input_ids, attention_mask)) + input_buffer = [] + if input_buffer: # Process remaining input_ids in the buffer + batch_input_ids, attention_mask = pad_tensors(input_buffer, device=0) + output.append((batch_input_ids, attention_mask)) + return output + + +def inference_worker( + model, + batched_input_ids: list[tuple[torch.Tensor, torch.Tensor]], + use_tqdm=False, +): + batched_input_ids_use_tqdm: list[tuple[torch.Tensor, torch.Tensor]] = ( + tqdm(batched_input_ids, desc="Inference") if use_tqdm else batched_input_ids + ) + + for input_ids, attention_mask in batched_input_ids_use_tqdm: + with torch.no_grad(): + model(input_ids, attention_mask=attention_mask) + + +def main(args): + model_str = args.model + num_threads = args.threads + use_8bit = args.use_8bit + batch_size = args.batch_size + use_llama_override: bool = args.use_llama_override + print("Batch size:", batch_size) + + cfg = Extract(model=model_str, prompts=PromptConfig(datasets=["imdb"])) + + print("Extracting input ids...") + input_ids_list = temp_extract_input_ids_cached( + cfg=cfg, device="cpu", split_type="train" + ) + temp_extract_input_ids_cached(cfg=cfg, device="cpu", split_type="val") + # bring all the tensors to device 0 + + print("Number of input ids:", len(input_ids_list)) + device_tensors = [t.to(0) for t in input_ids_list] + device_tensors_batched = batch_ids(device_tensors, batch_size=batch_size) + # shuffle so we can tqdm more accurately + device_tensors_batched = random.sample( + device_tensors_batched, len(device_tensors_batched) + ) + print("Number of batches:", len(device_tensors_batched)) + + print("Instantiating model...") + used_dtype = torch.float16 if use_8bit else "auto" + + if use_8bit: + print("Using 8bit") + else: + print("Using 16bit") + with init_empty_weights(): + # Kinda dumb but you need to first insantiate on the CPU to get the layer class + model = instantiate_model(model_str, torch_dtype=used_dtype, device_map={"": 0}) + + # Hack to take into account that its 8bit + # min_gpu_mem * 2 if use_8bit else min_gpu_mem + dont_split = [LlamaDecoderLayer.__name__] + print("Dont split:", dont_split) + forty_gb = 40 * 1024 * 1024 * 1024 + autodevice_map = infer_auto_device_map( + model, no_split_module_classes=dont_split, max_memory={0: forty_gb, 1: forty_gb} + ) + print("Auto device map:", autodevice_map) + + device_map_override = ( + ( + overwrite_30b + if "30b" in model_str + else overwrite_65b + if "65b" in model_str + else {} + ) + if use_llama_override + else {} + ) + + # autodevice_map["lm_head"] = 0 + print("Device map overwrite:", device_map_override) + # Then instantiate on the GPU + model = instantiate_model( + model_str, + torch_dtype=used_dtype, + device_map=device_map_override or autodevice_map, + load_in_8bit=use_8bit, + ) + time_start = time.time() + time_end = time.time() + # in minutes + print("Compilation time:", (time_end - time_start) / 60) + + input_ids_chunks = [ + device_tensors_batched[i::num_threads] for i in range(num_threads) + ] + + threads = [] + for i in range(num_threads): + input_ids_queue = input_ids_chunks[i] + use_tqdm = i == 0 + t = Thread(target=inference_worker, args=(model, input_ids_queue, use_tqdm)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run inference with specified model") + parser.add_argument( + "--model", + type=str, + required=True, + help='Model string, e.g., "huggyllama/llama-13b"', + ) + parser.add_argument( + "--num_gpus", type=int, default=8, help="Number of GPUs to run on" + ) + default_bytes = 40 * 1024 * 1024 * 1024 + parser.add_argument( + "--min_gpu_mem", type=int, default=default_bytes, help="Min GPU memory per GPU" + ) + parser.add_argument( + "--threads", type=int, default=2, help="Number of threads to run" + ) + # store_true means that if you pass in --use_8bit, it will be True, otherwise False + parser.add_argument("--use_8bit", action="store_true", help="Whether to use 8bit") + parser.add_argument( + "--batch_size", type=int, default=8, help="Batch size for inference" + ) + # store_true means that if you pass in --use_llama_override, it will be True, otherwise False + parser.add_argument( + "--use_llama_override", + action="store_true", + help="Whether to use llama override", + ) + args = parser.parse_args() + + main(args) From 495354485dc315cfda31fb9e66084a4de081ef53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 May 2023 17:38:21 +0000 Subject: [PATCH 14/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- llama_device_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_device_map.py b/llama_device_map.py index dd48cba4..847d5362 100644 --- a/llama_device_map.py +++ b/llama_device_map.py @@ -5,6 +5,7 @@ import torch from accelerate import infer_auto_device_map, init_empty_weights +from llama_overwrite import overwrite_30b, overwrite_65b from tqdm import tqdm from transformers.models.llama.modeling_llama import LlamaDecoderLayer @@ -14,7 +15,6 @@ temp_extract_input_ids_cached, ) from elk.utils import instantiate_model -from llama_overwrite import overwrite_30b, overwrite_65b def pad_tensors(tensors, device, pad_value=0): From a2ceeb95852ab9672e545a5da4736107f975129d Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 01:42:34 +0800 Subject: [PATCH 15/42] fix 8bit mem --- llama_device_map.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/llama_device_map.py b/llama_device_map.py index 847d5362..dc43e865 100644 --- a/llama_device_map.py +++ b/llama_device_map.py @@ -113,8 +113,17 @@ def main(args): dont_split = [LlamaDecoderLayer.__name__] print("Dont split:", dont_split) forty_gb = 40 * 1024 * 1024 * 1024 + + max_memory = ( + {0: forty_gb, 1: forty_gb} + if not use_8bit + # this is a hack since infer_auto_device_map doesn't detect 8bit + # even if we load it in 8bit + # for big models, it'll start allocating to disk + else {0: forty_gb * 2, 1: forty_gb * 2} + ) autodevice_map = infer_auto_device_map( - model, no_split_module_classes=dont_split, max_memory={0: forty_gb, 1: forty_gb} + model, no_split_module_classes=dont_split, max_memory=max_memory ) print("Auto device map:", autodevice_map) From 5409b01cabbe639a2333f41c38b7fbd697be912c Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 03:57:40 +0800 Subject: [PATCH 16/42] make pyright happy --- elk/extraction/extraction.py | 35 ++--- elk/extraction/generator.py | 2 +- elk/extraction/llama/device_configs.py | 170 ++++++++++++++++++------- elk/utils/gpu_utils.py | 12 ++ llama_device_map.py | 10 +- tests/test_split_devices.py | 47 +++++++ 6 files changed, 211 insertions(+), 65 deletions(-) create mode 100644 tests/test_split_devices.py diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 0618cbe3..9a8b43b7 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -46,9 +46,9 @@ ) from .generator import _GeneratorBuilder from .llama.device_configs import ( - Llama65bDeviceConfig, - instantiate_model_or_llama, - select_devices_or_llama_65b_configs, + ModelDevices, + instantiate_model_with_devices, + select_devices_multi_gpus, ) from .prompt_loading import load_prompts @@ -147,21 +147,23 @@ def explode(self) -> list["Extract"]: def extract_hiddens( cfg: "Extract", *, - device_config: str | Llama65bDeviceConfig = "cpu", + device_config: ModelDevices, split_type: Literal["train", "val"] = "train", rank: int = 0, world_size: int = 1, ) -> Iterable[dict]: - device = ( + first_device = ( device_config - if not isinstance(device_config, Llama65bDeviceConfig) + if not isinstance(device_config, ModelDevices) else device_config.first_device ) """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" + is_verbose = rank == 0 + # Silence datasets logging messages from all but the first process - if rank != 0: + if not is_verbose: filterwarnings("ignore") logging.disable(logging.CRITICAL) @@ -171,7 +173,9 @@ def extract_hiddens( # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its # welcome message on every rank with redirect_stdout(None) if rank != 0 else nullcontext(): - model = instantiate_model_or_llama(cfg=cfg, device_config=device_config) + model = instantiate_model_with_devices( + cfg=cfg, device_config=device_config, is_verbose=is_verbose + ) tokenizer = instantiate_tokenizer( cfg.model, truncation_side="left", verbose=rank == 0 ) @@ -216,7 +220,7 @@ def extract_hiddens( num_variants, num_choices, model.config.hidden_size, - device=device, + device=first_device, dtype=torch.int16, ) for layer_idx in layer_indices @@ -224,7 +228,7 @@ def extract_hiddens( lm_logits = torch.empty( num_variants, num_choices, - device=device, + device=first_device, dtype=torch.float32, ) text_questions = [] @@ -249,7 +253,7 @@ def extract_hiddens( return_tensors="pt", text_target=target, # type: ignore[arg-type] truncation=True, - ).to(device) + ).to(first_device) input_ids = assert_type(Tensor, encoding.input_ids) if is_enc_dec: @@ -260,7 +264,7 @@ def extract_hiddens( # Don't include [CLS] and [SEP] in the answer add_special_tokens=False, return_tensors="pt", - ).to(device) + ).to(first_device) answer = assert_type(Tensor, encoding2.input_ids) input_ids = torch.cat([input_ids, answer], dim=-1) @@ -389,14 +393,15 @@ def extract( disable_cache: bool = False, highlight_color: Color = "cyan", num_gpus: int = -1, + gpus_per_model: int = 1, min_gpu_mem: int | None = None, split_type: Literal["train", "val", None] = None, ) -> DatasetDictWithName: """Extract hidden states from a model and return a `DatasetDict` containing them.""" info, features = hidden_features(cfg) - devices: Sequence[str | Llama65bDeviceConfig] = select_devices_or_llama_65b_configs( - model_name=cfg.model, num_gpus=num_gpus, min_memory=min_gpu_mem + devices: list[ModelDevices] = select_devices_multi_gpus( + gpus_per_model=gpus_per_model, num_gpus=num_gpus, min_memory=min_gpu_mem ) limits = cfg.max_examples splits = assert_type(SplitDict, info.splits) @@ -433,7 +438,7 @@ def extract( ), gen_kwargs=dict( cfg=[cfg] * len(devices), - device_config=devices, + devices=devices, rank=list(range(len(devices))), split_type=[ty] * len(devices), world_size=[len(devices)] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 53af2f5e..90b3eedd 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,7 +30,7 @@ def create_config_id( config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() - if k not in ("device_config", "rank", "world_size") + if k not in ("devices", "rank", "world_size") } return super().create_config_id(config_kwargs, custom_features) diff --git a/elk/extraction/llama/device_configs.py b/elk/extraction/llama/device_configs.py index bfeaf342..13bee8e5 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/extraction/llama/device_configs.py @@ -1,84 +1,166 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Type, Dict import torch +from accelerate import init_empty_weights, infer_auto_device_map +from torch import dtype +from torch.nn import Module from transformers import PreTrainedModel -from elk.extraction.llama.device_map import get_llama_65b_8bit_device_map from elk.utils import instantiate_model, select_usable_devices +from elk.utils.gpu_utils import get_available_memory_for_devices +from tests.test_split_devices import split_devices_into_model_devices if TYPE_CHECKING: from elk import Extract @dataclass -class Llama65bDeviceConfig: +class ModelDevices: + # The devices to instantiate a single model on first_device: str - second_device: str + other_devices: list[str] + @property + def is_single_gpu(self) -> bool: + return len(self.other_devices) == 0 -def select_devices_or_llama_65b_configs( - model_name: str, + @property + def used_devices(self) -> list[str]: + return [self.first_device] + self.other_devices + + +def select_devices_multi_gpus( + gpus_per_model: int, num_gpus: int, min_memory: int | None = None, -) -> Sequence[str | Llama65bDeviceConfig]: - if "llama-65b" not in model_name: - return select_usable_devices(num_gpus, min_memory=min_memory) +) -> list[ModelDevices]: + if gpus_per_model == 1: + devices = select_usable_devices(num_gpus, min_memory=min_memory) + return [ + ModelDevices(first_device=devices, other_devices=[]) for devices in devices + ] else: + # how many models can we create? + models_to_create = num_gpus // gpus_per_model print( - "You've selected a llama-65b model, which requires at least two GPUs." - "Each GPU must have at least 40 GiB of memory." + f"Will instantiate {models_to_create} models with {gpus_per_model} GPUs each" ) - print("Note that we will force the model to use 8-bit") - assert num_gpus >= 2, "llama-65b models require at least two GPUs" - # how many pairs of 2 gpus are specified? - num_pairs = num_gpus // 2 - print(f"Will create {num_pairs} llama workers ") - forty_gb = 42_949_672_960 - devices = select_usable_devices(num_gpus, min_memory=forty_gb) - # split the devices into pairs - configs = [] - while len(configs) < num_pairs: - first_device = devices.pop() - second_device = devices.pop() - configs.append( - Llama65bDeviceConfig( - first_device=first_device, second_device=second_device - ) - ) - print(f"Created {len(configs)} llama workers") - + devices = select_usable_devices(num_gpus, min_memory=min_memory) + configs = split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) + print(f"Models will be instantiated on {configs}") return configs -def instantiate_model_or_llama( - cfg: "Extract", device_config: str | Llama65bDeviceConfig, **kwargs +def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None: + """Get the class of the transformer layer used by the given model.""" + total_params = sum(p.numel() for p in model.parameters()) + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + module_params = sum(p.numel() for p in module.parameters()) + if module_params > total_params / 2: + type_of_cls = type(module[0]) + print(f"Found transformer layer of type {type_of_cls}") + return type_of_cls + + return None + + +def create_device_map( + model_str: str, + use_8bit: float, + torch_dtype: dtype | str, + model_devices: ModelDevices, + verbose: bool, +) -> dict[str, str]: + """Creates a device map for a model running on multiple GPUs.""" + with init_empty_weights(): + # Need to first instantiate an empty model to get the layer class + model = instantiate_model(model_str=model_str, torch_dtype=torch_dtype) + + # e.g. {"cuda:0": 16000, "cuda:1": 16000} + max_memory_all_devices: dict[str, int] = get_available_memory_for_devices() + # now let's get the available memory for the devices we want to use + used_devices = model_devices.used_devices + max_memory_used_devices: dict[str, int | float] = { + device: max_memory_all_devices[device] for device in used_devices + } + # Decrease the memory potentially used by the first device + # because we're going to create additional tensors on it + max_memory_used_devices[model_devices.first_device] = ( + max_memory_used_devices[model_devices.first_device] * 0.9 + ) + # If 8bit, multiply the memory by 2 + # This is because we instantiated our empty model in (probably) float16 + # We aren't able to instantiate an empty model in 8bit currently + max_memory_used_devices = ( + { + device: max_memory_used_devices[device] * 2 + for device in max_memory_used_devices + } + if use_8bit + else max_memory_used_devices + ) + + """ + Make sure that the transformer layer is not split + because that contains residual connections + See https://huggingface.co/docs/accelerate/usage_guides/big_modeling + Otherwise we get an error like this: + RuntimeError: Expected all tensors to be on the same device, + but found at least two devices, cuda:0 and cuda1 + """ + maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) + dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] + autodevice_map = infer_auto_device_map( + model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices + ) + if verbose: + print(f"Autodevice map: {autodevice_map}") + assert "disk" not in autodevice_map.values(), ( + f"Unable to fit the model {model} into the given memory for {used_devices}." + f" Try increasing gpus_per_model?" + ) + return autodevice_map + + +def instantiate_model_with_devices( + cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs ) -> PreTrainedModel: - is_llama_65b = isinstance(device_config, Llama65bDeviceConfig) + is_llama_65b = isinstance(device_config, ModelDevices) first_device = device_config.first_device if is_llama_65b else device_config - if cfg.int8 or is_llama_65b: + if cfg.int8: # Required by `bitsandbytes` - dtype = torch.float16 + torch_dtype = torch.float16 elif device_config == "cpu": - dtype = torch.float32 + torch_dtype = torch.float32 else: - dtype = "auto" - if not is_llama_65b: + torch_dtype = "auto" + if device_config.is_single_gpu: model = instantiate_model( cfg.model, device_map={"": first_device}, load_in_8bit=cfg.int8, - torch_dtype=dtype, + torch_dtype=torch_dtype, **kwargs, ) else: + device_map = create_device_map( + model_str=cfg.model, + use_8bit=cfg.int8, + torch_dtype=torch_dtype, + model_devices=device_config, + verbose=is_verbose, + ) model = instantiate_model( cfg.model, - device_map=get_llama_65b_8bit_device_map( - first_device=first_device, second_device=device_config.second_device - ), - load_in_8bit=True, - torch_dtype=dtype, + device_map=device_map, + load_in_8bit=cfg.int8, + torch_dtype=torch_dtype, **kwargs, ) return model diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index a4294298..9fc1c8df 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -4,6 +4,7 @@ import time import warnings from functools import cache +from typing import TypeVar, Mapping import pynvml import torch @@ -164,3 +165,14 @@ def select_usable_devices( print(f"Using {len(selection)} of {num_visible} GPUs: {selection}") return [f"cuda:{i}" for i in selection] + + +def get_available_memory_for_devices() -> dict[str, int]: + # Edited from get_max_memory of the accelerate library + for i in range(torch.cuda.device_count()): + _ = torch.tensor([0], device=i) + max_memory = { + f"cuda:{i}": torch.cuda.mem_get_info(i)[0] + for i in range(torch.cuda.device_count()) + } + return max_memory diff --git a/llama_device_map.py b/llama_device_map.py index dc43e865..b60e703a 100644 --- a/llama_device_map.py +++ b/llama_device_map.py @@ -5,7 +5,6 @@ import torch from accelerate import infer_auto_device_map, init_empty_weights -from llama_overwrite import overwrite_30b, overwrite_65b from tqdm import tqdm from transformers.models.llama.modeling_llama import LlamaDecoderLayer @@ -15,6 +14,7 @@ temp_extract_input_ids_cached, ) from elk.utils import instantiate_model +from llama_overwrite import overwrite_30b, overwrite_65b def pad_tensors(tensors, device, pad_value=0): @@ -106,7 +106,7 @@ def main(args): print("Using 16bit") with init_empty_weights(): # Kinda dumb but you need to first insantiate on the CPU to get the layer class - model = instantiate_model(model_str, torch_dtype=used_dtype, device_map={"": 0}) + model = instantiate_model(model_str, torch_dtype=used_dtype) # Hack to take into account that its 8bit # min_gpu_mem * 2 if use_8bit else min_gpu_mem @@ -117,9 +117,9 @@ def main(args): max_memory = ( {0: forty_gb, 1: forty_gb} if not use_8bit - # this is a hack since infer_auto_device_map doesn't detect 8bit - # even if we load it in 8bit - # for big models, it'll start allocating to disk + # this is a hack since infer_auto_device_map can't detect + # that we're using 8bit, since we inited an empty model + # to analyse. else {0: forty_gb * 2, 1: forty_gb * 2} ) autodevice_map = infer_auto_device_map( diff --git a/tests/test_split_devices.py b/tests/test_split_devices.py new file mode 100644 index 00000000..7d6fee62 --- /dev/null +++ b/tests/test_split_devices.py @@ -0,0 +1,47 @@ +from elk.extraction.llama.device_configs import ModelDevices + + +def split_devices_into_model_devices( + devices: list[str], gpus_per_model: int, models_to_create: int +) -> list[ModelDevices]: + assert len(devices) >= gpus_per_model * models_to_create + configs = [] + while len(configs) < models_to_create: + first_device = devices.pop(0) + other_devices = devices[: gpus_per_model - 1] + devices = devices[gpus_per_model - 1 :] + configs.append(ModelDevices(first_device, other_devices)) + return configs + + +def test_split_2_devices_1_gpu_per_model(): + devices = ["a", "b"] + gpus_per_model = 1 + models_to_create = 2 + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", []), ModelDevices("b", [])] + + +def test_split_4_devices_2_gpus_per_model(): + devices = ["a", "b", "c", "d"] + gpus_per_model = 2 + models_to_create = 2 + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", ["b"]), ModelDevices("c", ["d"])] + + +def test_split_7_devices_3_gpus_per_model(): + devices = ["a", "b", "c", "d", "e", "f", "g"] + gpus_per_model = 3 + models_to_create = 2 + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", ["b", "c"]), ModelDevices("d", ["e", "f"])] From 0db53b9616cee8716761d4c84f68b13e6c69c88a Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 04:05:38 +0800 Subject: [PATCH 17/42] implement multi gpu --- elk/extraction/extraction.py | 4 +- elk/extraction/llama/device_map.py | 117 ---------- .../device_configs.py => utils/multi_gpu.py} | 152 +++++++------ llama_device_map.py | 203 ------------------ tests/test_split_devices.py | 15 +- 5 files changed, 85 insertions(+), 406 deletions(-) delete mode 100644 elk/extraction/llama/device_map.py rename elk/{extraction/llama/device_configs.py => utils/multi_gpu.py} (91%) delete mode 100644 llama_device_map.py diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 9a8b43b7..755a8f36 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -45,10 +45,10 @@ parse_dataset_string, ) from .generator import _GeneratorBuilder -from .llama.device_configs import ( - ModelDevices, +from ..utils.multi_gpu import ( instantiate_model_with_devices, select_devices_multi_gpus, + ModelDevices, ) from .prompt_loading import load_prompts diff --git a/elk/extraction/llama/device_map.py b/elk/extraction/llama/device_map.py deleted file mode 100644 index 931a11b0..00000000 --- a/elk/extraction/llama/device_map.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -from accelerate import infer_auto_device_map, init_empty_weights - -from elk.utils import instantiate_model - - -def get_suggested_map(model_str: str, used_dtype: torch.dtype) -> dict[str, int]: - """Util function to get the suggested map for a given model string and dtype - Usually doesn't work out of the box, you'll need to manually - change the attention module - to the same device as the lm_head due to the residual connection. - """ - with init_empty_weights(): - # you need to first instantiate the model to get the suggested map - model = instantiate_model(model_str, torch_dtype=used_dtype) - suggested_map = infer_auto_device_map(model) - return suggested_map - - -def get_llama_65b_8bit_device_map( - first_device: str | torch.device, second_device: str | torch.device -) -> dict[str, str | torch.device]: - """ - This assumes that you are using 2 GPUs, with at least 40GB of memory each. - and that you are using 8bit - """ - return { - "model.embed_tokens": first_device, - "model.layers.0": first_device, - "model.layers.1": first_device, - "model.layers.2": first_device, - "model.layers.3": first_device, - "model.layers.4": first_device, - "model.layers.5": first_device, - "model.layers.6": first_device, - "model.layers.7": first_device, - "model.layers.8": first_device, - "model.layers.9": first_device, - "model.layers.10": first_device, - "model.layers.11": first_device, - "model.layers.12": first_device, - "model.layers.13": first_device, - "model.layers.14": first_device, - "model.layers.15": first_device, - "model.layers.16": first_device, - "model.layers.17": first_device, - "model.layers.18": first_device, - "model.layers.19": first_device, - "model.layers.20": first_device, - "model.layers.21": first_device, - "model.layers.22": first_device, - "model.layers.23": first_device, - "model.layers.24": first_device, - "model.layers.25": first_device, - "model.layers.26": first_device, - "model.layers.27.self_attn": first_device, - "model.layers.27.mlp.gate_proj": first_device, - "model.layers.27.mlp.down_proj": first_device, - "model.layers.27.mlp.up_proj": first_device, - "model.layers.27.mlp.act_fn": first_device, - "model.layers.27.input_layernorm": first_device, - "model.layers.27.post_attention_layernorm": first_device, - "model.layers.28": first_device, - "model.layers.29": first_device, - "model.layers.30": first_device, - "model.layers.31": first_device, - "model.layers.32": first_device, - "model.layers.33": first_device, - "model.layers.34": second_device, - "model.layers.35": second_device, - "model.layers.36": second_device, - "model.layers.37": second_device, - "model.layers.38": second_device, - "model.layers.39": second_device, - "model.layers.40": second_device, - "model.layers.41": second_device, - "model.layers.42": second_device, - "model.layers.43": second_device, - "model.layers.44": second_device, - "model.layers.45": second_device, - "model.layers.46": second_device, - "model.layers.47": second_device, - "model.layers.48": second_device, - "model.layers.49": second_device, - "model.layers.50": second_device, - "model.layers.51": second_device, - "model.layers.52": second_device, - "model.layers.53": second_device, - "model.layers.54": second_device, - "model.layers.55": second_device, - "model.layers.56": second_device, - "model.layers.57": second_device, - "model.layers.58": second_device, - "model.layers.59": second_device, - "model.layers.60": second_device, - "model.layers.61": second_device, - "model.layers.62": second_device, - "model.layers.63": second_device, - "model.layers.64": second_device, - "model.layers.65": second_device, - "model.layers.66": second_device, - "model.layers.67": second_device, - "model.layers.68": second_device, - "model.layers.69": second_device, - "model.layers.70": second_device, - "model.layers.71": second_device, - "model.layers.72": second_device, - "model.layers.73": second_device, - "model.layers.74": second_device, - "model.layers.75": second_device, - "model.layers.76": second_device, - "model.layers.77": second_device, - "model.layers.78": second_device, - "model.layers.79": second_device, - "model.norm": second_device, - "lm_head": first_device, - } diff --git a/elk/extraction/llama/device_configs.py b/elk/utils/multi_gpu.py similarity index 91% rename from elk/extraction/llama/device_configs.py rename to elk/utils/multi_gpu.py index 13bee8e5..50dfbf36 100644 --- a/elk/extraction/llama/device_configs.py +++ b/elk/utils/multi_gpu.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Type, Dict +from typing import Type, TYPE_CHECKING import torch from accelerate import init_empty_weights, infer_auto_device_map @@ -9,7 +9,6 @@ from elk.utils import instantiate_model, select_usable_devices from elk.utils.gpu_utils import get_available_memory_for_devices -from tests.test_split_devices import split_devices_into_model_devices if TYPE_CHECKING: from elk import Extract @@ -30,44 +29,42 @@ def used_devices(self) -> list[str]: return [self.first_device] + self.other_devices -def select_devices_multi_gpus( - gpus_per_model: int, - num_gpus: int, - min_memory: int | None = None, -) -> list[ModelDevices]: - if gpus_per_model == 1: - devices = select_usable_devices(num_gpus, min_memory=min_memory) - return [ - ModelDevices(first_device=devices, other_devices=[]) for devices in devices - ] +def instantiate_model_with_devices( + cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs +) -> PreTrainedModel: + is_llama_65b = isinstance(device_config, ModelDevices) + first_device = device_config.first_device if is_llama_65b else device_config + if cfg.int8: + # Required by `bitsandbytes` + torch_dtype = torch.float16 + elif device_config == "cpu": + torch_dtype = torch.float32 else: - # how many models can we create? - models_to_create = num_gpus // gpus_per_model - print( - f"Will instantiate {models_to_create} models with {gpus_per_model} GPUs each" + torch_dtype = "auto" + if device_config.is_single_gpu: + model = instantiate_model( + cfg.model, + device_map={"": first_device}, + load_in_8bit=cfg.int8, + torch_dtype=torch_dtype, + **kwargs, ) - devices = select_usable_devices(num_gpus, min_memory=min_memory) - configs = split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, + else: + device_map = create_device_map( + model_str=cfg.model, + use_8bit=cfg.int8, + torch_dtype=torch_dtype, + model_devices=device_config, + verbose=is_verbose, ) - print(f"Models will be instantiated on {configs}") - return configs - - -def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None: - """Get the class of the transformer layer used by the given model.""" - total_params = sum(p.numel() for p in model.parameters()) - for module in model.modules(): - if isinstance(module, torch.nn.ModuleList): - module_params = sum(p.numel() for p in module.parameters()) - if module_params > total_params / 2: - type_of_cls = type(module[0]) - print(f"Found transformer layer of type {type_of_cls}") - return type_of_cls - - return None + model = instantiate_model( + cfg.model, + device_map=device_map, + load_in_8bit=cfg.int8, + torch_dtype=torch_dtype, + **kwargs, + ) + return model def create_device_map( @@ -128,39 +125,54 @@ def create_device_map( return autodevice_map -def instantiate_model_with_devices( - cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs -) -> PreTrainedModel: - is_llama_65b = isinstance(device_config, ModelDevices) - first_device = device_config.first_device if is_llama_65b else device_config - if cfg.int8: - # Required by `bitsandbytes` - torch_dtype = torch.float16 - elif device_config == "cpu": - torch_dtype = torch.float32 - else: - torch_dtype = "auto" - if device_config.is_single_gpu: - model = instantiate_model( - cfg.model, - device_map={"": first_device}, - load_in_8bit=cfg.int8, - torch_dtype=torch_dtype, - **kwargs, - ) +def select_devices_multi_gpus( + gpus_per_model: int, + num_gpus: int, + min_memory: int | None = None, +) -> list[ModelDevices]: + if gpus_per_model == 1: + devices = select_usable_devices(num_gpus, min_memory=min_memory) + return [ + ModelDevices(first_device=devices, other_devices=[]) for devices in devices + ] else: - device_map = create_device_map( - model_str=cfg.model, - use_8bit=cfg.int8, - torch_dtype=torch_dtype, - model_devices=device_config, - verbose=is_verbose, + # how many models can we create? + models_to_create = num_gpus // gpus_per_model + print( + f"Will instantiate {models_to_create} models with {gpus_per_model} GPUs each" ) - model = instantiate_model( - cfg.model, - device_map=device_map, - load_in_8bit=cfg.int8, - torch_dtype=torch_dtype, - **kwargs, + devices = select_usable_devices(num_gpus, min_memory=min_memory) + configs = split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, ) - return model + print(f"Models will be instantiated on {configs}") + return configs + + +def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None: + """Get the class of the transformer layer used by the given model.""" + total_params = sum(p.numel() for p in model.parameters()) + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + module_params = sum(p.numel() for p in module.parameters()) + if module_params > total_params / 2: + type_of_cls = type(module[0]) + print(f"Found transformer layer of type {type_of_cls}") + return type_of_cls + + return None + + +def split_devices_into_model_devices( + devices: list[str], gpus_per_model: int, models_to_create: int +) -> list[ModelDevices]: + assert len(devices) >= gpus_per_model * models_to_create + configs = [] + while len(configs) < models_to_create: + first_device = devices.pop(0) + other_devices = devices[: gpus_per_model - 1] + devices = devices[gpus_per_model - 1 :] + configs.append(ModelDevices(first_device, other_devices)) + return configs diff --git a/llama_device_map.py b/llama_device_map.py deleted file mode 100644 index b60e703a..00000000 --- a/llama_device_map.py +++ /dev/null @@ -1,203 +0,0 @@ -import argparse -import random -import time -from threading import Thread - -import torch -from accelerate import infer_auto_device_map, init_empty_weights -from tqdm import tqdm -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -from elk.extraction import PromptConfig -from elk.extraction.extraction import ( - Extract, - temp_extract_input_ids_cached, -) -from elk.utils import instantiate_model -from llama_overwrite import overwrite_30b, overwrite_65b - - -def pad_tensors(tensors, device, pad_value=0): - max_len = max([t.size(-1) for t in tensors]) - padded_tensors = [] - attention_masks = [] - for _t in tensors: - t = _t.to(device) - pad_len = max_len - t.size(-1) - padded_tensor = torch.cat( - [torch.full((1, pad_len), pad_value, dtype=t.dtype, device=device), t], - dim=-1, - ) - attention_mask = torch.cat( - [ - torch.zeros((1, pad_len), dtype=torch.bool, device=device), - torch.ones_like(t), - ], - dim=-1, - ) - padded_tensors.append(padded_tensor) - attention_masks.append(attention_mask) - return torch.cat(padded_tensors, dim=0), torch.cat(attention_masks, dim=0) - - -def batch_ids( - input_ids_unbatched: list[torch.Tensor], batch_size: int -) -> list[tuple[torch.Tensor, torch.Tensor]]: - input_ids_unbatched_sorted = sorted(input_ids_unbatched, key=lambda x: x.size(-1)) - output = [] - input_buffer = [] - for input_id_args in input_ids_unbatched_sorted: - input_buffer.append(input_id_args) - if len(input_buffer) == batch_size: - batch_input_ids, attention_mask = pad_tensors(input_buffer, device=0) - output.append((batch_input_ids, attention_mask)) - input_buffer = [] - if input_buffer: # Process remaining input_ids in the buffer - batch_input_ids, attention_mask = pad_tensors(input_buffer, device=0) - output.append((batch_input_ids, attention_mask)) - return output - - -def inference_worker( - model, - batched_input_ids: list[tuple[torch.Tensor, torch.Tensor]], - use_tqdm=False, -): - batched_input_ids_use_tqdm: list[tuple[torch.Tensor, torch.Tensor]] = ( - tqdm(batched_input_ids, desc="Inference") if use_tqdm else batched_input_ids - ) - - for input_ids, attention_mask in batched_input_ids_use_tqdm: - with torch.no_grad(): - model(input_ids, attention_mask=attention_mask) - - -def main(args): - model_str = args.model - num_threads = args.threads - use_8bit = args.use_8bit - batch_size = args.batch_size - use_llama_override: bool = args.use_llama_override - print("Batch size:", batch_size) - - cfg = Extract(model=model_str, prompts=PromptConfig(datasets=["imdb"])) - - print("Extracting input ids...") - input_ids_list = temp_extract_input_ids_cached( - cfg=cfg, device="cpu", split_type="train" - ) + temp_extract_input_ids_cached(cfg=cfg, device="cpu", split_type="val") - # bring all the tensors to device 0 - - print("Number of input ids:", len(input_ids_list)) - device_tensors = [t.to(0) for t in input_ids_list] - device_tensors_batched = batch_ids(device_tensors, batch_size=batch_size) - # shuffle so we can tqdm more accurately - device_tensors_batched = random.sample( - device_tensors_batched, len(device_tensors_batched) - ) - print("Number of batches:", len(device_tensors_batched)) - - print("Instantiating model...") - used_dtype = torch.float16 if use_8bit else "auto" - - if use_8bit: - print("Using 8bit") - else: - print("Using 16bit") - with init_empty_weights(): - # Kinda dumb but you need to first insantiate on the CPU to get the layer class - model = instantiate_model(model_str, torch_dtype=used_dtype) - - # Hack to take into account that its 8bit - # min_gpu_mem * 2 if use_8bit else min_gpu_mem - dont_split = [LlamaDecoderLayer.__name__] - print("Dont split:", dont_split) - forty_gb = 40 * 1024 * 1024 * 1024 - - max_memory = ( - {0: forty_gb, 1: forty_gb} - if not use_8bit - # this is a hack since infer_auto_device_map can't detect - # that we're using 8bit, since we inited an empty model - # to analyse. - else {0: forty_gb * 2, 1: forty_gb * 2} - ) - autodevice_map = infer_auto_device_map( - model, no_split_module_classes=dont_split, max_memory=max_memory - ) - print("Auto device map:", autodevice_map) - - device_map_override = ( - ( - overwrite_30b - if "30b" in model_str - else overwrite_65b - if "65b" in model_str - else {} - ) - if use_llama_override - else {} - ) - - # autodevice_map["lm_head"] = 0 - print("Device map overwrite:", device_map_override) - # Then instantiate on the GPU - model = instantiate_model( - model_str, - torch_dtype=used_dtype, - device_map=device_map_override or autodevice_map, - load_in_8bit=use_8bit, - ) - time_start = time.time() - time_end = time.time() - # in minutes - print("Compilation time:", (time_end - time_start) / 60) - - input_ids_chunks = [ - device_tensors_batched[i::num_threads] for i in range(num_threads) - ] - - threads = [] - for i in range(num_threads): - input_ids_queue = input_ids_chunks[i] - use_tqdm = i == 0 - t = Thread(target=inference_worker, args=(model, input_ids_queue, use_tqdm)) - threads.append(t) - t.start() - - for t in threads: - t.join() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run inference with specified model") - parser.add_argument( - "--model", - type=str, - required=True, - help='Model string, e.g., "huggyllama/llama-13b"', - ) - parser.add_argument( - "--num_gpus", type=int, default=8, help="Number of GPUs to run on" - ) - default_bytes = 40 * 1024 * 1024 * 1024 - parser.add_argument( - "--min_gpu_mem", type=int, default=default_bytes, help="Min GPU memory per GPU" - ) - parser.add_argument( - "--threads", type=int, default=2, help="Number of threads to run" - ) - # store_true means that if you pass in --use_8bit, it will be True, otherwise False - parser.add_argument("--use_8bit", action="store_true", help="Whether to use 8bit") - parser.add_argument( - "--batch_size", type=int, default=8, help="Batch size for inference" - ) - # store_true means that if you pass in --use_llama_override, it will be True, otherwise False - parser.add_argument( - "--use_llama_override", - action="store_true", - help="Whether to use llama override", - ) - args = parser.parse_args() - - main(args) diff --git a/tests/test_split_devices.py b/tests/test_split_devices.py index 7d6fee62..da85b051 100644 --- a/tests/test_split_devices.py +++ b/tests/test_split_devices.py @@ -1,17 +1,4 @@ -from elk.extraction.llama.device_configs import ModelDevices - - -def split_devices_into_model_devices( - devices: list[str], gpus_per_model: int, models_to_create: int -) -> list[ModelDevices]: - assert len(devices) >= gpus_per_model * models_to_create - configs = [] - while len(configs) < models_to_create: - first_device = devices.pop(0) - other_devices = devices[: gpus_per_model - 1] - devices = devices[gpus_per_model - 1 :] - configs.append(ModelDevices(first_device, other_devices)) - return configs +from elk.utils.multi_gpu import ModelDevices, split_devices_into_model_devices def test_split_2_devices_1_gpu_per_model(): From 5ee3c3af15844dd20157ff48f78d892b8ba22cc7 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 04:11:03 +0800 Subject: [PATCH 18/42] add cli --- README.md | 11 ++++++++++ elk/extraction/extraction.py | 10 ++++----- elk/run.py | 2 ++ elk/utils/gpu_utils.py | 1 - elk/utils/multi_gpu.py | 7 ++++--- tests/test_smoke_elicit.py | 4 ++-- tests/test_smoke_eval.py | 2 +- tests/test_split_devices.py | 39 ++++++++++++++++++++++-------------- tests/test_truncated_eigh.py | 2 +- 9 files changed, 50 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index ac950788..2950dff9 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,17 @@ The following runs `elicit` on the Cartesian product of the listed models and da elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled ``` +## Running big models +For big models that cannot fit on a single gpu, you'll need to use multiple +gpus per model. + +This is an example to run a single 8bit llama-65b model on 2 A40s that have +~50 GB of memory each. + +``` +elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 +``` + ## Caching The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`. diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index d386ab26..57e533e5 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -40,16 +40,16 @@ select_split, select_train_val_splits, ) +from ..utils.multi_gpu import ( + ModelDevices, + instantiate_model_with_devices, + select_devices_multi_gpus, +) from .dataset_name import ( DatasetDictWithName, parse_dataset_string, ) from .generator import _GeneratorBuilder -from ..utils.multi_gpu import ( - instantiate_model_with_devices, - select_devices_multi_gpus, - ModelDevices, -) from .prompt_loading import load_prompts diff --git a/elk/run.py b/elk/run.py index 65573895..057d3a7c 100644 --- a/elk/run.py +++ b/elk/run.py @@ -46,6 +46,7 @@ class Run(ABC, Serializable): num_gpus: int = -1 out_dir: Path | None = None disable_cache: bool = field(default=False, to_dict=False) + gpus_per_model: int = 1 def execute( self, @@ -58,6 +59,7 @@ def execute( disable_cache=self.disable_cache, highlight_color=highlight_color, num_gpus=self.num_gpus, + gpus_per_model=self.gpus_per_model, min_gpu_mem=self.min_gpu_mem, split_type=split_type, ) diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 9fc1c8df..4938faa9 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -4,7 +4,6 @@ import time import warnings from functools import cache -from typing import TypeVar, Mapping import pynvml import torch diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 50dfbf36..6d857b3c 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -1,8 +1,8 @@ from dataclasses import dataclass -from typing import Type, TYPE_CHECKING +from typing import TYPE_CHECKING, Type import torch -from accelerate import init_empty_weights, infer_auto_device_map +from accelerate import infer_auto_device_map, init_empty_weights from torch import dtype from torch.nn import Module from transformers import PreTrainedModel @@ -139,7 +139,8 @@ def select_devices_multi_gpus( # how many models can we create? models_to_create = num_gpus // gpus_per_model print( - f"Will instantiate {models_to_create} models with {gpus_per_model} GPUs each" + f"Allocating devices for {models_to_create} models with {gpus_per_model}" + f" GPUs each" ) devices = select_usable_devices(num_gpus, min_memory=min_memory) configs = split_devices_into_model_devices( diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 7cf0e8c9..aed8e51e 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -7,7 +7,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 dataset_name = "imdb" elicit = Elicit( data=Extract( @@ -38,7 +38,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 dataset_name = "imdb" elicit = Elicit( data=Extract( diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index d58db6cd..683e718a 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -19,7 +19,7 @@ def setup_elicit( tmp_path: Path, dataset_name="imdb", model_path="sshleifer/tiny-gpt2", - min_mem=10 * 1024**2, + min_mem=10 * 1024 ** 2, is_ccs: bool = True, ) -> Elicit: """Setup elicit config for testing, execute elicit, and save output to tmp_path. diff --git a/tests/test_split_devices.py b/tests/test_split_devices.py index da85b051..8168ebd8 100644 --- a/tests/test_split_devices.py +++ b/tests/test_split_devices.py @@ -5,30 +5,39 @@ def test_split_2_devices_1_gpu_per_model(): devices = ["a", "b"] gpus_per_model = 1 models_to_create = 2 - assert split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, - ) == [ModelDevices("a", []), ModelDevices("b", [])] + assert ( + split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) + == [ModelDevices("a", []), ModelDevices("b", [])] + ) def test_split_4_devices_2_gpus_per_model(): devices = ["a", "b", "c", "d"] gpus_per_model = 2 models_to_create = 2 - assert split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, - ) == [ModelDevices("a", ["b"]), ModelDevices("c", ["d"])] + assert ( + split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) + == [ModelDevices("a", ["b"]), ModelDevices("c", ["d"])] + ) def test_split_7_devices_3_gpus_per_model(): devices = ["a", "b", "c", "d", "e", "f", "g"] gpus_per_model = 3 models_to_create = 2 - assert split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, - ) == [ModelDevices("a", ["b", "c"]), ModelDevices("d", ["e", "f"])] + assert ( + split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) + == [ModelDevices("a", ["b", "c"]), ModelDevices("d", ["e", "f"])] + ) diff --git a/tests/test_truncated_eigh.py b/tests/test_truncated_eigh.py index 5241f1c0..84a3de87 100644 --- a/tests/test_truncated_eigh.py +++ b/tests/test_truncated_eigh.py @@ -11,7 +11,7 @@ def random_symmetric_matrix(n: int, k: int) -> torch.Tensor: assert k <= n, "Rank k should be less than or equal to the matrix size n." # Generate random n x k matrix A with elements drawn from a uniform distribution - A = torch.rand(n, k) / k**0.5 + A = torch.rand(n, k) / k ** 0.5 # Create a diagonal matrix D with k eigenvalues evenly distributed around zero eigenvalues = torch.linspace(-1, 1, k) From 8d6279cf1feec1d8ecd5bd867d1ba6d4d4ee943a Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 04:17:38 +0800 Subject: [PATCH 19/42] redirect only later --- elk/extraction/extraction.py | 16 +++++++--------- elk/utils/multi_gpu.py | 33 +++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 57e533e5..238208a7 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -170,15 +170,13 @@ def extract_hiddens( ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." - # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its - # welcome message on every rank - with redirect_stdout(None) if rank != 0 else nullcontext(): - model = instantiate_model_with_devices( - cfg=cfg, device_config=device_config, is_verbose=is_verbose - ) - tokenizer = instantiate_tokenizer( - cfg.model, truncation_side="left", verbose=rank == 0 - ) + + model = instantiate_model_with_devices( + cfg=cfg, device_config=device_config, is_verbose=is_verbose + ) + tokenizer = instantiate_tokenizer( + cfg.model, truncation_side="left", verbose=is_verbose + ) is_enc_dec = model.config.is_encoder_decoder if is_enc_dec and cfg.use_encoder_states: diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 6d857b3c..1875180b 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -1,3 +1,4 @@ +from contextlib import redirect_stdout, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Type @@ -42,13 +43,16 @@ def instantiate_model_with_devices( else: torch_dtype = "auto" if device_config.is_single_gpu: - model = instantiate_model( - cfg.model, - device_map={"": first_device}, - load_in_8bit=cfg.int8, - torch_dtype=torch_dtype, - **kwargs, - ) + # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its + # welcome message on every rank + with redirect_stdout(None) if not is_verbose else nullcontext(): + model = instantiate_model( + cfg.model, + device_map={"": first_device}, + load_in_8bit=cfg.int8, + torch_dtype=torch_dtype, + **kwargs, + ) else: device_map = create_device_map( model_str=cfg.model, @@ -57,13 +61,14 @@ def instantiate_model_with_devices( model_devices=device_config, verbose=is_verbose, ) - model = instantiate_model( - cfg.model, - device_map=device_map, - load_in_8bit=cfg.int8, - torch_dtype=torch_dtype, - **kwargs, - ) + with redirect_stdout(None) if not is_verbose else nullcontext(): + model = instantiate_model( + cfg.model, + device_map=device_map, + load_in_8bit=cfg.int8, + torch_dtype=torch_dtype, + **kwargs, + ) return model From 47529b6f7cd3f6b86a135a95ed264390887fd8f7 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 04:25:54 +0800 Subject: [PATCH 20/42] add logs and remove llama --- elk/extraction/extraction.py | 1 - elk/utils/multi_gpu.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 238208a7..aa65270a 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,7 +1,6 @@ """Functions for extracting the hidden states of a model.""" import logging import os -from contextlib import nullcontext, redirect_stdout from dataclasses import InitVar, dataclass, replace from itertools import islice, zip_longest from typing import Any, Iterable, Literal diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 1875180b..79957145 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -1,4 +1,4 @@ -from contextlib import redirect_stdout, nullcontext +from contextlib import nullcontext, redirect_stdout from dataclasses import dataclass from typing import TYPE_CHECKING, Type @@ -33,8 +33,7 @@ def used_devices(self) -> list[str]: def instantiate_model_with_devices( cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs ) -> PreTrainedModel: - is_llama_65b = isinstance(device_config, ModelDevices) - first_device = device_config.first_device if is_llama_65b else device_config + first_device = device_config.first_device if cfg.int8: # Required by `bitsandbytes` torch_dtype = torch.float16 @@ -54,6 +53,8 @@ def instantiate_model_with_devices( **kwargs, ) else: + if is_verbose: + print(f"Instantiating the model on multiple GPUs: {device_config.used_devices}") device_map = create_device_map( model_str=cfg.model, use_8bit=cfg.int8, From 4a49aa09a523c11edc737dc94d5fe69e9e7758fb Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 04:29:07 +0800 Subject: [PATCH 21/42] fix keyword --- elk/extraction/extraction.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index aa65270a..8b2c367e 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -146,15 +146,15 @@ def explode(self) -> list["Extract"]: def extract_hiddens( cfg: "Extract", *, - device_config: ModelDevices, + devices: ModelDevices, split_type: Literal["train", "val"] = "train", rank: int = 0, world_size: int = 1, ) -> Iterable[dict]: first_device = ( - device_config - if not isinstance(device_config, ModelDevices) - else device_config.first_device + devices + if not isinstance(devices, ModelDevices) + else devices.first_device ) """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -171,7 +171,7 @@ def extract_hiddens( model = instantiate_model_with_devices( - cfg=cfg, device_config=device_config, is_verbose=is_verbose + cfg=cfg, device_config=devices, is_verbose=is_verbose ) tokenizer = instantiate_tokenizer( cfg.model, truncation_side="left", verbose=is_verbose From 91a06b66c7b6a080db2d7cf7cf83e1474f55eacb Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 13:21:54 +0800 Subject: [PATCH 22/42] try out lm head --- README.md | 2 +- elk/utils/multi_gpu.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2950dff9..f78cf710 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ This is an example to run a single 8bit llama-65b model on 2 A40s that have ~50 GB of memory each. ``` -elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 +elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 --8int ``` ## Caching diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 79957145..37d3b6de 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -54,7 +54,9 @@ def instantiate_model_with_devices( ) else: if is_verbose: - print(f"Instantiating the model on multiple GPUs: {device_config.used_devices}") + print( + f"Instantiating the model on multiple GPUs: {device_config.used_devices}" + ) device_map = create_device_map( model_str=cfg.model, use_8bit=cfg.int8, @@ -80,6 +82,7 @@ def create_device_map( model_devices: ModelDevices, verbose: bool, ) -> dict[str, str]: + # TODO: Run this before allocating workers """Creates a device map for a model running on multiple GPUs.""" with init_empty_weights(): # Need to first instantiate an empty model to get the layer class @@ -122,6 +125,10 @@ def create_device_map( autodevice_map = infer_auto_device_map( model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices ) + # TODO: remove this which we just testing out + # explicitly set the lm head of autodevice_map to the first device + autodevice_map["lm_head"] = model_devices.first_device + if verbose: print(f"Autodevice map: {autodevice_map}") assert "disk" not in autodevice_map.values(), ( From d74c9b6266fadb025e3764bb15bdff1407e4355a Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 13:30:34 +0800 Subject: [PATCH 23/42] shift it to 0.8 instead --- elk/utils/multi_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 37d3b6de..df9d05d0 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -98,7 +98,7 @@ def create_device_map( # Decrease the memory potentially used by the first device # because we're going to create additional tensors on it max_memory_used_devices[model_devices.first_device] = ( - max_memory_used_devices[model_devices.first_device] * 0.9 + max_memory_used_devices[model_devices.first_device] * 0.8 ) # If 8bit, multiply the memory by 2 # This is because we instantiated our empty model in (probably) float16 From e6eb9c1729be797ca22d47fb02c11e90ecf274a2 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 13:38:29 +0800 Subject: [PATCH 24/42] try hardcoded map --- elk/utils/llama_utils.py | 101 +++++++++++++++++++++++++++++++++++++++ elk/utils/multi_gpu.py | 9 +++- 2 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 elk/utils/llama_utils.py diff --git a/elk/utils/llama_utils.py b/elk/utils/llama_utils.py new file mode 100644 index 00000000..38ec190a --- /dev/null +++ b/elk/utils/llama_utils.py @@ -0,0 +1,101 @@ +import torch + + +def get_llama_65b_8bit_device_map( + first_device: str | torch.device, second_device: str | torch.device +) -> dict[str, str | torch.device]: + """ + This assumes that you are using 2 GPUs, with at least 40GB of memory each. + and that you are using 8bit + """ + return { + "model.embed_tokens": first_device, + "model.layers.0": first_device, + "model.layers.1": first_device, + "model.layers.2": first_device, + "model.layers.3": first_device, + "model.layers.4": first_device, + "model.layers.5": first_device, + "model.layers.6": first_device, + "model.layers.7": first_device, + "model.layers.8": first_device, + "model.layers.9": first_device, + "model.layers.10": first_device, + "model.layers.11": first_device, + "model.layers.12": first_device, + "model.layers.13": first_device, + "model.layers.14": first_device, + "model.layers.15": first_device, + "model.layers.16": first_device, + "model.layers.17": first_device, + "model.layers.18": first_device, + "model.layers.19": first_device, + "model.layers.20": first_device, + "model.layers.21": first_device, + "model.layers.22": first_device, + "model.layers.23": first_device, + "model.layers.24": first_device, + "model.layers.25": first_device, + "model.layers.26": first_device, + "model.layers.27.self_attn": first_device, + "model.layers.27.mlp.gate_proj": first_device, + "model.layers.27.mlp.down_proj": first_device, + "model.layers.27.mlp.up_proj": first_device, + "model.layers.27.mlp.act_fn": first_device, + "model.layers.27.input_layernorm": first_device, + "model.layers.27.post_attention_layernorm": first_device, + "model.layers.28": first_device, + "model.layers.29": first_device, + "model.layers.30": first_device, + "model.layers.31": first_device, + "model.layers.32": first_device, + "model.layers.33": first_device, + "model.layers.34": second_device, + "model.layers.35": second_device, + "model.layers.36": second_device, + "model.layers.37": second_device, + "model.layers.38": second_device, + "model.layers.39": second_device, + "model.layers.40": second_device, + "model.layers.41": second_device, + "model.layers.42": second_device, + "model.layers.43": second_device, + "model.layers.44": second_device, + "model.layers.45": second_device, + "model.layers.46": second_device, + "model.layers.47": second_device, + "model.layers.48": second_device, + "model.layers.49": second_device, + "model.layers.50": second_device, + "model.layers.51": second_device, + "model.layers.52": second_device, + "model.layers.53": second_device, + "model.layers.54": second_device, + "model.layers.55": second_device, + "model.layers.56": second_device, + "model.layers.57": second_device, + "model.layers.58": second_device, + "model.layers.59": second_device, + "model.layers.60": second_device, + "model.layers.61": second_device, + "model.layers.62": second_device, + "model.layers.63": second_device, + "model.layers.64": second_device, + "model.layers.65": second_device, + "model.layers.66": second_device, + "model.layers.67": second_device, + "model.layers.68": second_device, + "model.layers.69": second_device, + "model.layers.70": second_device, + "model.layers.71": second_device, + "model.layers.72": second_device, + "model.layers.73": second_device, + "model.layers.74": second_device, + "model.layers.75": second_device, + "model.layers.76": second_device, + "model.layers.77": second_device, + "model.layers.78": second_device, + "model.layers.79": second_device, + "model.norm": second_device, + "lm_head": second_device, + } diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index df9d05d0..216d0e78 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -10,6 +10,7 @@ from elk.utils import instantiate_model, select_usable_devices from elk.utils.gpu_utils import get_available_memory_for_devices +from elk.utils.llama_utils import get_llama_65b_8bit_device_map if TYPE_CHECKING: from elk import Extract @@ -122,8 +123,12 @@ def create_device_map( """ maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] - autodevice_map = infer_auto_device_map( - model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices + # autodevice_map = infer_auto_device_map( + # model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices + # ) + autodevice_map = get_llama_65b_8bit_device_map( + first_device=model_devices.first_device, + second_device=model_devices.other_devices[0], ) # TODO: remove this which we just testing out # explicitly set the lm head of autodevice_map to the first device From 06b1a1104408b1a8a2a850290c98b6163e287e30 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 13:42:01 +0800 Subject: [PATCH 25/42] decrease further for gpu 1 --- elk/utils/llama_utils.py | 101 --------------------------------------- elk/utils/multi_gpu.py | 28 ++++------- 2 files changed, 9 insertions(+), 120 deletions(-) delete mode 100644 elk/utils/llama_utils.py diff --git a/elk/utils/llama_utils.py b/elk/utils/llama_utils.py deleted file mode 100644 index 38ec190a..00000000 --- a/elk/utils/llama_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch - - -def get_llama_65b_8bit_device_map( - first_device: str | torch.device, second_device: str | torch.device -) -> dict[str, str | torch.device]: - """ - This assumes that you are using 2 GPUs, with at least 40GB of memory each. - and that you are using 8bit - """ - return { - "model.embed_tokens": first_device, - "model.layers.0": first_device, - "model.layers.1": first_device, - "model.layers.2": first_device, - "model.layers.3": first_device, - "model.layers.4": first_device, - "model.layers.5": first_device, - "model.layers.6": first_device, - "model.layers.7": first_device, - "model.layers.8": first_device, - "model.layers.9": first_device, - "model.layers.10": first_device, - "model.layers.11": first_device, - "model.layers.12": first_device, - "model.layers.13": first_device, - "model.layers.14": first_device, - "model.layers.15": first_device, - "model.layers.16": first_device, - "model.layers.17": first_device, - "model.layers.18": first_device, - "model.layers.19": first_device, - "model.layers.20": first_device, - "model.layers.21": first_device, - "model.layers.22": first_device, - "model.layers.23": first_device, - "model.layers.24": first_device, - "model.layers.25": first_device, - "model.layers.26": first_device, - "model.layers.27.self_attn": first_device, - "model.layers.27.mlp.gate_proj": first_device, - "model.layers.27.mlp.down_proj": first_device, - "model.layers.27.mlp.up_proj": first_device, - "model.layers.27.mlp.act_fn": first_device, - "model.layers.27.input_layernorm": first_device, - "model.layers.27.post_attention_layernorm": first_device, - "model.layers.28": first_device, - "model.layers.29": first_device, - "model.layers.30": first_device, - "model.layers.31": first_device, - "model.layers.32": first_device, - "model.layers.33": first_device, - "model.layers.34": second_device, - "model.layers.35": second_device, - "model.layers.36": second_device, - "model.layers.37": second_device, - "model.layers.38": second_device, - "model.layers.39": second_device, - "model.layers.40": second_device, - "model.layers.41": second_device, - "model.layers.42": second_device, - "model.layers.43": second_device, - "model.layers.44": second_device, - "model.layers.45": second_device, - "model.layers.46": second_device, - "model.layers.47": second_device, - "model.layers.48": second_device, - "model.layers.49": second_device, - "model.layers.50": second_device, - "model.layers.51": second_device, - "model.layers.52": second_device, - "model.layers.53": second_device, - "model.layers.54": second_device, - "model.layers.55": second_device, - "model.layers.56": second_device, - "model.layers.57": second_device, - "model.layers.58": second_device, - "model.layers.59": second_device, - "model.layers.60": second_device, - "model.layers.61": second_device, - "model.layers.62": second_device, - "model.layers.63": second_device, - "model.layers.64": second_device, - "model.layers.65": second_device, - "model.layers.66": second_device, - "model.layers.67": second_device, - "model.layers.68": second_device, - "model.layers.69": second_device, - "model.layers.70": second_device, - "model.layers.71": second_device, - "model.layers.72": second_device, - "model.layers.73": second_device, - "model.layers.74": second_device, - "model.layers.75": second_device, - "model.layers.76": second_device, - "model.layers.77": second_device, - "model.layers.78": second_device, - "model.layers.79": second_device, - "model.norm": second_device, - "lm_head": second_device, - } diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 216d0e78..2ded1256 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -83,7 +83,6 @@ def create_device_map( model_devices: ModelDevices, verbose: bool, ) -> dict[str, str]: - # TODO: Run this before allocating workers """Creates a device map for a model running on multiple GPUs.""" with init_empty_weights(): # Need to first instantiate an empty model to get the layer class @@ -99,7 +98,7 @@ def create_device_map( # Decrease the memory potentially used by the first device # because we're going to create additional tensors on it max_memory_used_devices[model_devices.first_device] = ( - max_memory_used_devices[model_devices.first_device] * 0.8 + max_memory_used_devices[model_devices.first_device] * 0.6 ) # If 8bit, multiply the memory by 2 # This is because we instantiated our empty model in (probably) float16 @@ -113,26 +112,17 @@ def create_device_map( else max_memory_used_devices ) - """ - Make sure that the transformer layer is not split - because that contains residual connections - See https://huggingface.co/docs/accelerate/usage_guides/big_modeling - Otherwise we get an error like this: - RuntimeError: Expected all tensors to be on the same device, - but found at least two devices, cuda:0 and cuda1 - """ + # Make sure that the transformer layer is not split + # because that contains residual connections + # See https://huggingface.co/docs/accelerate/usage_guides/big_modeling + # Otherwise we get an error like this: + # RuntimeError: Expected all tensors to be on the same device, + # but found at least two devices, cuda:0 and cuda1 maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] - # autodevice_map = infer_auto_device_map( - # model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices - # ) - autodevice_map = get_llama_65b_8bit_device_map( - first_device=model_devices.first_device, - second_device=model_devices.other_devices[0], + autodevice_map = infer_auto_device_map( + model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices ) - # TODO: remove this which we just testing out - # explicitly set the lm head of autodevice_map to the first device - autodevice_map["lm_head"] = model_devices.first_device if verbose: print(f"Autodevice map: {autodevice_map}") From 64919b31dc5d63f9444a9b0774d844de6f51bd8b Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 13:43:49 +0800 Subject: [PATCH 26/42] fix import --- elk/utils/multi_gpu.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 2ded1256..dd33e152 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -10,8 +10,7 @@ from elk.utils import instantiate_model, select_usable_devices from elk.utils.gpu_utils import get_available_memory_for_devices -from elk.utils.llama_utils import get_llama_65b_8bit_device_map - += if TYPE_CHECKING: from elk import Extract From fe331bc17df2342268ada8bb000b7e9e617369c3 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 13:46:36 +0800 Subject: [PATCH 27/42] remove syntax --- elk/utils/multi_gpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index dd33e152..77ea9bf3 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -10,7 +10,7 @@ from elk.utils import instantiate_model, select_usable_devices from elk.utils.gpu_utils import get_available_memory_for_devices -= + if TYPE_CHECKING: from elk import Extract @@ -55,7 +55,8 @@ def instantiate_model_with_devices( else: if is_verbose: print( - f"Instantiating the model on multiple GPUs: {device_config.used_devices}" + f"Instantiating the model on multiple GPUs" + f": {device_config.used_devices}" ) device_map = create_device_map( model_str=cfg.model, From 0ed7f313dfc9a5cb29fe44aa40a0dbbfc15e420f Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 14:00:36 +0800 Subject: [PATCH 28/42] try comparing to hardcoding --- elk/utils/llama.py | 104 +++++++++++++++++++++++++++++++++++++++++ elk/utils/multi_gpu.py | 9 +++- 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 elk/utils/llama.py diff --git a/elk/utils/llama.py b/elk/utils/llama.py new file mode 100644 index 00000000..c5cf3ba2 --- /dev/null +++ b/elk/utils/llama.py @@ -0,0 +1,104 @@ +import torch +from accelerate import infer_auto_device_map, init_empty_weights + +from elk.utils import instantiate_model + + +def get_llama_65b_8bit_device_map( + first_device: str | torch.device, second_device: str | torch.device +) -> dict[str, str | torch.device]: + """ + This assumes that you are using 2 GPUs, with at least 40GB of memory each. + and that you are using 8bit + """ + return { + "model.embed_tokens": first_device, + "model.layers.0": first_device, + "model.layers.1": first_device, + "model.layers.2": first_device, + "model.layers.3": first_device, + "model.layers.4": first_device, + "model.layers.5": first_device, + "model.layers.6": first_device, + "model.layers.7": first_device, + "model.layers.8": first_device, + "model.layers.9": first_device, + "model.layers.10": first_device, + "model.layers.11": first_device, + "model.layers.12": first_device, + "model.layers.13": first_device, + "model.layers.14": first_device, + "model.layers.15": first_device, + "model.layers.16": first_device, + "model.layers.17": first_device, + "model.layers.18": first_device, + "model.layers.19": first_device, + "model.layers.20": first_device, + "model.layers.21": first_device, + "model.layers.22": first_device, + "model.layers.23": first_device, + "model.layers.24": first_device, + "model.layers.25": first_device, + "model.layers.26": first_device, + "model.layers.27.self_attn": first_device, + "model.layers.27.mlp.gate_proj": first_device, + "model.layers.27.mlp.down_proj": first_device, + "model.layers.27.mlp.up_proj": first_device, + "model.layers.27.mlp.act_fn": first_device, + "model.layers.27.input_layernorm": first_device, + "model.layers.27.post_attention_layernorm": first_device, + "model.layers.28": first_device, + "model.layers.29": first_device, + "model.layers.30": first_device, + "model.layers.31": first_device, + "model.layers.32": first_device, + "model.layers.33": first_device, + "model.layers.34": second_device, + "model.layers.35": second_device, + "model.layers.36": second_device, + "model.layers.37": second_device, + "model.layers.38": second_device, + "model.layers.39": second_device, + "model.layers.40": second_device, + "model.layers.41": second_device, + "model.layers.42": second_device, + "model.layers.43": second_device, + "model.layers.44": second_device, + "model.layers.45": second_device, + "model.layers.46": second_device, + "model.layers.47": second_device, + "model.layers.48": second_device, + "model.layers.49": second_device, + "model.layers.50": second_device, + "model.layers.51": second_device, + "model.layers.52": second_device, + "model.layers.53": second_device, + "model.layers.54": second_device, + "model.layers.55": second_device, + "model.layers.56": second_device, + "model.layers.57": second_device, + "model.layers.58": second_device, + "model.layers.59": second_device, + "model.layers.60": second_device, + "model.layers.61": second_device, + "model.layers.62": second_device, + "model.layers.63": second_device, + "model.layers.64": second_device, + "model.layers.65": second_device, + "model.layers.66": second_device, + "model.layers.67": second_device, + "model.layers.68": second_device, + "model.layers.69": second_device, + "model.layers.70": second_device, + "model.layers.71": second_device, + "model.layers.72": second_device, + "model.layers.73": second_device, + "model.layers.74": second_device, + "model.layers.75": second_device, + "model.layers.76": second_device, + "model.layers.77": second_device, + "model.layers.78": second_device, + "model.layers.79": second_device, + "model.norm": second_device, + "lm_head": first_device, + } diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 77ea9bf3..e935bf41 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -10,6 +10,7 @@ from elk.utils import instantiate_model, select_usable_devices from elk.utils.gpu_utils import get_available_memory_for_devices +from elk.utils.llama import get_llama_65b_8bit_device_map if TYPE_CHECKING: from elk import Extract @@ -120,8 +121,12 @@ def create_device_map( # but found at least two devices, cuda:0 and cuda1 maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] - autodevice_map = infer_auto_device_map( - model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices + # autodevice_map = infer_auto_device_map( + # model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices + # ) + autodevice_map = get_llama_65b_8bit_device_map( + first_device=model_devices.first_device, + second_device=model_devices.other_devices[0], ) if verbose: From 69bbf644f37f3fee1aeac1d1225bcbae0fdf735d Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 15:42:14 +0800 Subject: [PATCH 29/42] Revert "try comparing to hardcoding" This reverts commit 0ed7f313dfc9a5cb29fe44aa40a0dbbfc15e420f. --- elk/utils/llama.py | 104 ----------------------------------------- elk/utils/multi_gpu.py | 9 +--- 2 files changed, 2 insertions(+), 111 deletions(-) delete mode 100644 elk/utils/llama.py diff --git a/elk/utils/llama.py b/elk/utils/llama.py deleted file mode 100644 index c5cf3ba2..00000000 --- a/elk/utils/llama.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from accelerate import infer_auto_device_map, init_empty_weights - -from elk.utils import instantiate_model - - -def get_llama_65b_8bit_device_map( - first_device: str | torch.device, second_device: str | torch.device -) -> dict[str, str | torch.device]: - """ - This assumes that you are using 2 GPUs, with at least 40GB of memory each. - and that you are using 8bit - """ - return { - "model.embed_tokens": first_device, - "model.layers.0": first_device, - "model.layers.1": first_device, - "model.layers.2": first_device, - "model.layers.3": first_device, - "model.layers.4": first_device, - "model.layers.5": first_device, - "model.layers.6": first_device, - "model.layers.7": first_device, - "model.layers.8": first_device, - "model.layers.9": first_device, - "model.layers.10": first_device, - "model.layers.11": first_device, - "model.layers.12": first_device, - "model.layers.13": first_device, - "model.layers.14": first_device, - "model.layers.15": first_device, - "model.layers.16": first_device, - "model.layers.17": first_device, - "model.layers.18": first_device, - "model.layers.19": first_device, - "model.layers.20": first_device, - "model.layers.21": first_device, - "model.layers.22": first_device, - "model.layers.23": first_device, - "model.layers.24": first_device, - "model.layers.25": first_device, - "model.layers.26": first_device, - "model.layers.27.self_attn": first_device, - "model.layers.27.mlp.gate_proj": first_device, - "model.layers.27.mlp.down_proj": first_device, - "model.layers.27.mlp.up_proj": first_device, - "model.layers.27.mlp.act_fn": first_device, - "model.layers.27.input_layernorm": first_device, - "model.layers.27.post_attention_layernorm": first_device, - "model.layers.28": first_device, - "model.layers.29": first_device, - "model.layers.30": first_device, - "model.layers.31": first_device, - "model.layers.32": first_device, - "model.layers.33": first_device, - "model.layers.34": second_device, - "model.layers.35": second_device, - "model.layers.36": second_device, - "model.layers.37": second_device, - "model.layers.38": second_device, - "model.layers.39": second_device, - "model.layers.40": second_device, - "model.layers.41": second_device, - "model.layers.42": second_device, - "model.layers.43": second_device, - "model.layers.44": second_device, - "model.layers.45": second_device, - "model.layers.46": second_device, - "model.layers.47": second_device, - "model.layers.48": second_device, - "model.layers.49": second_device, - "model.layers.50": second_device, - "model.layers.51": second_device, - "model.layers.52": second_device, - "model.layers.53": second_device, - "model.layers.54": second_device, - "model.layers.55": second_device, - "model.layers.56": second_device, - "model.layers.57": second_device, - "model.layers.58": second_device, - "model.layers.59": second_device, - "model.layers.60": second_device, - "model.layers.61": second_device, - "model.layers.62": second_device, - "model.layers.63": second_device, - "model.layers.64": second_device, - "model.layers.65": second_device, - "model.layers.66": second_device, - "model.layers.67": second_device, - "model.layers.68": second_device, - "model.layers.69": second_device, - "model.layers.70": second_device, - "model.layers.71": second_device, - "model.layers.72": second_device, - "model.layers.73": second_device, - "model.layers.74": second_device, - "model.layers.75": second_device, - "model.layers.76": second_device, - "model.layers.77": second_device, - "model.layers.78": second_device, - "model.layers.79": second_device, - "model.norm": second_device, - "lm_head": first_device, - } diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index e935bf41..77ea9bf3 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -10,7 +10,6 @@ from elk.utils import instantiate_model, select_usable_devices from elk.utils.gpu_utils import get_available_memory_for_devices -from elk.utils.llama import get_llama_65b_8bit_device_map if TYPE_CHECKING: from elk import Extract @@ -121,12 +120,8 @@ def create_device_map( # but found at least two devices, cuda:0 and cuda1 maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] - # autodevice_map = infer_auto_device_map( - # model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices - # ) - autodevice_map = get_llama_65b_8bit_device_map( - first_device=model_devices.first_device, - second_device=model_devices.other_devices[0], + autodevice_map = infer_auto_device_map( + model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices ) if verbose: From 8c6386c364c159a3df74f982f3cff3a244815a31 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 3 May 2023 23:27:02 +0800 Subject: [PATCH 30/42] add comment on future improvement --- elk/utils/multi_gpu.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index ab929484..f7cf11a3 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -42,6 +42,10 @@ def instantiate_model_with_devices( else: torch_dtype = "auto" + # TODO: Maybe we should ensure the device map is the same + # for all the extract processes? This is because the device map + # can affect performance highly and its annoying if one process + # is using a different device map than the others. device_map = ( {"": first_device} if device_config.is_single_gpu From 0182a649044002e84e0756481066b882a09a2377 Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 00:39:01 +0800 Subject: [PATCH 31/42] print --- README.md | 2 +- elk/utils/multi_gpu.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f78cf710..1b4b28c8 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ This is an example to run a single 8bit llama-65b model on 2 A40s that have ~50 GB of memory each. ``` -elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 --8int +elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 --int8 ``` ## Caching diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index f7cf11a3..127c7cbc 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -97,10 +97,12 @@ def create_device_map( max_memory_used_devices[model_devices.first_device] = ( max_memory_used_devices[model_devices.first_device] * 0.6 ) + if use_8bit: + print("Using 8bit") # If 8bit, multiply the memory by 2 # This is because we instantiated our empty model in (probably) float16 # We aren't able to instantiate an empty model in 8bit currently - max_memory_used_devices = ( + devices_accounted_8bit = ( { device: max_memory_used_devices[device] * 2 for device in max_memory_used_devices @@ -118,7 +120,7 @@ def create_device_map( maybe_transformer_class: Type[Module] | None = get_transformer_layer_cls(model) dont_split = [maybe_transformer_class.__name__] if maybe_transformer_class else [] autodevice_map = infer_auto_device_map( - model, no_split_module_classes=dont_split, max_memory=max_memory_used_devices + model, no_split_module_classes=dont_split, max_memory=devices_accounted_8bit ) if verbose: From df1c0ff23c436c784d6a54cd91ea00a8b7cbba98 Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 01:33:40 +0800 Subject: [PATCH 32/42] load in 8bit correctly --- elk/utils/hf_utils.py | 44 ++++++++++++++++++++++++++++-------------- elk/utils/multi_gpu.py | 28 +++++++++++++-------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 9f429921..522fbdd1 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -20,15 +20,11 @@ _AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES -def instantiate_model( +def determine_dtypes( model_str: str, - device: str | torch.device = "cpu", - **kwargs, -) -> PreTrainedModel: - """Instantiate a model string with the appropriate `Auto` class.""" - device = torch.device(device) - kwargs["device_map"] = {"": device} - + is_cpu: bool, + load_in_8bit: bool, +) -> torch.dtype | str: with prevent_name_conflicts(): model_cfg = AutoConfig.from_pretrained(model_str) @@ -37,27 +33,47 @@ def instantiate_model( fp32_weights = model_cfg.torch_dtype in (None, torch.float32) # Required by `bitsandbytes` to load in 8-bit. - if kwargs.get("load_in_8bit"): + if load_in_8bit: # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and # we can't guarantee that there won't be overflow if we downcast to fp16. if fp32_weights: raise ValueError("Cannot load in 8-bit if weights are fp32") - kwargs["torch_dtype"] = torch.float16 + torch_dtype = torch.float16 # CPUs generally don't support anything other than fp32. - elif device.type == "cpu": - kwargs["torch_dtype"] = torch.float32 + elif is_cpu: + torch_dtype = torch.float32 # If the model is fp32 but bf16 is available, convert to bf16. # Usually models with fp32 weights were actually trained in bf16, and # converting them doesn't hurt performance. elif fp32_weights and torch.cuda.is_bf16_supported(): - kwargs["torch_dtype"] = torch.bfloat16 + torch_dtype = torch.bfloat16 print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") else: - kwargs["torch_dtype"] = "auto" + torch_dtype = "auto" + return torch_dtype + + +def instantiate_model( + model_str: str, + load_in_8bit: bool, + is_cpu: bool, + **kwargs, +) -> PreTrainedModel: + """Instantiate a model string with the appropriate `Auto` class.""" + + with prevent_name_conflicts(): + model_cfg = AutoConfig.from_pretrained(model_str) + # If a torch_dtype was not specified, try to infer it. + if "torch_dtype" not in kwargs: + kwargs["torch_dtype"] = determine_dtypes( + model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit + ) + # Add load_in_8bit to kwargs + kwargs["load_in_8bit"] = load_in_8bit archs = model_cfg.architectures if not isinstance(archs, list): diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 127c7cbc..a8395bd8 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -29,18 +29,16 @@ def is_single_gpu(self) -> bool: def used_devices(self) -> list[str]: return [self.first_device] + self.other_devices + @property + def has_cpu_device(self) -> bool: + devices = [torch.device(device) for device in self.used_devices] + return any(device.type == "cpu" for device in devices) + def instantiate_model_with_devices( cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs ) -> PreTrainedModel: first_device = device_config.first_device - if cfg.int8: - # Required by `bitsandbytes` - torch_dtype = torch.float16 - elif device_config == "cpu": - torch_dtype = torch.float32 - else: - torch_dtype = "auto" # TODO: Maybe we should ensure the device map is the same # for all the extract processes? This is because the device map @@ -51,8 +49,7 @@ def instantiate_model_with_devices( if device_config.is_single_gpu else create_device_map( model_str=cfg.model, - use_8bit=cfg.int8, - torch_dtype=torch_dtype, + load_in_8bit=cfg.int8, model_devices=device_config, verbose=is_verbose, ) @@ -67,7 +64,7 @@ def instantiate_model_with_devices( cfg.model, device_map=device_map, load_in_8bit=cfg.int8, - torch_dtype=torch_dtype, + is_cpu=device_config.has_cpu_device, **kwargs, ) return model @@ -75,15 +72,16 @@ def instantiate_model_with_devices( def create_device_map( model_str: str, - use_8bit: float, - torch_dtype: dtype | str, + load_in_8bit: bool, model_devices: ModelDevices, verbose: bool, ) -> dict[str, str]: """Creates a device map for a model running on multiple GPUs.""" with init_empty_weights(): # Need to first instantiate an empty model to get the layer class - model = instantiate_model(model_str=model_str, torch_dtype=torch_dtype) + model = instantiate_model( + model_str=model_str, load_in_8bit=load_in_8bit, is_cpu=False + ) # e.g. {"cuda:0": 16000, "cuda:1": 16000} max_memory_all_devices: dict[str, int] = get_available_memory_for_devices() @@ -97,7 +95,7 @@ def create_device_map( max_memory_used_devices[model_devices.first_device] = ( max_memory_used_devices[model_devices.first_device] * 0.6 ) - if use_8bit: + if load_in_8bit: print("Using 8bit") # If 8bit, multiply the memory by 2 # This is because we instantiated our empty model in (probably) float16 @@ -107,7 +105,7 @@ def create_device_map( device: max_memory_used_devices[device] * 2 for device in max_memory_used_devices } - if use_8bit + if load_in_8bit else max_memory_used_devices ) From 6d9e9ea2956fc6e43add6bf612f9e34d29adb12e Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 01:36:38 +0800 Subject: [PATCH 33/42] add comment --- elk/utils/multi_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index a8395bd8..b4dc0736 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -79,9 +79,9 @@ def create_device_map( """Creates a device map for a model running on multiple GPUs.""" with init_empty_weights(): # Need to first instantiate an empty model to get the layer class - model = instantiate_model( - model_str=model_str, load_in_8bit=load_in_8bit, is_cpu=False - ) + # We need to specify load_in_8bit False because its incompatible with + # init_empty_weights + model = instantiate_model(model_str=model_str, load_in_8bit=False, is_cpu=False) # e.g. {"cuda:0": 16000, "cuda:1": 16000} max_memory_all_devices: dict[str, int] = get_available_memory_for_devices() From 02602cb507225f072639d9ad30d958fc2ee93943 Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 01:47:47 +0800 Subject: [PATCH 34/42] try passing float16? --- elk/extraction/extraction.py | 6 +----- elk/utils/hf_utils.py | 10 ++++++---- elk/utils/multi_gpu.py | 8 ++++++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 971ad305..ad040ed6 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -157,9 +157,7 @@ def extract_hiddens( world_size: int = 1, ) -> Iterable[dict]: first_device = ( - devices - if not isinstance(devices, ModelDevices) - else devices.first_device + devices if not isinstance(devices, ModelDevices) else devices.first_device ) """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -174,7 +172,6 @@ def extract_hiddens( ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." - model = instantiate_model_with_devices( cfg=cfg, device_config=devices, is_verbose=is_verbose ) @@ -182,7 +179,6 @@ def extract_hiddens( cfg.model, truncation_side="left", verbose=is_verbose ) - is_enc_dec = model.config.is_encoder_decoder if is_enc_dec and cfg.use_encoder_states: assert hasattr(model, "get_encoder") and callable(model.get_encoder) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 522fbdd1..a1fc8bf5 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import transformers from transformers import ( @@ -61,6 +63,7 @@ def instantiate_model( model_str: str, load_in_8bit: bool, is_cpu: bool, + torch_dtype: Optional[torch.dtype] = None, **kwargs, ) -> PreTrainedModel: """Instantiate a model string with the appropriate `Auto` class.""" @@ -68,10 +71,9 @@ def instantiate_model( with prevent_name_conflicts(): model_cfg = AutoConfig.from_pretrained(model_str) # If a torch_dtype was not specified, try to infer it. - if "torch_dtype" not in kwargs: - kwargs["torch_dtype"] = determine_dtypes( - model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit - ) + kwargs["torch_dtype"] = torch_dtype or determine_dtypes( + model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit + ) # Add load_in_8bit to kwargs kwargs["load_in_8bit"] = load_in_8bit diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index b4dc0736..4fe4c0de 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -4,7 +4,6 @@ import torch from accelerate import infer_auto_device_map, init_empty_weights -from torch import dtype from torch.nn import Module from transformers import PreTrainedModel @@ -81,7 +80,12 @@ def create_device_map( # Need to first instantiate an empty model to get the layer class # We need to specify load_in_8bit False because its incompatible with # init_empty_weights - model = instantiate_model(model_str=model_str, load_in_8bit=False, is_cpu=False) + model = instantiate_model( + model_str=model_str, + load_in_8bit=False, + is_cpu=False, + torch_dtype=torch.float16, + ) # e.g. {"cuda:0": 16000, "cuda:1": 16000} max_memory_all_devices: dict[str, int] = get_available_memory_for_devices() From d3a8f2934cf074aa589d85eed663783d3cfc8a7e Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 01:51:19 +0800 Subject: [PATCH 35/42] prevent mem issues? --- elk/utils/gpu_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 4938faa9..c5b64e30 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -168,8 +168,6 @@ def select_usable_devices( def get_available_memory_for_devices() -> dict[str, int]: # Edited from get_max_memory of the accelerate library - for i in range(torch.cuda.device_count()): - _ = torch.tensor([0], device=i) max_memory = { f"cuda:{i}": torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count()) From bf827ea03edc58998eecbe37d06fe49456da1460 Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 01:57:47 +0800 Subject: [PATCH 36/42] add logs --- elk/utils/hf_utils.py | 51 +++++++++++++++++++++--------------------- elk/utils/multi_gpu.py | 1 - 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index a1fc8bf5..4c03c942 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -27,36 +27,35 @@ def determine_dtypes( is_cpu: bool, load_in_8bit: bool, ) -> torch.dtype | str: - with prevent_name_conflicts(): - model_cfg = AutoConfig.from_pretrained(model_str) + model_cfg = AutoConfig.from_pretrained(model_str) - # When the torch_dtype is None, this generally means the model is fp32, because - # the config was probably created before the `torch_dtype` field was added. - fp32_weights = model_cfg.torch_dtype in (None, torch.float32) + # When the torch_dtype is None, this generally means the model is fp32, because + # the config was probably created before the `torch_dtype` field was added. + fp32_weights = model_cfg.torch_dtype in (None, torch.float32) - # Required by `bitsandbytes` to load in 8-bit. - if load_in_8bit: - # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint - # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and - # we can't guarantee that there won't be overflow if we downcast to fp16. - if fp32_weights: - raise ValueError("Cannot load in 8-bit if weights are fp32") + # Required by `bitsandbytes` to load in 8-bit. + if load_in_8bit: + # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint + # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and + # we can't guarantee that there won't be overflow if we downcast to fp16. + if fp32_weights: + raise ValueError("Cannot load in 8-bit if weights are fp32") - torch_dtype = torch.float16 + torch_dtype = torch.float16 - # CPUs generally don't support anything other than fp32. - elif is_cpu: - torch_dtype = torch.float32 + # CPUs generally don't support anything other than fp32. + elif is_cpu: + torch_dtype = torch.float32 - # If the model is fp32 but bf16 is available, convert to bf16. - # Usually models with fp32 weights were actually trained in bf16, and - # converting them doesn't hurt performance. - elif fp32_weights and torch.cuda.is_bf16_supported(): - torch_dtype = torch.bfloat16 - print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") - else: - torch_dtype = "auto" - return torch_dtype + # If the model is fp32 but bf16 is available, convert to bf16. + # Usually models with fp32 weights were actually trained in bf16, and + # converting them doesn't hurt performance. + elif fp32_weights and torch.cuda.is_bf16_supported(): + torch_dtype = torch.bfloat16 + print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") + else: + torch_dtype = "auto" + return torch_dtype def instantiate_model( @@ -88,7 +87,7 @@ def instantiate_model( if arch_str.endswith(suffix): model_cls = getattr(transformers, arch_str) return model_cls.from_pretrained(model_str, **kwargs) - + print(f"Loading model with {kwargs}") return AutoModel.from_pretrained(model_str, **kwargs) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 4fe4c0de..477cc395 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -84,7 +84,6 @@ def create_device_map( model_str=model_str, load_in_8bit=False, is_cpu=False, - torch_dtype=torch.float16, ) # e.g. {"cuda:0": 16000, "cuda:1": 16000} From 301e6e2dcdf146b6cd2e02ea75055120e63ba0fe Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 02:03:12 +0800 Subject: [PATCH 37/42] try only adding load_in_8bit if we really need to --- elk/utils/hf_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 4c03c942..8200c91b 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -74,7 +74,8 @@ def instantiate_model( model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit ) # Add load_in_8bit to kwargs - kwargs["load_in_8bit"] = load_in_8bit + if load_in_8bit: + kwargs["load_in_8bit"] = load_in_8bit archs = model_cfg.architectures if not isinstance(archs, list): From 6b6bb6fd4cb58b6c9ccb3fa8460de755bb2fc8aa Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 02:09:49 +0800 Subject: [PATCH 38/42] catch max mem --- elk/utils/gpu_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index c5b64e30..5a9b1ac9 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -167,9 +167,13 @@ def select_usable_devices( def get_available_memory_for_devices() -> dict[str, int]: - # Edited from get_max_memory of the accelerate library - max_memory = { - f"cuda:{i}": torch.cuda.mem_get_info(i)[0] - for i in range(torch.cuda.device_count()) - } + # Edited from get_max_memory of the accelerate library to + # catch out of memory errors + max_memory = {} + for i in range(torch.cuda.device_count()): + try: + max_memory[f"cuda:{i}"]: torch.cuda.mem_get_info(i)[0] + except RuntimeError: + max_memory[f"cuda:{i}"]: 0 + return max_memory From a5b3d5fca34c4fb7770f50d294148296ad5a6b59 Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 02:09:56 +0800 Subject: [PATCH 39/42] Revert "try only adding load_in_8bit if we really need to" This reverts commit 301e6e2dcdf146b6cd2e02ea75055120e63ba0fe. --- elk/utils/hf_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 8200c91b..4c03c942 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -74,8 +74,7 @@ def instantiate_model( model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit ) # Add load_in_8bit to kwargs - if load_in_8bit: - kwargs["load_in_8bit"] = load_in_8bit + kwargs["load_in_8bit"] = load_in_8bit archs = model_cfg.architectures if not isinstance(archs, list): From 99db2a01ba84937b3ed85983580fa7ee98b13f3f Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 02:19:24 +0800 Subject: [PATCH 40/42] try out means of memory --- elk/utils/gpu_utils.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 5a9b1ac9..9783c41e 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -167,13 +167,15 @@ def select_usable_devices( def get_available_memory_for_devices() -> dict[str, int]: - # Edited from get_max_memory of the accelerate library to - # catch out of memory errors - max_memory = {} - for i in range(torch.cuda.device_count()): - try: - max_memory[f"cuda:{i}"]: torch.cuda.mem_get_info(i)[0] - except RuntimeError: - max_memory[f"cuda:{i}"]: 0 - - return max_memory + # PyNVML and PyTorch device indices should agree when CUDA_VISIBLE_DEVICES is + # not set. We need them to agree so that the PyNVML indices match the PyTorch + # indices, and we don't have to do any complex error-prone conversions. + num_visible = torch.cuda.device_count() + num_installed = pynvml.nvmlDeviceGetCount() + assert num_installed == num_visible, "PyNVML and PyTorch disagree on GPU count" + output = {} + # Get free memory for each GPU + for i in range(num_installed): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + output[f"cuda:{i}"] = int(pynvml.nvmlDeviceGetMemoryInfo(handle).free) + return output From 55b18ab9037080621473d5c0cedd1082a688d6cd Mon Sep 17 00:00:00 2001 From: James Chua Date: Thu, 4 May 2023 02:23:19 +0800 Subject: [PATCH 41/42] remove debug print --- elk/utils/hf_utils.py | 1 - elk/utils/multi_gpu.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 4c03c942..26319cea 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -87,7 +87,6 @@ def instantiate_model( if arch_str.endswith(suffix): model_cls = getattr(transformers, arch_str) return model_cls.from_pretrained(model_str, **kwargs) - print(f"Loading model with {kwargs}") return AutoModel.from_pretrained(model_str, **kwargs) diff --git a/elk/utils/multi_gpu.py b/elk/utils/multi_gpu.py index 477cc395..f6b945cc 100644 --- a/elk/utils/multi_gpu.py +++ b/elk/utils/multi_gpu.py @@ -98,8 +98,6 @@ def create_device_map( max_memory_used_devices[model_devices.first_device] = ( max_memory_used_devices[model_devices.first_device] * 0.6 ) - if load_in_8bit: - print("Using 8bit") # If 8bit, multiply the memory by 2 # This is because we instantiated our empty model in (probably) float16 # We aren't able to instantiate an empty model in 8bit currently From fa52400053cd37ac96c150a96ffc3947b476fbd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 May 2023 18:25:43 +0000 Subject: [PATCH 42/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 2 +- tests/test_smoke_elicit.py | 4 ++-- tests/test_smoke_eval.py | 2 +- tests/test_split_devices.py | 39 ++++++++++++++---------------------- tests/test_truncated_eigh.py | 2 +- 5 files changed, 20 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 1b4b28c8..772a9163 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ gpus per model. This is an example to run a single 8bit llama-65b model on 2 A40s that have ~50 GB of memory each. - + ``` elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 --int8 ``` diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index aed8e51e..7cf0e8c9 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -7,7 +7,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 dataset_name = "imdb" elicit = Elicit( data=Extract( @@ -38,7 +38,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024 ** 2 + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 dataset_name = "imdb" elicit = Elicit( data=Extract( diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index 683e718a..d58db6cd 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -19,7 +19,7 @@ def setup_elicit( tmp_path: Path, dataset_name="imdb", model_path="sshleifer/tiny-gpt2", - min_mem=10 * 1024 ** 2, + min_mem=10 * 1024**2, is_ccs: bool = True, ) -> Elicit: """Setup elicit config for testing, execute elicit, and save output to tmp_path. diff --git a/tests/test_split_devices.py b/tests/test_split_devices.py index 8168ebd8..da85b051 100644 --- a/tests/test_split_devices.py +++ b/tests/test_split_devices.py @@ -5,39 +5,30 @@ def test_split_2_devices_1_gpu_per_model(): devices = ["a", "b"] gpus_per_model = 1 models_to_create = 2 - assert ( - split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, - ) - == [ModelDevices("a", []), ModelDevices("b", [])] - ) + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", []), ModelDevices("b", [])] def test_split_4_devices_2_gpus_per_model(): devices = ["a", "b", "c", "d"] gpus_per_model = 2 models_to_create = 2 - assert ( - split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, - ) - == [ModelDevices("a", ["b"]), ModelDevices("c", ["d"])] - ) + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", ["b"]), ModelDevices("c", ["d"])] def test_split_7_devices_3_gpus_per_model(): devices = ["a", "b", "c", "d", "e", "f", "g"] gpus_per_model = 3 models_to_create = 2 - assert ( - split_devices_into_model_devices( - devices=devices, - gpus_per_model=gpus_per_model, - models_to_create=models_to_create, - ) - == [ModelDevices("a", ["b", "c"]), ModelDevices("d", ["e", "f"])] - ) + assert split_devices_into_model_devices( + devices=devices, + gpus_per_model=gpus_per_model, + models_to_create=models_to_create, + ) == [ModelDevices("a", ["b", "c"]), ModelDevices("d", ["e", "f"])] diff --git a/tests/test_truncated_eigh.py b/tests/test_truncated_eigh.py index 84a3de87..5241f1c0 100644 --- a/tests/test_truncated_eigh.py +++ b/tests/test_truncated_eigh.py @@ -11,7 +11,7 @@ def random_symmetric_matrix(n: int, k: int) -> torch.Tensor: assert k <= n, "Rank k should be less than or equal to the matrix size n." # Generate random n x k matrix A with elements drawn from a uniform distribution - A = torch.rand(n, k) / k ** 0.5 + A = torch.rand(n, k) / k**0.5 # Create a diagonal matrix D with k eigenvalues evenly distributed around zero eigenvalues = torch.linspace(-1, 1, k)