diff --git a/src/opentau/scripts/visualize_dataset.py b/src/opentau/scripts/visualize_dataset.py index 3eba00b..1098890 100644 --- a/src/opentau/scripts/visualize_dataset.py +++ b/src/opentau/scripts/visualize_dataset.py @@ -172,7 +172,23 @@ def visualize_dataset( save: bool = False, output_dir: Path | None = None, urdf: Path | None = None, + joint_names: list[str] | None = None, ) -> Path | None: + r""" + Visualize data of a given episode of a LeRobotDataset with rerun. + + Args: + dataset: The dataset to visualize. + episode_index: The index of the episode to visualize. + batch_size: Batch size for loading data. Defaults to 32. + num_workers: Number of workers for data loading. Defaults to 0. + mode: Visualization mode, either "local" or "distant". Defaults to "local". + web_port: Web port for rerun when in "distant" mode. Defaults to 9090. + save: Whether to save the visualization as a .rrd file instead of spawning a viewer. Defaults to False. + output_dir: Directory to save the .rrd file if `save` is True. Required if `save` is True. Defaults to None. + urdf: Path to a URDF file to load and visualize alongside the dataset. Defaults to None. + joint_names: List of joint names for each state dimension, in order. Used for associating state dimensions with URDF joints. If not provided, state names from dataset metadata will be used. Defaults to None. + """ if save: assert output_dir is not None, ( "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." @@ -202,16 +218,15 @@ def visualize_dataset( # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix gc.collect() + urdf_joints = {} + if joint_names is None: + joint_names = dataset.meta.info.get("features", {}).get("observation.state", {}).get("names", []) + if urdf: rr.log_file_from_path(urdf, static=True) urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf) - urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"] - print( - "Assuming the dataset state dimensions correspond to URDF joints in order:\n", - "\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)), - ) - else: - urdf_joints = [] + urdf_joints = {jnt.name: jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"} + print("Assuming joints are ordered as: ", joint_names) if mode == "distant": rr.serve_web_viewer(open_browser=False, web_port=web_port) @@ -236,12 +251,12 @@ def visualize_dataset( # display each dimension of observed state space (e.g. agent position in joint space) if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): - rr.log(f"state/{dim_idx}", _rr_scalar(val.item())) - # Assuming the state dimensions correspond to URDF joints in order. - # TODO(shuheng): allow overriding with a mapping from state dim to joint name. - if dim_idx < len(urdf_joints): - joint = urdf_joints[dim_idx] + states = batch["observation.state"][i] + for dim_idx, val in enumerate(states): + jnt_name = joint_names[dim_idx] if dim_idx < len(joint_names) else str(dim_idx) + rr.log(f"state/{jnt_name}", _rr_scalar(val.item())) + if jnt_name in urdf_joints: + joint = urdf_joints[jnt_name] transform = joint.compute_transform(float(val)) rr.log("URDF", transform) @@ -364,6 +379,16 @@ def parse_args() -> dict: "which will be used if this argument is not provided." ), ) + parser.add_argument( + "--joint-names", + type=str, + default="", + help=( + "If provided, a comma-separated list of joint names for each state dimension, in order. " + "This is used to associate state dimensions with URDF joints for visualization. " + "If not provided, the script will use the state names in dataset metadata (info.json)" + ), + ) args = parser.parse_args() return vars(args) @@ -375,6 +400,8 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") urdf_package_dir = kwargs.pop("urdf_package_dir") + joint_names = kwargs.pop("joint_names") + kwargs["joint_names"] = None if not joint_names else joint_names.split(",") if urdf_package_dir: os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix()