diff --git a/flow_control/flow/module_raft.py b/flow_control/flow/module_raft.py index 8ba9949..86de79b 100644 --- a/flow_control/flow/module_raft.py +++ b/flow_control/flow/module_raft.py @@ -74,7 +74,7 @@ def _totorch(self, array): """ return torch.from_numpy(array)[None].float().permute(0, 3, 1, 2).cuda() - def step(self, img0, img1): + def step(self, img0, img1, store_flow: bool=True): """ compute flow @@ -100,7 +100,7 @@ def step(self, img0, img1): test_mode=True ) - self.flow_prev = forward_interpolate(flow_low[0])[None].cuda() + self.flow_prev = forward_interpolate(flow_low[0])[None].cuda() if store_flow else None return padder.unpad(flow_up[0]).permute(1, 2, 0).detach().cpu().numpy() def warp(self, x, flow, mode="bilinear"):