-
Notifications
You must be signed in to change notification settings - Fork 33
Description
Hello. It seems that there is a bug in the unet code that tackles gradient checkpointing.
In order to save cuda memory , I attempted to use gradient checkpointing by adding this code to my train.py : unet.enable_gradient_checkpointing().
However , I encountered an error
File "CameraCtrl/cameractrl/models/attention.py", line 40, in forward pose_feature = pose_feature.permute(0, 3, 4, 2, 1).reshape(batch_size * seq_length, num_frames, -1) AttributeError: 'NoneType' object has no attribute 'permute'
It seems that pose_feature was set to none instead of the right value.
Perhaps the reason is inside unet_3d_blocks.py . In forward of CrossAttnDownBlockSpatioTemporalPoseCond , UNetMidBlockSpatioTemporalPoseCond and CrossAttnUpBlockSpatioTemporalPoseCond , the code
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
failed to sent pose_feature to torch.utils.checkpoint.checkpoint .
After changing these code into
def create_custom_forward(resnet_module, attn_module):
def custom_forward(hidden_states, temb, image_only_indicator, encoder_hidden_states, pose_feature):
hidden_states = resnet_module(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
hidden_states = attn_module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
pose_feature=pose_feature,
return_dict=False,
)[0]
return hidden_states
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet, attn),
hidden_states,
temb,
image_only_indicator,
encoder_hidden_states,
pose_feature,
**ckpt_kwargs,
)
It works and successfully apply gradient checkpointing to unet .