Skip to content

Does the training data follow the action sequence repetition pattern? #57

@H1yori233

Description

@H1yori233

Thanks for open-sourcing the model and code! I have been diving into the implementation and noticed an interesting pattern regarding how action sequences are handled.

I noticed both inference and inference_streaming have same action sequence repeat logic: every action is repeat for 12 times.

    current_frame = 0
    selections = [12]

    while current_frame < num_frames:
        rd_frame = selections[random.randint(0, len(selections) - 1)]
        rd = random.randint(0, len(data) - 1)
        k = data[rd]['keyboard_condition']
        if mouse == True:
            m = data[rd]['mouse_condition']
        
        if current_frame == 0:
            keyboard_condition[:1] = k[:1]
            if mouse == True:
                mouse_condition[:1] = m[:1]
            current_frame = 1
        else:
            rd_frame = min(rd_frame, num_frames - current_frame)
            repeat_time = rd_frame // 4
            keyboard_condition[current_frame:current_frame+rd_frame] = k.repeat(repeat_time, 1)
            if mouse == True:
                mouse_condition[current_frame:current_frame+rd_frame] = m.repeat(repeat_time, 1)
            current_frame += rd_frame

and

    if replace != None:
        if current_start_frame == 0:
            last_frame_num = 1 + 4 * (num_frame_per_block - 1)
        else:
            last_frame_num = 4 * num_frame_per_block
        final_frame = 1 + 4 * (current_start_frame + num_frame_per_block-1)
        if mode != 'templerun':
            conditional_dict["mouse_cond"][:, -last_frame_num + final_frame: final_frame] = replace['mouse'][None, None, :].repeat(1, last_frame_num, 1)
        conditional_dict["keyboard_cond"][:, -last_frame_num + final_frame: final_frame] = replace['keyboard'][None, None, :].repeat(1, last_frame_num, 1)
    if mode != 'templerun':
        new_cond["mouse_cond"] = conditional_dict["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
    new_cond["keyboard_cond"] = conditional_dict["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]

Also, I notice this in your action_module, you handle action cond with a sliding window with size of 12 and stride of 4:

        pad_t = self.vae_time_compression_ratio * self.windows_size
        if self.enable_mouse and mouse_condition is not None:
            pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1)
            mouse_condition = torch.cat([pad, mouse_condition], dim=1)
            if is_causal and kv_cache_mouse is not None: 
                mouse_condition = mouse_condition[:, self.vae_time_compression_ratio*(N_feats - num_frame_per_block - self.windows_size) + pad_t:, :] 
                group_mouse = [mouse_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(num_frame_per_block)]
            else:
                group_mouse = [mouse_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(N_feats)]
                
            group_mouse = torch.stack(group_mouse, dim = 1)

            S = th * tw 
            group_mouse = group_mouse.unsqueeze(-1).expand(B, num_frame_per_block, pad_t, C, S)
            group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape(B * S, num_frame_per_block, pad_t * C) 

            group_mouse = torch.cat([hidden_states, group_mouse], dim = -1)
            group_mouse = self.mouse_mlp(group_mouse)
        if self.enable_keyboard and keyboard_condition is not None:
            pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1)
            keyboard_condition = torch.cat([pad, keyboard_condition], dim=1)
            if is_causal and kv_cache_keyboard is not None:
                keyboard_condition = keyboard_condition[:, self.vae_time_compression_ratio*(N_feats - num_frame_per_block - self.windows_size) + pad_t:, :] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:]
                keyboard_condition = self.keyboard_embed(keyboard_condition)
                group_keyboard = [keyboard_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(num_frame_per_block)]
            else:
                keyboard_condition = self.keyboard_embed(keyboard_condition)
                group_keyboard = [keyboard_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(N_feats)]
            group_keyboard = torch.stack(group_keyboard, dim = 1) # B F RW C
            group_keyboard = group_keyboard.reshape(shape=(group_keyboard.shape[0],group_keyboard.shape[1],-1))

Therefore, I am just wondering does your training data also follow the pattern of repeating 12 times? Or was the model trained on raw, high-frequency action data? If you can answer, this will be very helpful!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions