-
Notifications
You must be signed in to change notification settings - Fork 191
Open
Description
Thanks for open-sourcing the model and code! I have been diving into the implementation and noticed an interesting pattern regarding how action sequences are handled.
I noticed both inference and inference_streaming have same action sequence repeat logic: every action is repeat for 12 times.
current_frame = 0
selections = [12]
while current_frame < num_frames:
rd_frame = selections[random.randint(0, len(selections) - 1)]
rd = random.randint(0, len(data) - 1)
k = data[rd]['keyboard_condition']
if mouse == True:
m = data[rd]['mouse_condition']
if current_frame == 0:
keyboard_condition[:1] = k[:1]
if mouse == True:
mouse_condition[:1] = m[:1]
current_frame = 1
else:
rd_frame = min(rd_frame, num_frames - current_frame)
repeat_time = rd_frame // 4
keyboard_condition[current_frame:current_frame+rd_frame] = k.repeat(repeat_time, 1)
if mouse == True:
mouse_condition[current_frame:current_frame+rd_frame] = m.repeat(repeat_time, 1)
current_frame += rd_frameand
if replace != None:
if current_start_frame == 0:
last_frame_num = 1 + 4 * (num_frame_per_block - 1)
else:
last_frame_num = 4 * num_frame_per_block
final_frame = 1 + 4 * (current_start_frame + num_frame_per_block-1)
if mode != 'templerun':
conditional_dict["mouse_cond"][:, -last_frame_num + final_frame: final_frame] = replace['mouse'][None, None, :].repeat(1, last_frame_num, 1)
conditional_dict["keyboard_cond"][:, -last_frame_num + final_frame: final_frame] = replace['keyboard'][None, None, :].repeat(1, last_frame_num, 1)
if mode != 'templerun':
new_cond["mouse_cond"] = conditional_dict["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
new_cond["keyboard_cond"] = conditional_dict["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]Also, I notice this in your action_module, you handle action cond with a sliding window with size of 12 and stride of 4:
pad_t = self.vae_time_compression_ratio * self.windows_size
if self.enable_mouse and mouse_condition is not None:
pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1)
mouse_condition = torch.cat([pad, mouse_condition], dim=1)
if is_causal and kv_cache_mouse is not None:
mouse_condition = mouse_condition[:, self.vae_time_compression_ratio*(N_feats - num_frame_per_block - self.windows_size) + pad_t:, :]
group_mouse = [mouse_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(num_frame_per_block)]
else:
group_mouse = [mouse_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(N_feats)]
group_mouse = torch.stack(group_mouse, dim = 1)
S = th * tw
group_mouse = group_mouse.unsqueeze(-1).expand(B, num_frame_per_block, pad_t, C, S)
group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape(B * S, num_frame_per_block, pad_t * C)
group_mouse = torch.cat([hidden_states, group_mouse], dim = -1)
group_mouse = self.mouse_mlp(group_mouse) if self.enable_keyboard and keyboard_condition is not None:
pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1)
keyboard_condition = torch.cat([pad, keyboard_condition], dim=1)
if is_causal and kv_cache_keyboard is not None:
keyboard_condition = keyboard_condition[:, self.vae_time_compression_ratio*(N_feats - num_frame_per_block - self.windows_size) + pad_t:, :] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:]
keyboard_condition = self.keyboard_embed(keyboard_condition)
group_keyboard = [keyboard_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(num_frame_per_block)]
else:
keyboard_condition = self.keyboard_embed(keyboard_condition)
group_keyboard = [keyboard_condition[:, self.vae_time_compression_ratio*(i - self.windows_size) + pad_t:i * self.vae_time_compression_ratio + pad_t,:] for i in range(N_feats)]
group_keyboard = torch.stack(group_keyboard, dim = 1) # B F RW C
group_keyboard = group_keyboard.reshape(shape=(group_keyboard.shape[0],group_keyboard.shape[1],-1))Therefore, I am just wondering does your training data also follow the pattern of repeating 12 times? Or was the model trained on raw, high-frequency action data? If you can answer, this will be very helpful!
Metadata
Metadata
Assignees
Labels
No labels