diff --git a/README.md b/README.md index d750e9ca..288c3afb 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ The toolbox supports the most comprehensive 6 datasets \& benchmarks and 10+ pop The toolbox supports the most advanced high-quality navigation dataset, InternData-N1, which includes 3k+ scenes and 830k VLN data covering diverse embodiments and scenes, and the first dual-system navigation foundation model with leading performance on all the benchmarks and zero-shot generalization capability in the real world, InternVLA-N1. ## πŸ”₯ News - +- [2025/09] Real-world deployment code of InternVLA-N1 is released. - [2025/07] We are hosting πŸ†IROS 2025 Grand Challenge, stay tuned at [official website](https://internrobotics.shlab.org.cn/challenge/2025/). - [2025/07] InternNav v0.1.1 released. diff --git a/internnav/agent/internvla_n1_agent_realworld.py b/internnav/agent/internvla_n1_agent_realworld.py new file mode 100644 index 00000000..cd123ef7 --- /dev/null +++ b/internnav/agent/internvla_n1_agent_realworld.py @@ -0,0 +1,250 @@ +import copy +import itertools +import os +import re +import sys +import time +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from collections import OrderedDict + +from PIL import Image +from transformers import AutoProcessor + +from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM +from internnav.model.utils.vln_utils import S2Output, split_and_clean, traj_to_actions + +DEFAULT_IMAGE_TOKEN = "" + + +class InternVLAN1AsyncAgent: + def __init__(self, args): + self.device = torch.device(args.device) + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") + self.model = InternVLAN1ForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map={"": self.device}, + ) + self.model.eval() + self.model.to(self.device) + + self.processor = AutoProcessor.from_pretrained(args.model_path) + self.processor.tokenizer.padding_side = 'left' + + self.resize_w = args.resize_w + self.resize_h = args.resize_h + self.num_history = args.num_history + + prompt = "You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task." + answer = "" + self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] + self.conjunctions = [ + 'you can see ', + 'in front of you is ', + 'there is ', + 'you can spot ', + 'you are toward the ', + 'ahead of you is ', + 'in your sight is ', + ] + + self.actions2idx = OrderedDict( + { + 'STOP': [0], + "↑": [1], + "←": [2], + "β†’": [3], + "↓": [5], + } + ) + + self.rgb_list = [] + self.depth_list = [] + self.pose_list = [] + self.episode_idx = 0 + self.conversation_history = [] + self.llm_output = "" + self.past_key_values = None + self.last_s2_idx = -100 + + # output + self.output_action = None + self.output_latent = None + self.output_pixel = None + self.pixel_goal_rgb = None + self.pixel_goal_depth = None + + def reset(self): + self.rgb_list = [] + self.depth_list = [] + self.pose_list = [] + self.episode_idx = 0 + self.conversation_history = [] + self.llm_output = "" + self.past_key_values = None + + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") + os.makedirs(self.save_dir, exist_ok=True) + + def parse_actions(self, output): + action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) + regex = re.compile(action_patterns) + matches = regex.findall(output) + actions = [self.actions2idx[match] for match in matches] + actions = itertools.chain.from_iterable(actions) + return list(actions) + + def step_no_infer(self, rgb, depth, pose): + image = Image.fromarray(rgb).convert('RGB') + image = image.resize((self.resize_w, self.resize_h)) + self.rgb_list.append(image) + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg") + self.episode_idx += 1 + + def trajectory_tovw(self, trajectory, kp=1.0): + subgoal = trajectory[-1] + linear_vel, angular_vel = kp * np.linalg.norm(subgoal[:2]), kp * subgoal[2] + linear_vel = np.clip(linear_vel, 0, 0.5) + angular_vel = np.clip(angular_vel, -0.5, 0.5) + return linear_vel, angular_vel + + def step(self, rgb, depth, pose, instruction, intrinsic, look_down=False): + dual_sys_output = S2Output() + PLAN_STEP_GAP = 8 + no_output_flag = self.output_action is None and self.output_latent is None + if (self.episode_idx - self.last_s2_idx > PLAN_STEP_GAP) or look_down or no_output_flag: + self.output_action, self.output_latent, self.output_pixel = self.step_s2( + rgb, depth, pose, instruction, intrinsic, look_down + ) + self.last_s2_idx = self.episode_idx + dual_sys_output.output_pixel = self.output_pixel + self.pixel_goal_rgb = copy.deepcopy(rgb) + self.pixel_goal_depth = copy.deepcopy(depth) + else: + self.step_no_infer(rgb, depth, pose) + + if self.output_action is not None: + dual_sys_output.output_action = copy.deepcopy(self.output_action) + self.output_action = None + elif self.output_latent is not None: + processed_pixel_rgb = np.array(Image.fromarray(self.pixel_goal_rgb).resize((224, 224))) / 255 + processed_pixel_depth = np.array(Image.fromarray(self.pixel_goal_depth).resize((224, 224))) + processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255 + processed_depth = np.array(Image.fromarray(depth).resize((224, 224))) + rgbs = ( + torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]) + .unsqueeze(0) + .to(self.device) + ) + depths = ( + torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]) + .unsqueeze(0) + .unsqueeze(-1) + .to(self.device) + ) + trajectories = self.step_s1(self.output_latent, rgbs, depths) + + dual_sys_output.output_action = traj_to_actions(trajectories) + + return dual_sys_output + + def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down=False): + image = Image.fromarray(rgb).convert('RGB') + if not look_down: + image = image.resize((self.resize_w, self.resize_h)) + self.rgb_list.append(image) + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg") + else: + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}_look_down.jpg") + if not look_down: + self.conversation_history = [] + self.past_key_values = None + + sources = copy.deepcopy(self.conversation) + sources[0]["value"] = sources[0]["value"].replace('.', instruction) + cur_images = self.rgb_list[-1:] + if self.episode_idx == 0: + history_id = [] + else: + history_id = np.unique(np.linspace(0, self.episode_idx - 1, self.num_history, dtype=np.int32)).tolist() + placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) + sources[0]["value"] += f' These are your historical observations: {placeholder}.' + + history_id = sorted(history_id) + self.input_images = [self.rgb_list[i] for i in history_id] + cur_images + input_img_id = 0 + self.episode_idx += 1 + else: + self.input_images.append(image) + input_img_id = -1 + assert self.llm_output != "", "Last llm_output should not be empty when look down" + sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] + self.conversation_history.append( + {'role': 'assistant', 'content': [{'type': 'text', 'text': self.llm_output}]} + ) + + prompt = self.conjunctions[0] + DEFAULT_IMAGE_TOKEN + sources[0]["value"] += f" {prompt}." + prompt_instruction = copy.deepcopy(sources[0]["value"]) + parts = split_and_clean(prompt_instruction) + + content = [] + for i in range(len(parts)): + if parts[i] == "": + content.append({"type": "image", "image": self.input_images[input_img_id]}) + input_img_id += 1 + else: + content.append({"type": "text", "text": parts[i]}) + + self.conversation_history.append({'role': 'user', 'content': content}) + + text = self.processor.apply_chat_template(self.conversation_history, tokenize=False, add_generation_prompt=True) + + inputs = self.processor(text=[text], images=self.input_images, return_tensors="pt").to(self.device) + t0 = time.time() + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=128, + do_sample=False, + use_cache=True, + past_key_values=self.past_key_values, + return_dict_in_generate=True, + raw_input_ids=copy.deepcopy(inputs.input_ids), + ) + output_ids = outputs.sequences + + t1 = time.time() + self.llm_output = self.processor.tokenizer.decode( + output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True + ) + with open(f"{self.save_dir}/llm_output_{self.episode_idx: 04d}.txt", 'w') as f: + f.write(self.llm_output) + self.last_output_ids = copy.deepcopy(output_ids[0]) + self.past_key_values = copy.deepcopy(outputs.past_key_values) + print(f"output {self.episode_idx} {self.llm_output} cost: {t1 - t0}s") + if bool(re.search(r'\d', self.llm_output)): + coord = [int(c) for c in re.findall(r'\d+', self.llm_output)] + pixel_goal = [int(coord[1]), int(coord[0])] + image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0) + pixel_values = inputs.pixel_values + t0 = time.time() + with torch.no_grad(): + traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) + return None, traj_latents, pixel_goal + + else: + action_seq = self.parse_actions(self.llm_output) + return action_seq, None, None + + def step_s1(self, latent, rgb, depth): + all_trajs = self.model.generate_traj(latent, rgb, depth, use_async=True) + return all_trajs diff --git a/internnav/model/basemodel/internvla_n1/internvla_n1.py b/internnav/model/basemodel/internvla_n1/internvla_n1.py index 41b9b97a..40e01dcd 100644 --- a/internnav/model/basemodel/internvla_n1/internvla_n1.py +++ b/internnav/model/basemodel/internvla_n1/internvla_n1.py @@ -1,20 +1,17 @@ from abc import ABC, abstractmethod -import torch -import torch.nn as nn -from .navdp import NavDP_Policy_DPT_CriticSum_DAT - -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import torch +import torch.nn as nn from transformers import ( - Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast -import torch.nn as nn -from internnav.model.utils.vln_utils import * +from .navdp import NavDP_Policy_DPT_CriticSum_DAT + def build_navdp(navdp_cfg): navdp_version = getattr(navdp_cfg, "navdp_version", 0.0) @@ -22,45 +19,44 @@ def build_navdp(navdp_cfg): memory_size = 2 else: memory_size = 3 - - navdp = NavDP_Policy_DPT_CriticSum_DAT(memory_size=memory_size, - navdp_pretrained=navdp_cfg.navdp_pretrained, - navdp_version=navdp_version) + + navdp = NavDP_Policy_DPT_CriticSum_DAT( + memory_size=memory_size, navdp_pretrained=navdp_cfg.navdp_pretrained, navdp_version=navdp_version + ) navdp.load_model() return navdp -class InternVLAN1MetaModel: +class InternVLAN1MetaModel: def __init__(self, config): super(InternVLAN1MetaModel, self).__init__(config) if hasattr(config, "navdp"): self.latent_queries = nn.Parameter(torch.randn(1, config.n_query, config.hidden_size)) self.navdp = build_navdp(config) - + def initialize_vision_modules(self, model_args): if getattr(self, 'navdp', None) is None: self.config.navdp = model_args.navdp self.config.navdp_pretrained = model_args.navdp_pretrained self.navdp = build_navdp(model_args) - + self.config.n_query = model_args.n_query if getattr(self, 'latent_queries', None) is None: print("random initiation the latent_queries !!!") self.latent_queries = nn.Parameter(torch.randn(1, self.config.n_query, self.config.hidden_size)) - + class InternVLAN1MetaForCausalLM(ABC): - @abstractmethod def get_model(self): pass - + def get_navdp(self): return self.get_model().navdp - + def get_mm_projector(self): return self.get_model().mm_projector - + def get_n_query(self): return self.get_model().config.n_query @@ -68,11 +64,11 @@ def get_n_query(self): TRAJ_START_TOKEN_INDEX = 151665 IMAGE_TOKEN_INDEX = 151655 TRAJ_TOKEN_INDEX = 151667 - + class InternVLAN1ModelConfig(Qwen2_5_VLConfig): model_type = "internvla_n1" - + def __init__(self, **kwargs): super().__init__(**kwargs) self.model_cfg = kwargs.get('model_cfg', None) @@ -80,6 +76,7 @@ def __init__(self, **kwargs): class InternVLAN1Model(InternVLAN1MetaModel, Qwen2_5_VLModel): config_class = InternVLAN1ModelConfig + def __init__(self, config: Qwen2_5_VLConfig): super(InternVLAN1Model, self).__init__(config) @@ -90,17 +87,58 @@ class InternVLAN1ForCausalLM(Qwen2_5_VLForConditionalGeneration, InternVLAN1Meta def __init__(self, config): Qwen2_5_VLForConditionalGeneration.__init__(self, config) config.model_type == "internvla_n1" - + self.model = InternVLAN1Model(config) - self.rope_deltas = None + self.rope_deltas = None self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() - - + def get_model(self): return self.model + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + # add for QwenVL kv cache + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + + return model_inputs + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -121,6 +159,7 @@ def forward( rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + raw_input_ids: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -169,10 +208,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + if pixel_values is not None and n_image_tokens > 0: pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + image_embeds = image_embeds[-n_image_tokens:] n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( @@ -206,7 +246,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) n_traj_tokens = (input_ids == TRAJ_TOKEN_INDEX).sum().item() - traj_idx = (input_ids == TRAJ_TOKEN_INDEX) + traj_idx = input_ids == TRAJ_TOKEN_INDEX latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1) H = latent_queries.shape[-1] latent_queries = latent_queries.contiguous().view(-1, H) @@ -232,13 +272,25 @@ def forward( attention_mask, ) self.rope_deltas = rope_deltas + elif n_image_tokens > 0: # using only for kv cache + attention_mask = attention_mask[:, : raw_input_ids.shape[1]] + position_ids, rope_deltas = self.get_rope_index( + raw_input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 + ) + position_ids = position_ids[:, :, -input_ids.shape[1] :] + self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) @@ -246,7 +298,7 @@ def forward( delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - + outputs = self.model( input_ids=None, position_ids=position_ids, @@ -272,40 +324,41 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - + def generate_latents(self, input_ids, pixel_values, image_grid_thw): input_ids.to(self.get_model().device) input_ids = torch.cat([input_ids, torch.tensor([[TRAJ_START_TOKEN_INDEX]]).to(input_ids.device)], dim=1) text_embeds = self.get_model().embed_tokens(input_ids) latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1) - image_idx = (input_ids == IMAGE_TOKEN_INDEX) - N_QUERY = self.get_n_query() + image_idx = input_ids == IMAGE_TOKEN_INDEX + N_QUERY = self.get_n_query() input_ids = torch.cat([input_ids, torch.tensor([[TRAJ_TOKEN_INDEX] * N_QUERY]).to(input_ids.device)], dim=1) pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).unsqueeze(0) - text_embeds[image_idx] = image_embeds.to(text_embeds.device)[:image_idx.sum(), :] + text_embeds[image_idx] = image_embeds.to(text_embeds.device)[: image_idx.sum(), :] text_embeds = torch.cat([text_embeds, latent_queries], dim=1) - position_ids, _ = self.get_rope_index( - input_ids, - image_grid_thw - ) + position_ids, _ = self.get_rope_index(input_ids, image_grid_thw) outputs = self.model( inputs_embeds=text_embeds, - position_ids = position_ids, + position_ids=position_ids, # attention_mask=attention_mask, output_hidden_states=True, return_dict=True, ) - hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:] + hidden_states = outputs.hidden_states[-1][:, -N_QUERY:, :] return hidden_states - + def generate_traj(self, traj_latents, images_dp=None, depths_dp=None, use_async=False): if use_async: - all_trajs = self.model.navdp.predict_pointgoal_action_async(traj_latents.to(self.get_model().device), images_dp, depths_dp, vlm_mask=None) + all_trajs = self.model.navdp.predict_pointgoal_action_async( + traj_latents.to(self.get_model().device), images_dp, depths_dp, vlm_mask=None + ) else: - all_trajs = self.model.navdp.predict_pointgoal_action(traj_latents.to(self.get_model().device), vlm_mask=None) - return all_trajs \ No newline at end of file + all_trajs = self.model.navdp.predict_pointgoal_action( + traj_latents.to(self.get_model().device), vlm_mask=None + ) + return all_trajs diff --git a/scripts/realworld/controllers.py b/scripts/realworld/controllers.py new file mode 100644 index 00000000..c9d495a8 --- /dev/null +++ b/scripts/realworld/controllers.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python + +import math +import os +import sys + +import casadi as ca +import numpy as np + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from scipy.interpolate import interp1d + + +class Mpc_controller: + def __init__(self, global_planed_traj, N=20, desired_v=0.3, v_max=0.4, w_max=0.4, ref_gap=4): + """Initialize the MPC controller. + + Args: + global_planed_traj (np.ndarray): The global planned trajectory, shape (n, 2). + N (int): Prediction horizon. + desired_v (float): Desired linear velocity. + v_max (float): Maximum linear velocity. + w_max (float): Maximum angular velocity. + ref_gap (int): Gap between reference points in the prediction horizon. + """ + self.N, self.desired_v, self.ref_gap, self.T = N, desired_v, ref_gap, 0.1 + self.ref_traj = self.make_ref_denser(global_planed_traj) + self.ref_traj_len = N // ref_gap + 1 + + # setup mpc problem + opti = ca.Opti() + opt_controls = opti.variable(N, 2) + v, w = opt_controls[:, 0], opt_controls[:, 1] + + opt_states = opti.variable(N + 1, 3) + # x, y, theta = opt_states[:, 0], opt_states[:, 1], opt_states[:, 2] + + # parameters + opt_x0 = opti.parameter(3) + opt_xs = opti.parameter(3 * self.ref_traj_len) # the intermidia state may also be the parameter + + # system dynamics for mobile manipulator + f = lambda x_, u_: ca.vertcat(*[u_[0] * ca.cos(x_[2]), u_[0] * ca.sin(x_[2]), u_[1]]) # noqa + + # init_condition + opti.subject_to(opt_states[0, :] == opt_x0.T) + for i in range(N): + x_next = opt_states[i, :] + f(opt_states[i, :], opt_controls[i, :]).T * self.T + opti.subject_to(opt_states[i + 1, :] == x_next) + + # define the cost function + Q = np.diag([10.0, 10.0, 0.0]) + R = np.diag([0.05, 0.2]) + obj = 0 + for i in range(N): + obj = obj + ca.mtimes([opt_controls[i, :], R, opt_controls[i, :].T]) + if i % ref_gap == 0: + nn = i // ref_gap + obj = obj + ca.mtimes( + [ + (opt_states[i, :] - opt_xs[nn * 3 : nn * 3 + 3].T), + Q, + (opt_states[i, :] - opt_xs[nn * 3 : nn * 3 + 3].T).T, + ] + ) + + opti.minimize(obj) + + # boundary and control conditions + opti.subject_to(opti.bounded(0, v, v_max)) + opti.subject_to(opti.bounded(-w_max, w, w_max)) + + opts_setting = { + 'ipopt.max_iter': 100, + 'ipopt.print_level': 0, + 'print_time': 0, + 'ipopt.acceptable_tol': 1e-8, + 'ipopt.acceptable_obj_change_tol': 1e-6, + } + opti.solver('ipopt', opts_setting) + # opts_setting = { 'qpsol':'osqp','hessian_approximation':'limited-memory','max_iter':200,'convexify_strategy':'regularize','beta':0.5,'c1':1e-4,'tol_du':1e-3,'tol_pr':1e-6} + # opti.solver('sqpmethod',opts_setting) + + self.opti = opti + self.opt_xs = opt_xs + self.opt_x0 = opt_x0 + self.opt_controls = opt_controls + self.opt_states = opt_states + self.last_opt_x_states = None + self.last_opt_u_controls = None + + def make_ref_denser(self, ref_traj, ratio=50): + x_orig = np.arange(len(ref_traj)) + new_x = np.linspace(0, len(ref_traj) - 1, num=len(ref_traj) * ratio) + + interp_func_x = interp1d(x_orig, ref_traj[:, 0], kind='linear') + interp_func_y = interp1d(x_orig, ref_traj[:, 1], kind='linear') + + uniform_x = interp_func_x(new_x) + uniform_y = interp_func_y(new_x) + ref_traj = np.stack((uniform_x, uniform_y), axis=1) + + return ref_traj + + def update_ref_traj(self, global_planed_traj): + self.ref_traj = self.make_ref_denser(global_planed_traj) + self.ref_traj_len = self.N // self.ref_gap + 1 + + def solve(self, x0): + ref_traj = self.find_reference_traj(x0, self.ref_traj) + # fake a yaw angle + ref_traj = np.concatenate((ref_traj, np.zeros((ref_traj.shape[0], 1))), axis=1).reshape(-1, 1) + + self.opti.set_value(self.opt_xs, ref_traj.reshape(-1, 1)) + u0 = np.zeros((self.N, 2)) if self.last_opt_u_controls is None else self.last_opt_u_controls + x00 = np.zeros((self.N + 1, 3)) if self.last_opt_x_states is None else self.last_opt_x_states + + self.opti.set_value(self.opt_x0, x0) + self.opti.set_initial(self.opt_controls, u0) + self.opti.set_initial(self.opt_states, x00) + + sol = self.opti.solve() + + self.last_opt_u_controls = sol.value(self.opt_controls) + self.last_opt_x_states = sol.value(self.opt_states) + + return self.last_opt_u_controls, self.last_opt_x_states + + def reset(self): + self.last_opt_x_states = None + self.last_opt_u_controls = None + + def find_reference_traj(self, x0, global_planed_traj): + ref_traj_pts = [] + # find the nearest point in global_planed_traj + nearest_idx = np.argmin(np.linalg.norm(global_planed_traj - x0[:2].reshape((1, 2)), axis=1)) + desire_arc_length = self.desired_v * self.ref_gap * self.T + cum_dist = np.cumsum(np.linalg.norm(np.diff(global_planed_traj, axis=0), axis=1)) + + # select the reference points from the nearest point to the end of global_planed_traj + for i in range(nearest_idx, len(global_planed_traj) - 1): + if cum_dist[i] - cum_dist[nearest_idx] >= desire_arc_length * len(ref_traj_pts): + ref_traj_pts.append(global_planed_traj[i, :]) + if len(ref_traj_pts) == self.ref_traj_len: + break + # if the target is reached before the reference trajectory is complete, add the last point of global_planed_traj + while len(ref_traj_pts) < self.ref_traj_len: + ref_traj_pts.append(global_planed_traj[-1, :]) + return np.array(ref_traj_pts) + + +class PID_controller: + def __init__(self, Kp_trans=1.0, Kd_trans=0.1, Kp_yaw=1.0, Kd_yaw=1.0, max_v=1.0, max_w=1.2): + """Initialize the PID controller. + + Args: + Kp_trans (float): Proportional gain for translational error. + Kd_trans (float): Derivative gain for translational error. + Kp_yaw (float): Proportional gain for yaw error. + Kd_yaw (float): Derivative gain for yaw error. + max_v (float): Maximum linear velocity. + max_w (float): Maximum angular velocity. + """ + self.Kp_trans = Kp_trans + self.Kd_trans = Kd_trans + self.Kp_yaw = Kp_yaw + self.Kd_yaw = Kd_yaw + self.max_v = max_v + self.max_w = max_w + + def solve(self, odom, target, vel=np.zeros(2)): + translation_error, yaw_error = self.calculate_errors(odom, target) + v, w = self.pd_step(translation_error, yaw_error, vel[0], vel[1]) + return v, w, translation_error, yaw_error + + def pd_step(self, translation_error, yaw_error, linear_vel, angular_vel): + translation_error = max(-1.0, min(1.0, translation_error)) + yaw_error = max(-1.0, min(1.0, yaw_error)) + + linear_velocity = self.Kp_trans * translation_error - self.Kd_trans * linear_vel + angular_velocity = self.Kp_yaw * yaw_error - self.Kd_yaw * angular_vel + + linear_velocity = max(-self.max_v, min(self.max_v, linear_velocity)) + angular_velocity = max(-self.max_w, min(self.max_w, angular_velocity)) + + return linear_velocity, angular_velocity + + def calculate_errors(self, odom, target): + + dx = target[0, 3] - odom[0, 3] + dy = target[1, 3] - odom[1, 3] + + odom_yaw = math.atan2(odom[1, 0], odom[0, 0]) + target_yaw = math.atan2(target[1, 0], target[0, 0]) + + translation_error = dx * np.cos(odom_yaw) + dy * np.sin(odom_yaw) + + yaw_error = target_yaw - odom_yaw + yaw_error = (yaw_error + math.pi) % (2 * math.pi) - math.pi + + return translation_error, yaw_error diff --git a/scripts/realworld/http_internvla_client.py b/scripts/realworld/http_internvla_client.py new file mode 100644 index 00000000..f7de8762 --- /dev/null +++ b/scripts/realworld/http_internvla_client.py @@ -0,0 +1,362 @@ +import copy +import io +import json +import math +import threading +import time +from collections import deque +from enum import Enum + +import numpy as np +import rclpy +import requests +from geometry_msgs.msg import Twist +from nav_msgs.msg import Odometry +from PIL import Image as PIL_Image +from sensor_msgs.msg import Image + +frame_data = {} +frame_idx = 0 +# user-specific +from controllers import Mpc_controller, PID_controller +from cv_bridge import CvBridge +from message_filters import ApproximateTimeSynchronizer, Subscriber +from rclpy.node import Node +from rclpy.qos import HistoryPolicy, QoSProfile, ReliabilityPolicy +from thread_utils import ReadWriteLock + + +class ControlMode(Enum): + PID_Mode = 1 + MPC_Mode = 2 + + +# global variable +policy_init = True +mpc = None +pid = PID_controller(Kp_trans=2.0, Kd_trans=0.0, Kp_yaw=1.5, Kd_yaw=0.0, max_v=0.6, max_w=0.5) +http_idx = -1 +first_running_time = 0.0 +last_pixel_goal = None +last_s2_step = -1 +manager = None +current_control_mode = ControlMode.MPC_Mode +trajs_in_world = None + +desired_v, desired_w = 0.0, 0.0 +rgb_depth_rw_lock = ReadWriteLock() +odom_rw_lock = ReadWriteLock() +mpc_rw_lock = ReadWriteLock() + + +def dual_sys_eval(image_bytes, depth_bytes, front_image_bytes, url='http://127.0.0.1:5801/eval_dual'): + global policy_init, http_idx, first_running_time + data = {"reset": policy_init, "idx": http_idx} + json_data = json.dumps(data) + + policy_init = False + files = { + 'image': ('rgb_image', image_bytes, 'image/jpeg'), + 'depth': ('depth_image', depth_bytes, 'image/png'), + } + start = time.time() + response = requests.post(url, files=files, data={'json': json_data}, timeout=100) + print(f"response {response.text}") + http_idx += 1 + if http_idx == 0: + first_running_time = time.time() + print(f"idx: {http_idx} after http {time.time() - start}") + + return json.loads(response.text) + + +def control_thread(): + global desired_v, desired_w + while True: + global current_control_mode + if current_control_mode == ControlMode.MPC_Mode: + odom_rw_lock.acquire_read() + odom = manager.odom.copy() if manager.odom else None + odom_rw_lock.release_read() + if mpc is not None and manager is not None and odom is not None: + local_mpc = mpc + opt_u_controls, opt_x_states = local_mpc.solve(np.array(odom)) + v, w = opt_u_controls[0, 0], opt_u_controls[0, 1] + + desired_v, desired_w = v, w + manager.move(v, 0.0, w) + elif current_control_mode == ControlMode.PID_Mode: + odom_rw_lock.acquire_read() + odom = manager.odom.copy() if manager.odom else None + odom_rw_lock.release_read() + homo_odom = manager.homo_odom.copy() if manager.homo_odom is not None else None + vel = manager.vel.copy() if manager.vel is not None else None + homo_goal = manager.homo_goal.copy() if manager.homo_goal is not None else None + + if homo_odom is not None and vel is not None and homo_goal is not None: + v, w, e_p, e_r = pid.solve(homo_odom, homo_goal, vel) + if v < 0.0: + v = 0.0 + desired_v, desired_w = v, w + manager.move(v, 0.0, w) + + time.sleep(0.1) + + +def planning_thread(): + global trajs_in_world + + while True: + start_time = time.time() + DESIRED_TIME = 0.3 + time.sleep(0.05) + + if not manager.new_image_arrived: + time.sleep(0.01) + continue + manager.new_image_arrived = False + rgb_depth_rw_lock.acquire_read() + rgb_bytes = copy.deepcopy(manager.rgb_bytes) + depth_bytes = copy.deepcopy(manager.depth_bytes) + infer_rgb = copy.deepcopy(manager.rgb_image) + infer_depth = copy.deepcopy(manager.depth_image) + rgb_time = manager.rgb_time + rgb_depth_rw_lock.release_read() + odom_rw_lock.acquire_read() + min_diff = 1e10 + # time_diff = 1e10 + odom_infer = None + for odom in manager.odom_queue: + diff = abs(odom[0] - rgb_time) + if diff < min_diff: + min_diff = diff + odom_infer = copy.deepcopy(odom[1]) + # time_diff = odom[0] - rgb_time + # odom_time = manager.odom_timestamp + odom_rw_lock.release_read() + + if odom_infer is not None and rgb_bytes is not None and depth_bytes is not None: + global frame_data + frame_data[http_idx] = { + 'infer_rgb': copy.deepcopy(infer_rgb), + 'infer_depth': copy.deepcopy(infer_depth), + 'infer_odom': copy.deepcopy(odom_infer), + } + if len(frame_data) > 100: + del frame_data[min(frame_data.keys())] + response = dual_sys_eval(rgb_bytes, depth_bytes, None) + + global current_control_mode + traj_len = 0.0 + if 'trajectory' in response: + trajectory = response['trajectory'] + trajs_in_world = [] + odom = odom_infer + traj_len = np.linalg.norm(trajectory[-1][:2]) + print(f"traj len {traj_len}") + for i, traj in enumerate(trajectory): + if i < 3: + continue + x_, y_, yaw_ = odom[0], odom[1], odom[2] + + w_T_b = np.array( + [ + [np.cos(yaw_), -np.sin(yaw_), 0, x_], + [np.sin(yaw_), np.cos(yaw_), 0, y_], + [0.0, 0.0, 1.0, 0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + w_P = (w_T_b @ (np.array([traj[0], traj[1], 0.0, 1.0])).T)[:2] + trajs_in_world.append(w_P) + trajs_in_world = np.array(trajs_in_world) + print(f"{time.time()} update traj") + + manager.last_trajs_in_world = trajs_in_world + mpc_rw_lock.acquire_write() + global mpc + if mpc is None: + mpc = Mpc_controller(np.array(trajs_in_world)) + else: + mpc.update_ref_traj(np.array(trajs_in_world)) + manager.request_cnt += 1 + mpc_rw_lock.release_write() + current_control_mode = ControlMode.MPC_Mode + elif 'discrete_action' in response: + actions = response['discrete_action'] + if actions != [5] and actions != [9]: + manager.incremental_change_goal(actions) + current_control_mode = ControlMode.PID_Mode + else: + print( + f"skip planning. odom_infer: {odom_infer is not None} rgb_bytes: {rgb_bytes is not None} depth_bytes: {depth_bytes is not None}" + ) + time.sleep(0.1) + + time.sleep(max(0, DESIRED_TIME - (time.time() - start_time))) + + +class Go2Manager(Node): + def __init__(self): + super().__init__('go2_manager') + + rgb_down_sub = Subscriber(self, Image, "/camera/camera/color/image_raw") + depth_down_sub = Subscriber(self, Image, "/camera/camera/aligned_depth_to_color/image_raw") + + qos_profile = QoSProfile(reliability=ReliabilityPolicy.BEST_EFFORT, history=HistoryPolicy.KEEP_LAST, depth=10) + + self.syncronizer = ApproximateTimeSynchronizer([rgb_down_sub, depth_down_sub], 1, 0.1) + self.syncronizer.registerCallback(self.rgb_depth_down_callback) + self.odom_sub = self.create_subscription(Odometry, "/odom_bridge", self.odom_callback, qos_profile) + + # publisher + self.control_pub = self.create_publisher(Twist, '/cmd_vel_bridge', 5) + + # class member variable + self.cv_bridge = CvBridge() + self.rgb_image = None + self.rgb_bytes = None + self.depth_image = None + self.depth_bytes = None + self.rgb_forward_image = None + self.rgb_forward_bytes = None + self.new_image_arrived = False + self.new_vis_image_arrived = False + self.rgb_time = 0.0 + + self.odom = None + self.linear_vel = 0.0 + self.angular_vel = 0.0 + self.request_cnt = 0 + self.odom_cnt = 0 + self.odom_queue = deque(maxlen=50) + self.odom_timestamp = 0.0 + + self.last_s2_step = -1 + self.last_trajs_in_world = None + self.last_all_trajs_in_world = None + self.homo_odom = None + self.homo_goal = None + self.vel = None + + def rgb_forward_callback(self, rgb_msg): + raw_image = self.cv_bridge.imgmsg_to_cv2(rgb_msg, 'rgb8')[:, :, :] + self.rgb_forward_image = raw_image + image = PIL_Image.fromarray(self.rgb_forward_image) + image_bytes = io.BytesIO() + image.save(image_bytes, format='JPEG') + image_bytes.seek(0) + self.rgb_forward_bytes = image_bytes + self.new_vis_image_arrived = True + self.new_image_arrived = True + + def rgb_depth_down_callback(self, rgb_msg, depth_msg): + raw_image = self.cv_bridge.imgmsg_to_cv2(rgb_msg, 'rgb8')[:, :, :] + self.rgb_image = raw_image + image = PIL_Image.fromarray(self.rgb_image) + image_bytes = io.BytesIO() + image.save(image_bytes, format='JPEG') + image_bytes.seek(0) + + raw_depth = self.cv_bridge.imgmsg_to_cv2(depth_msg, '16UC1') + raw_depth[np.isnan(raw_depth)] = 0 + raw_depth[np.isinf(raw_depth)] = 0 + self.depth_image = raw_depth / 1000.0 + self.depth_image -= 0.0 + self.depth_image[np.where(self.depth_image < 0)] = 0 + depth = (np.clip(self.depth_image * 10000.0, 0, 65535)).astype(np.uint16) + depth = PIL_Image.fromarray(depth) + depth_bytes = io.BytesIO() + depth.save(depth_bytes, format='PNG') + depth_bytes.seek(0) + + rgb_depth_rw_lock.acquire_write() + self.rgb_bytes = image_bytes + + self.rgb_time = rgb_msg.header.stamp.sec + rgb_msg.header.stamp.nanosec / 1.0e9 + self.last_rgb_time = self.rgb_time + + self.depth_bytes = depth_bytes + self.depth_time = depth_msg.header.stamp.sec + depth_msg.header.stamp.nanosec / 1.0e9 + self.last_depth_time = self.depth_time + + rgb_depth_rw_lock.release_write() + + self.new_vis_image_arrived = True + self.new_image_arrived = True + + def odom_callback(self, msg): + self.odom_cnt += 1 + odom_rw_lock.acquire_write() + zz = msg.pose.pose.orientation.z + ww = msg.pose.pose.orientation.w + yaw = math.atan2(2 * zz * ww, 1 - 2 * zz * zz) + self.odom = [msg.pose.pose.position.x, msg.pose.pose.position.y, yaw] + self.odom_queue.append((time.time(), copy.deepcopy(self.odom))) + self.odom_timestamp = time.time() + self.linear_vel = msg.twist.twist.linear.x + self.angular_vel = msg.twist.twist.angular.z + odom_rw_lock.release_write() + + R0 = np.array([[np.cos(yaw), -np.sin(yaw)], [np.sin(yaw), np.cos(yaw)]]) + self.homo_odom = np.eye(4) + self.homo_odom[:2, :2] = R0 + self.homo_odom[:2, 3] = [msg.pose.pose.position.x, msg.pose.pose.position.y] + self.vel = [msg.twist.twist.linear.x, msg.twist.twist.angular.z] + + if self.odom_cnt == 1: + self.homo_goal = self.homo_odom.copy() + + def incremental_change_goal(self, actions): + if self.homo_goal is None: + raise ValueError("Please initialize homo_goal before change it!") + homo_goal = self.homo_odom.copy() + for each_action in actions: + if each_action == 0: + pass + elif each_action == 1: + yaw = math.atan2(homo_goal[1, 0], homo_goal[0, 0]) + homo_goal[0, 3] += 0.25 * np.cos(yaw) + homo_goal[1, 3] += 0.25 * np.sin(yaw) + elif each_action == 2: + angle = math.radians(15) + rotation_matrix = np.array( + [[math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1]] + ) + homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) + elif each_action == 3: + angle = -math.radians(15.0) + rotation_matrix = np.array( + [[math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1]] + ) + homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) + self.homo_goal = homo_goal + + def move(self, vx, vy, vyaw): + request = Twist() + request.linear.x = vx + request.linear.y = 0.0 + request.angular.z = vyaw + + self.control_pub.publish(request) + + +if __name__ == '__main__': + control_thread_instance = threading.Thread(target=control_thread) + planning_thread_instance = threading.Thread(target=planning_thread) + control_thread_instance.daemon = True + planning_thread_instance.daemon = True + rclpy.init() + + try: + manager = Go2Manager() + + control_thread_instance.start() + planning_thread_instance.start() + + rclpy.spin(manager) + except KeyboardInterrupt: + pass + finally: + manager.destroy_node() + rclpy.shutdown() diff --git a/scripts/realworld/http_internvla_server.py b/scripts/realworld/http_internvla_server.py new file mode 100644 index 00000000..c93bbe83 --- /dev/null +++ b/scripts/realworld/http_internvla_server.py @@ -0,0 +1,94 @@ +import argparse +import json +import os +import time +from datetime import datetime + +import numpy as np +from flask import Flask, jsonify, request +from PIL import Image + +from internnav.agent.internvla_n1_agent_realworld import InternVLAN1AsyncAgent + +app = Flask(__name__) +idx = 0 +start_time = time.time() +output_dir = '' + + +@app.route("/eval_dual", methods=['POST']) +def eval_dual(): + global idx, output_dir, start_time + start_time = time.time() + + image_file = request.files['image'] + depth_file = request.files['depth'] + json_data = request.form['json'] + data = json.loads(json_data) + + image = Image.open(image_file.stream) + image = image.convert('RGB') + image = np.asarray(image) + + depth = Image.open(depth_file.stream) + depth = depth.convert('I') + depth = np.asarray(depth) + depth = depth.astype(np.float32) / 10000.0 + print(f"read http data cost {time.time() - start_time}") + + camera_pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + instruction = "Turn around and walk out of this office. Turn towards your slight right at the chair. Move forward to the walkway and go near the red bin. You can see an open door on your right side, go inside the open door. Stop at the computer monitor" + policy_init = data['reset'] + if policy_init: + start_time = time.time() + idx = 0 + output_dir = 'output/runs' + datetime.now().strftime('%m-%d-%H%M') + os.makedirs(output_dir, exist_ok=True) + print("init reset model!!!") + agent.reset() + + idx += 1 + + look_down = False + t0 = time.time() + dual_sys_output = {} + + dual_sys_output = agent.step( + image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down + ) + if dual_sys_output.output_action is not None and dual_sys_output.output_action == [5]: + look_down = True + dual_sys_output = agent.step( + image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down + ) + + json_output = {} + if dual_sys_output.output_action is not None: + json_output['discrete_action'] = dual_sys_output.output_action + if dual_sys_output.output_pixel is not None: + json_output['pixel_goal'] = dual_sys_output.output_pixel + + t1 = time.time() + generate_time = t1 - t0 + print(f"dual sys step {generate_time}") + print(f"json_output {json_output}") + return jsonify(json_output) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--model_path", type=str, default="/path/to/InternVLA-N1") + parser.add_argument("--resize_w", type=int, default=384) + parser.add_argument("--resize_h", type=int, default=384) + parser.add_argument("--num_history", type=int, default=8) + args = parser.parse_args() + + args.camera_intrinsic = np.array( + [[386.5, 0.0, 328.9, 0.0], [0.0, 386.5, 244, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + ) + agent = InternVLAN1AsyncAgent(args) + agent.reset() + + app.run(host='0.0.0.0', port=5801) diff --git a/scripts/realworld/thread_utils.py b/scripts/realworld/thread_utils.py new file mode 100644 index 00000000..2bd02c5a --- /dev/null +++ b/scripts/realworld/thread_utils.py @@ -0,0 +1,28 @@ +import threading + + +class ReadWriteLock: + def __init__(self): + self._read_ready = threading.Condition(threading.Lock()) + self._readers = 0 + + def acquire_read(self): + with self._read_ready: + self._read_ready.wait_for(lambda: self._readers >= 0) + self._readers += 1 + + def release_read(self): + with self._read_ready: + self._readers -= 1 + if self._readers == 0: + self._read_ready.notify_all() + + def acquire_write(self): + with self._read_ready: + self._read_ready.wait_for(lambda: self._readers == 0) + self._readers = -1 + + def release_write(self): + with self._read_ready: + self._readers = 0 + self._read_ready.notify_all() diff --git a/setup.cfg b/setup.cfg index 73140fbd..74e1c305 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ skip_glob = internutopia/*, internutopia_extension/* # than "BA" [codespell] quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,jetbot,wth,coverted,descrete +ignore-words-list = patten,nd,ty,mot,hist,formating,jetbot,wth,coverted,descrete,thw,ro skip = *.js *.txt