From a9e8b6581129ffec9b16d2ad0b62db77c4229d27 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sat, 7 Feb 2026 01:11:56 -0800 Subject: [PATCH 1/3] feat: support setting joint names for each state dim in URDF viz --- src/opentau/scripts/visualize_dataset.py | 38 ++++++++++++++++-------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/opentau/scripts/visualize_dataset.py b/src/opentau/scripts/visualize_dataset.py index 3eba00b..fc310ad 100644 --- a/src/opentau/scripts/visualize_dataset.py +++ b/src/opentau/scripts/visualize_dataset.py @@ -172,6 +172,7 @@ def visualize_dataset( save: bool = False, output_dir: Path | None = None, urdf: Path | None = None, + joint_names: list[str] = None, ) -> Path | None: if save: assert output_dir is not None, ( @@ -202,16 +203,15 @@ def visualize_dataset( # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix gc.collect() + urdf_joints = {} + if joint_names is None: + joint_names = dataset.meta.info.get("features", {}).get("observation.state", {}).get("names", []) + if urdf: rr.log_file_from_path(urdf, static=True) urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf) - urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"] - print( - "Assuming the dataset state dimensions correspond to URDF joints in order:\n", - "\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)), - ) - else: - urdf_joints = [] + urdf_joints = {jnt.name: jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"} + print("Assuming joints are ordered as: ", joint_names) if mode == "distant": rr.serve_web_viewer(open_browser=False, web_port=web_port) @@ -236,12 +236,12 @@ def visualize_dataset( # display each dimension of observed state space (e.g. agent position in joint space) if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): - rr.log(f"state/{dim_idx}", _rr_scalar(val.item())) - # Assuming the state dimensions correspond to URDF joints in order. - # TODO(shuheng): allow overriding with a mapping from state dim to joint name. - if dim_idx < len(urdf_joints): - joint = urdf_joints[dim_idx] + states = batch["observation.state"][i] + for dim_idx, val in enumerate(states): + jnt_name = joint_names[dim_idx] if dim_idx < len(joint_names) else str(dim_idx) + rr.log(f"state/{jnt_name}", _rr_scalar(val.item())) + if jnt_name in urdf_joints: + joint = urdf_joints[jnt_name] transform = joint.compute_transform(float(val)) rr.log("URDF", transform) @@ -364,6 +364,16 @@ def parse_args() -> dict: "which will be used if this argument is not provided." ), ) + parser.add_argument( + "--joint-names", + type=str, + default="", + help=( + "If provided, a comma-separated list of joint names for each state dimension, in order. " + "This is used to associate state dimensions with URDF joints for visualization. " + "If not provided, the script will use the state names in dataset metadata (info.json)" + ), + ) args = parser.parse_args() return vars(args) @@ -375,6 +385,8 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") urdf_package_dir = kwargs.pop("urdf_package_dir") + joint_names = kwargs.pop("joint_names") + kwargs["joint_names"] = None if not joint_names else joint_names.split(",") if urdf_package_dir: os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix() From 54f766aedb6e9e8ca06ac38e4e3f90fcbfc0d5ed Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sat, 7 Feb 2026 15:03:02 -0800 Subject: [PATCH 2/3] docs: improve type hinting for viz script Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/opentau/scripts/visualize_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opentau/scripts/visualize_dataset.py b/src/opentau/scripts/visualize_dataset.py index fc310ad..f2843e5 100644 --- a/src/opentau/scripts/visualize_dataset.py +++ b/src/opentau/scripts/visualize_dataset.py @@ -172,7 +172,7 @@ def visualize_dataset( save: bool = False, output_dir: Path | None = None, urdf: Path | None = None, - joint_names: list[str] = None, + joint_names: list[str] | None = None, ) -> Path | None: if save: assert output_dir is not None, ( From 1220d1af03b359fc43c6b90dfcf1e5f02107ddde Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sat, 7 Feb 2026 15:07:12 -0800 Subject: [PATCH 3/3] feat: add google style visualization script --- src/opentau/scripts/visualize_dataset.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/opentau/scripts/visualize_dataset.py b/src/opentau/scripts/visualize_dataset.py index f2843e5..1098890 100644 --- a/src/opentau/scripts/visualize_dataset.py +++ b/src/opentau/scripts/visualize_dataset.py @@ -174,6 +174,21 @@ def visualize_dataset( urdf: Path | None = None, joint_names: list[str] | None = None, ) -> Path | None: + r""" + Visualize data of a given episode of a LeRobotDataset with rerun. + + Args: + dataset: The dataset to visualize. + episode_index: The index of the episode to visualize. + batch_size: Batch size for loading data. Defaults to 32. + num_workers: Number of workers for data loading. Defaults to 0. + mode: Visualization mode, either "local" or "distant". Defaults to "local". + web_port: Web port for rerun when in "distant" mode. Defaults to 9090. + save: Whether to save the visualization as a .rrd file instead of spawning a viewer. Defaults to False. + output_dir: Directory to save the .rrd file if `save` is True. Required if `save` is True. Defaults to None. + urdf: Path to a URDF file to load and visualize alongside the dataset. Defaults to None. + joint_names: List of joint names for each state dimension, in order. Used for associating state dimensions with URDF joints. If not provided, state names from dataset metadata will be used. Defaults to None. + """ if save: assert output_dir is not None, ( "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."