Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/opentau/scripts/visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,23 @@ def visualize_dataset(
save: bool = False,
output_dir: Path | None = None,
urdf: Path | None = None,
joint_names: list[str] | None = None,
) -> Path | None:
r"""
Visualize data of a given episode of a LeRobotDataset with rerun.

Args:
dataset: The dataset to visualize.
episode_index: The index of the episode to visualize.
batch_size: Batch size for loading data. Defaults to 32.
num_workers: Number of workers for data loading. Defaults to 0.
mode: Visualization mode, either "local" or "distant". Defaults to "local".
web_port: Web port for rerun when in "distant" mode. Defaults to 9090.
save: Whether to save the visualization as a .rrd file instead of spawning a viewer. Defaults to False.
output_dir: Directory to save the .rrd file if `save` is True. Required if `save` is True. Defaults to None.
urdf: Path to a URDF file to load and visualize alongside the dataset. Defaults to None.
joint_names: List of joint names for each state dimension, in order. Used for associating state dimensions with URDF joints. If not provided, state names from dataset metadata will be used. Defaults to None.
"""
if save:
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
Expand Down Expand Up @@ -202,16 +218,15 @@ def visualize_dataset(
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()

urdf_joints = {}
if joint_names is None:
joint_names = dataset.meta.info.get("features", {}).get("observation.state", {}).get("names", [])

if urdf:
rr.log_file_from_path(urdf, static=True)
urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf)
urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"]
print(
"Assuming the dataset state dimensions correspond to URDF joints in order:\n",
"\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)),
)
else:
urdf_joints = []
urdf_joints = {jnt.name: jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"}
print("Assuming joints are ordered as: ", joint_names)

if mode == "distant":
rr.serve_web_viewer(open_browser=False, web_port=web_port)
Expand All @@ -236,12 +251,12 @@ def visualize_dataset(

# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", _rr_scalar(val.item()))
# Assuming the state dimensions correspond to URDF joints in order.
# TODO(shuheng): allow overriding with a mapping from state dim to joint name.
if dim_idx < len(urdf_joints):
joint = urdf_joints[dim_idx]
states = batch["observation.state"][i]
for dim_idx, val in enumerate(states):
jnt_name = joint_names[dim_idx] if dim_idx < len(joint_names) else str(dim_idx)
rr.log(f"state/{jnt_name}", _rr_scalar(val.item()))
if jnt_name in urdf_joints:
joint = urdf_joints[jnt_name]
transform = joint.compute_transform(float(val))
rr.log("URDF", transform)

Expand Down Expand Up @@ -364,6 +379,16 @@ def parse_args() -> dict:
"which will be used if this argument is not provided."
),
)
parser.add_argument(
"--joint-names",
type=str,
default="",
help=(
"If provided, a comma-separated list of joint names for each state dimension, in order. "
"This is used to associate state dimensions with URDF joints for visualization. "
"If not provided, the script will use the state names in dataset metadata (info.json)"
),
)

args = parser.parse_args()
return vars(args)
Expand All @@ -375,6 +400,8 @@ def main():
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
urdf_package_dir = kwargs.pop("urdf_package_dir")
joint_names = kwargs.pop("joint_names")
kwargs["joint_names"] = None if not joint_names else joint_names.split(",")
if urdf_package_dir:
os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix()

Expand Down
Loading