Skip to content

Applying gradient checkpointing to unet in branch svd #27

@Jiayi-hit

Description

@Jiayi-hit

Hello. It seems that there is a bug in the unet code that tackles gradient checkpointing.

In order to save cuda memory , I attempted to use gradient checkpointing by adding this code to my train.py : unet.enable_gradient_checkpointing().

However , I encountered an error
File "CameraCtrl/cameractrl/models/attention.py", line 40, in forward pose_feature = pose_feature.permute(0, 3, 4, 2, 1).reshape(batch_size * seq_length, num_frames, -1) AttributeError: 'NoneType' object has no attribute 'permute'

It seems that pose_feature was set to none instead of the right value.

Perhaps the reason is inside unet_3d_blocks.py . In forward of CrossAttnDownBlockSpatioTemporalPoseCond , UNetMidBlockSpatioTemporalPoseCond and CrossAttnUpBlockSpatioTemporalPoseCond , the code

 if self.training and self.gradient_checkpointing: 
            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(resnet),
                hidden_states,
                temb,
                image_only_indicator,
                **ckpt_kwargs,
            )
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                image_only_indicator=image_only_indicator,
                return_dict=False,
            )[0]

failed to sent pose_feature to torch.utils.checkpoint.checkpoint .

After changing these code into

 def create_custom_forward(resnet_module, attn_module):
       def custom_forward(hidden_states, temb, image_only_indicator, encoder_hidden_states, pose_feature):
               hidden_states = resnet_module(
                        hidden_states,
                        temb,
                        image_only_indicator=image_only_indicator,
               )
               hidden_states = attn_module(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        image_only_indicator=image_only_indicator,
                        pose_feature=pose_feature,
                        return_dict=False,
                 )[0]
      return hidden_states
 return custom_forward

  ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
  hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(resnet, attn),
                hidden_states,
                temb,
                image_only_indicator,
                encoder_hidden_states,
                pose_feature,
                **ckpt_kwargs,
   )

It works and successfully apply gradient checkpointing to unet .

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions