Skip to content

Commit e3e785c

Browse files
mingxu1067ashors1
authored andcommitted
Addind a comment to get_input_bld
Signed-off-by: Ming-Xu Huang <mingh@nvidia.com>
1 parent 0bd4a53 commit e3e785c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

praxis/contrib/gpu/scripts_gpu/te_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, laye
5050

5151
@staticmethod
5252
def get_input_bld(original_bld, batch_axes, mdl_axis):
53+
# This is used to specify the sharding pattern of inputs to TransformerLayers.
5354
raise NotImplementedError
5455

5556
@staticmethod
@@ -89,7 +90,7 @@ def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, laye
8990
return layer_p
9091

9192
@staticmethod
92-
def get_input_bld(original_bld, _, _):
93+
def get_input_bld(original_bld, *_):
9394
return original_bld
9495

9596
@staticmethod

0 commit comments

Comments
 (0)