From 8b0ddc8a94b39314197ea6db7579365b0f1388ed Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 30 May 2023 15:42:22 -0700 Subject: [PATCH 1/4] Add FALCON auto-tp support --- deepspeed/module_inject/auto_tp.py | 4 ++++ deepspeed/module_inject/replace_module.py | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 396fe7db2447..d80f0bf612b3 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -108,6 +108,10 @@ def tp_parser(model): gem_list = gem_list + [layer] elif 'down_proj' in layer: gem_list = gem_list + [layer] + elif 'self_attention.dense' in layer and 'falcon' in str( + type(module)): # this is a hack to get the right linear layer for this model! + gem_list = gem_list + [layer] + layer_list = [] if gem_list != []: gem_list = list(set(gem_list)) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f0fe81f28714..385194995b6b 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -438,6 +438,10 @@ def update_mp_params(child): assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format( child.num_heads, mp_size) child.num_heads = child.num_heads // mp_size + if hasattr(child, 'num_kv'): + assert child.num_kv % mp_size == 0, "num_kv ({}) must be divisible by mp_size ({})".format( + child.num_kv, mp_size) + child.num_kv = child.num_kv // mp_size if hasattr(child, 'num_attention_heads'): assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format( child.num_attention_heads, mp_size) @@ -492,9 +496,17 @@ def _replace_module(r_module, prev_name='', prev_class_name=''): continue if len(child._buffers) != 0 and state_dict != None: load_buffer(child, state_dict, checking_key) - if child.__class__ in linear_policies: - setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, - conv_linear_layer)) + if any(isinstance(child, lp) for lp in linear_policies): + if child.__class__ in linear_policies: + setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, + conv_linear_layer)) + else: + key = None + for lp in linear_policies: + if isinstance(child, lp): + key = lp + assert key is not None + setattr(r_module, name, linear_policies[key](child, prev_name + '.' + name, conv_linear_layer)) else: update_mp_params(child) _replace_module(child, name, class_name) @@ -522,6 +534,7 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module if checkpoint_dict != None and not config.replace_with_kernel_inject: + # AutoTP shard loading checkpoint = checkpoint_dict["checkpoints"] pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") From 4bbed0e2983b5abfc141a8745a1aacceb4f86413 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 30 Jun 2023 10:42:04 -0700 Subject: [PATCH 2/4] refactor multiple if statements into a loop --- deepspeed/module_inject/replace_module.py | 45 ++++------------------- 1 file changed, 8 insertions(+), 37 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 385194995b6b..da1c140cc012 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -426,42 +426,14 @@ def _slice_embedding(child, name, conv_linear_layer): def update_mp_params(child): if getattr(child, "replaced", False) == True: return - if hasattr(child, 'n_heads'): - assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format( - child.n_heads, mp_size) - child.n_heads = child.n_heads // mp_size - if hasattr(child, 'inner_dim'): - assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format( - child.inner_dim, mp_size) - child.inner_dim = child.inner_dim // mp_size - if hasattr(child, 'num_heads'): - assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format( - child.num_heads, mp_size) - child.num_heads = child.num_heads // mp_size - if hasattr(child, 'num_kv'): - assert child.num_kv % mp_size == 0, "num_kv ({}) must be divisible by mp_size ({})".format( - child.num_kv, mp_size) - child.num_kv = child.num_kv // mp_size - if hasattr(child, 'num_attention_heads'): - assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format( - child.num_attention_heads, mp_size) - child.num_attention_heads = child.num_attention_heads // mp_size - if hasattr(child, 'num_attn_heads'): - assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format( - child.num_attn_heads, mp_size) - child.num_attn_heads = child.num_attn_heads // mp_size - if hasattr(child, 'all_head_size'): - assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format( - child.all_head_size, mp_size) - child.all_head_size = child.all_head_size // mp_size - if hasattr(child, 'embed_dim'): - assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format( - child.embed_dim, mp_size) - child.embed_dim = child.embed_dim // mp_size - if hasattr(child, 'hidden_size'): - assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format( - child.hidden_size, mp_size) - child.hidden_size = child.hidden_size // mp_size + for param in [ + "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", + "all_head_size", "embed_dim", "hidden_size" + ]: + if hasattr(child, param): + param_val = getattr(child, param) + assert param_val % mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({mp_size})" + setattr(child, param, param_val // mp_size) setattr(child, "replaced", True) conv_linear_layer = False @@ -534,7 +506,6 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module if checkpoint_dict != None and not config.replace_with_kernel_inject: - # AutoTP shard loading checkpoint = checkpoint_dict["checkpoints"] pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") From 2ec3bd5bd95d9623691eb1b6782f4cc2ffc4eac2 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 30 Jun 2023 11:07:34 -0700 Subject: [PATCH 3/4] add falcon AutoTP unit test --- tests/unit/inference/test_inference.py | 34 ++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index a9da94d5d30f..b0f16542586d 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -380,6 +380,40 @@ def test( assert assert_fn(bs_output, ds_output) +@pytest.mark.seq_inference +@pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"]) +class TestAutoTP(DistributedTest): + world_size = 2 + + def test( + self, + model_w_task, + query, + inf_kwargs, + assert_fn, + ): + model, task = model_w_task + dtype = torch.float16 + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # We have to load these large models on CPU with pipeline because not + # enough GPU memory + pipe = pipeline(task, model=model, torch_dtype=dtype, device=torch.device("cpu"), framework="pt") + bs_output = pipe(query, **inf_kwargs) + + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=self.world_size, + dtype=dtype, + replace_with_kernel_inject=False) + # Switch device to GPU so that input tensors are not on CPU + pipe.device = torch.device(get_accelerator().device_name(local_rank)) + ds_output = pipe(query, **inf_kwargs) + + print(local_rank, "baseline", bs_output) + print(local_rank, "deepspeed", ds_output) + assert assert_fn(bs_output, ds_output) + + @pytest.mark.seq_inference @pytest.mark.parametrize( "model_w_task, injection_policy", From 6ade7583033488a4957d1267ae31a464a95333da Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 30 Jun 2023 13:56:55 -0700 Subject: [PATCH 4/4] added (skipped) unit test, refactored code to be more readable --- deepspeed/module_inject/replace_module.py | 24 +++++++++--------- tests/unit/inference/test_inference.py | 30 ++++++++++++++--------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index da1c140cc012..4753f55b9ea3 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -468,17 +468,19 @@ def _replace_module(r_module, prev_name='', prev_class_name=''): continue if len(child._buffers) != 0 and state_dict != None: load_buffer(child, state_dict, checking_key) - if any(isinstance(child, lp) for lp in linear_policies): - if child.__class__ in linear_policies: - setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, - conv_linear_layer)) - else: - key = None - for lp in linear_policies: - if isinstance(child, lp): - key = lp - assert key is not None - setattr(r_module, name, linear_policies[key](child, prev_name + '.' + name, conv_linear_layer)) + if child.__class__ in linear_policies: + setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, + conv_linear_layer)) + elif any(isinstance(child, lp) for lp in linear_policies): + # Added for falcon model support + # Note: isinstance will account for class inheritance, child.__class__ does not + key = None + for lp in linear_policies: + if isinstance(child, lp): + key = lp + break + assert key is not None + setattr(r_module, name, linear_policies[key](child, prev_name + '.' + name, conv_linear_layer)) else: update_mp_params(child) _replace_module(child, name, class_name) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index b0f16542586d..c9073e66b67b 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -13,7 +13,7 @@ from unit.common import DistributedTest from packaging import version as pkg_version from deepspeed.ops.op_builder import OpBuilder -from transformers import pipeline +from transformers import pipeline, AutoTokenizer from transformers.models.t5.modeling_t5 import T5Block from transformers.models.roberta.modeling_roberta import RobertaLayer from huggingface_hub import HfApi @@ -383,7 +383,7 @@ def test( @pytest.mark.seq_inference @pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"]) class TestAutoTP(DistributedTest): - world_size = 2 + world_size = 1 def test( self, @@ -392,26 +392,32 @@ def test( inf_kwargs, assert_fn, ): + # TODO: enable this test for H100 tests + pytest.skip("Not enough GPU memory for this on V100 runners") model, task = model_w_task - dtype = torch.float16 + dtype = torch.bfloat16 local_rank = int(os.getenv("LOCAL_RANK", "0")) # We have to load these large models on CPU with pipeline because not # enough GPU memory - pipe = pipeline(task, model=model, torch_dtype=dtype, device=torch.device("cpu"), framework="pt") - bs_output = pipe(query, **inf_kwargs) - - pipe.model = deepspeed.init_inference(pipe.model, - mp_size=self.world_size, - dtype=dtype, - replace_with_kernel_inject=False) + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + pipe = pipeline(task, + model=model, + tokenizer=tokenizer, + torch_dtype=dtype, + trust_remote_code=True, + device=torch.device("cpu"), + framework="pt") + #bs_output = pipe(query, **inf_kwargs) + + pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False) # Switch device to GPU so that input tensors are not on CPU pipe.device = torch.device(get_accelerator().device_name(local_rank)) ds_output = pipe(query, **inf_kwargs) - print(local_rank, "baseline", bs_output) + #print(local_rank, "baseline", bs_output) print(local_rank, "deepspeed", ds_output) - assert assert_fn(bs_output, ds_output) + #assert assert_fn(bs_output, ds_output) @pytest.mark.seq_inference