From 1403304ae6d660f5d3bfb32d671cd2b59a390eca Mon Sep 17 00:00:00 2001 From: liuyuliang Date: Wed, 6 Apr 2022 16:57:09 +0800 Subject: [PATCH 1/2] change to fit refactored schedule --- features/pipeline_parallel/resnet.py | 9 ++---- .../hybrid_parallel/train_with_engine.py | 10 +------ .../hybrid_parallel/train_with_trainer.py | 13 +-------- .../pipeline_parallel/vit.py | 9 ++---- language/DeepNet/train_deepnet_decoder.py | 12 -------- language/bert/sequene_parallel/train.py | 29 +++---------------- language/gpt/train_gpt.py | 12 -------- 7 files changed, 10 insertions(+), 84 deletions(-) diff --git a/features/pipeline_parallel/resnet.py b/features/pipeline_parallel/resnet.py index 3a781f0..b7708d4 100644 --- a/features/pipeline_parallel/resnet.py +++ b/features/pipeline_parallel/resnet.py @@ -156,7 +156,7 @@ def build_cifar(batch_size): BATCH_SIZE = 64 NUM_EPOCHS = 2 NUM_CHUNKS = 1 -CONFIG = dict(parallel=dict(pipeline=2)) +CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) def train(): @@ -184,12 +184,7 @@ def train(): train_dataloader, test_dataloader, lr_scheduler) timer = MultiTimer() - if NUM_CHUNKS == 1: - schedule = PipelineSchedule(num_microbatches=4) - else: - schedule = InterleavedPipelineSchedule(num_microbatches=4, num_model_chunks=NUM_CHUNKS) - - trainer = Trainer(engine=engine, timer=timer, logger=logger, schedule=schedule) + trainer = Trainer(engine=engine, timer=timer, logger=logger) hook_list = [ hooks.LossHook(), diff --git a/image/vision_transformer/hybrid_parallel/train_with_engine.py b/image/vision_transformer/hybrid_parallel/train_with_engine.py index ec106b6..2dd4ccb 100644 --- a/image/vision_transformer/hybrid_parallel/train_with_engine.py +++ b/image/vision_transformer/hybrid_parallel/train_with_engine.py @@ -130,14 +130,6 @@ def train_imagenet(): scatter_gather = True else: scatter_gather = False - if use_pipeline: - logger.info('Build PipelineSchedule', ranks=[0]) - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) - schedule.pre_processing(engine) - - if schedule is None: - schedule = NonPipelineSchedule() data_iter = iter(train_dataloader) @@ -155,7 +147,7 @@ def train_imagenet(): progress = range(len(train_dataloader)) for _ in progress: engine.zero_grad() - schedule.forward_backward_step(engine, data_iter, return_output_label=False) + engine.execute_schedule(data_iter, return_output_label=False) engine.step() lr_scheduler.step() diff --git a/image/vision_transformer/hybrid_parallel/train_with_trainer.py b/image/vision_transformer/hybrid_parallel/train_with_trainer.py index 74d4f81..f5d81d2 100644 --- a/image/vision_transformer/hybrid_parallel/train_with_trainer.py +++ b/image/vision_transformer/hybrid_parallel/train_with_trainer.py @@ -118,23 +118,12 @@ def train_imagenet(): logger.info("Engine is built", ranks=[0]) - # create schedule - schedule = None - tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None) - if gpc.is_initialized(ParallelMode.PARALLEL_1D): - scatter_gather = True - else: - scatter_gather = False - if use_pipeline: - logger.info('Build PipelineSchedule', ranks=[0]) - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) # create timer timer = MultiTimer() # create trainer - trainer = Trainer(engine=engine, logger=logger, timer=timer, schedule=schedule) + trainer = Trainer(engine=engine, logger=logger, timer=timer) logger.info("Trainer is built", ranks=[0]) # create a list of useful hooks diff --git a/image/vision_transformer/pipeline_parallel/vit.py b/image/vision_transformer/pipeline_parallel/vit.py index 1e30fcc..9afa6ae 100644 --- a/image/vision_transformer/pipeline_parallel/vit.py +++ b/image/vision_transformer/pipeline_parallel/vit.py @@ -147,7 +147,7 @@ def build_cifar(batch_size): BATCH_SIZE = 128 NUM_EPOCHS = 2 NUM_CHUNKS = 1 -CONFIG = dict(parallel=dict(pipeline=2)) +CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) def train(): @@ -174,12 +174,7 @@ def train(): train_dataloader, test_dataloader) timer = MultiTimer() - if NUM_CHUNKS == 1: - schedule = PipelineSchedule(num_microbatches=4) - else: - schedule = InterleavedPipelineSchedule(num_microbatches=4, num_model_chunks=NUM_CHUNKS) - - trainer = Trainer(engine=engine, timer=timer, logger=logger, schedule=schedule) + trainer = Trainer(engine=engine, timer=timer, logger=logger) hook_list = [ hooks.LossHook(), diff --git a/language/DeepNet/train_deepnet_decoder.py b/language/DeepNet/train_deepnet_decoder.py index 4a3a80c..d6a073b 100644 --- a/language/DeepNet/train_deepnet_decoder.py +++ b/language/DeepNet/train_deepnet_decoder.py @@ -71,24 +71,12 @@ def main(): global_batch_size = gpc.config.BATCH_SIZE * \ gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) - tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None) - schedule = None - if use_pipeline: - if use_interleaved: - logger.info('Build InterleavedPipelineSchedule', ranks=[0]) - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True) - else: - logger.info('Build PipelineSchedule', ranks=[0]) - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=True) timier = MultiTimer() trainer = Trainer( engine=engine, logger=logger, - schedule=schedule, timer=timier ) diff --git a/language/bert/sequene_parallel/train.py b/language/bert/sequene_parallel/train.py index d720f69..fa57023 100644 --- a/language/bert/sequene_parallel/train.py +++ b/language/bert/sequene_parallel/train.py @@ -18,19 +18,6 @@ from model.bert import build_pipeline_bert -def get_tensor_shape(): - if not gpc.is_initialized(ParallelMode.PIPELINE): - return None - - dp_size = gpc.get_world_size(ParallelMode.DATA) - if gpc.is_initialized(ParallelMode.SEQUENCE): - seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) - else: - seq_size = 1 - tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, - gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, - gpc.config.HIDDEN_SIZE) - return tensor_shape def process_batch_data(batch_data): @@ -157,15 +144,7 @@ def main(): criterion, ) - # schedule - schedule = None - tensor_shape = get_tensor_shape() - if use_pipeline: - logger.info('Build PipelineSchedule', ranks=[0]) - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=False, - batch_data_process_func=process_batch_data) - schedule.pre_processing(engine) + # build timer timer = MultiTimer() @@ -185,7 +164,7 @@ def main(): engine.train() if use_pipeline: engine.zero_grad() - _, _, train_loss = schedule.forward_backward_step(engine, train_data_iter, return_output_label=False) + _, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False) success, grad_norm, num_zeros_in_grad = engine.step() else: tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( @@ -211,8 +190,8 @@ def main(): for j in range(gpc.config.EVAL_ITERS): with torch.no_grad(): if use_pipeline: - _, _, eval_loss = schedule.forward_backward_step( - engine, valid_data_iter, forward_only=True, return_output_label=False) + _, _, eval_loss = engine.execute_schedule( + valid_data_iter, forward_only=True, return_output_label=False) else: tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( validloader) diff --git a/language/gpt/train_gpt.py b/language/gpt/train_gpt.py index 2480bf3..8060d4b 100644 --- a/language/gpt/train_gpt.py +++ b/language/gpt/train_gpt.py @@ -81,24 +81,12 @@ def main(): global_batch_size = gpc.config.BATCH_SIZE * \ gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) - tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None) - schedule = None - if use_pipeline: - if use_interleaved: - logger.info('Build InterleavedPipelineSchedule', ranks=[0]) - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True) - else: - logger.info('Build PipelineSchedule', ranks=[0]) - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=True) timier = MultiTimer() trainer = Trainer( engine=engine, logger=logger, - schedule=schedule, timer=timier ) From d1938fe13325fe36d2af5760d2f2196d6820c070 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Thu, 1 Dec 2022 17:20:26 +0800 Subject: [PATCH 2/2] [autoparallel] update resnet demo --- .../auto_parallel/auto_parallel_demo.py | 21 +++++++++---------- image/resnet/auto_parallel/config.py | 2 ++ 2 files changed, 12 insertions(+), 11 deletions(-) create mode 100644 image/resnet/auto_parallel/config.py diff --git a/image/resnet/auto_parallel/auto_parallel_demo.py b/image/resnet/auto_parallel/auto_parallel_demo.py index 429a99e..288430b 100644 --- a/image/resnet/auto_parallel/auto_parallel_demo.py +++ b/image/resnet/auto_parallel/auto_parallel_demo.py @@ -16,19 +16,17 @@ from titans.utils import barrier_context from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions, DataloaderOption from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer DATA_ROOT = Path(os.environ.get('DATA', './data')) -BATCH_SIZE = 1024 -NUM_EPOCHS = 10 def main(): - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch(config='./config.py') logger = get_dist_logger() @@ -52,16 +50,16 @@ def main(): train_dataloader = get_dataloader( dataset=train_dataset, - add_sampler=False, + add_sampler=True, shuffle=True, - batch_size=BATCH_SIZE, + batch_size=gpc.config.BATCH_SIZE, pin_memory=True, ) test_dataloader = get_dataloader( dataset=test_dataset, add_sampler=False, - batch_size=BATCH_SIZE, + batch_size=gpc.config.BATCH_SIZE, pin_memory=True, ) @@ -73,13 +71,13 @@ def main(): # trace the model with meta data tracer = ColoTracer() model = resnet50(num_classes=10).cuda() - input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')} + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() # prepare info for solver - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies) @@ -106,9 +104,9 @@ def main(): optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # lr_scheduler - lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS) + lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) - for epoch in range(NUM_EPOCHS): + for epoch in range(gpc.config.NUM_EPOCHS): gm.train() if gpc.get_global_rank() == 0: train_dl = tqdm(train_dataloader) @@ -121,6 +119,7 @@ def main(): output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) train_loss = criterion(output, label) train_loss.backward(train_loss) + torch.cuda.synchronize() optimizer.step() lr_scheduler.step() diff --git a/image/resnet/auto_parallel/config.py b/image/resnet/auto_parallel/config.py new file mode 100644 index 0000000..feaef04 --- /dev/null +++ b/image/resnet/auto_parallel/config.py @@ -0,0 +1,2 @@ +BATCH_SIZE = 128 +NUM_EPOCHS = 10 \ No newline at end of file