From b0de752f086784a904f611ab7e2fa50b8f49ca99 Mon Sep 17 00:00:00 2001 From: yangyuqiang Date: Wed, 24 Sep 2025 21:08:33 +0800 Subject: [PATCH 1/6] [feat] Add real-world InternVLA-N1 server code --- .../agent/internvla_n1_agent_realworld.py | 237 ++++++++++++++++++ scripts/realworld/http_internvla_server.py | 91 +++++++ 2 files changed, 328 insertions(+) create mode 100644 internnav/agent/internvla_n1_agent_realworld.py create mode 100644 scripts/realworld/http_internvla_server.py diff --git a/internnav/agent/internvla_n1_agent_realworld.py b/internnav/agent/internvla_n1_agent_realworld.py new file mode 100644 index 00000000..0d2b801a --- /dev/null +++ b/internnav/agent/internvla_n1_agent_realworld.py @@ -0,0 +1,237 @@ +import copy +import itertools +import os +import re +import time +import torch +import sys +import numpy as np +from datetime import datetime +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from PIL import Image, ImageFile, ImageDraw, ImageFont +from internnav.model.utils.vln_utils import split_and_clean, S2Output, traj_to_actions +from collections import OrderedDict + +from transformers import ( + AutoTokenizer, + AutoProcessor, +) +from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM + + +DEFAULT_IMAGE_TOKEN = "" +class InternVLAN1AsyncAgent: + def __init__(self, args): + self.device = torch.device(args.device) + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") + self.model = InternVLAN1ForCausalLM.from_pretrained( + args.model_path, torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", device_map={"": self.device} + ) + self.model.eval() + self.model.to(self.device) + + self.processor = AutoProcessor.from_pretrained(args.model_path) + self.processor.tokenizer.padding_side = 'left' + + self.resize_w = args.resize_w + self.resize_h = args.resize_h + self.num_history = args.num_history + + prompt = f"You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task." + answer = "" + self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] + self.conjunctions = [ + 'you can see ', + 'in front of you is ', + 'there is ', + 'you can spot ', + 'you are toward the ', + 'ahead of you is ', + 'in your sight is ' + ] + + self.actions2idx = OrderedDict({ + 'STOP': [0], + "↑": [1], + "←": [2], + "→": [3], + "↓": [5], + }) + + self.rgb_list = [] + self.depth_list = [] + self.pose_list = [] + self.episode_idx = 0 + self.conversation_history = [] + self.llm_output = "" + self.past_key_values = None + self.last_s2_idx = -100 + + # output + self.output_action = None + self.output_latent = None + self.output_pixel = None + self.pixel_goal_rgb = None + self.pixel_goal_depth = None + + def reset(self): + self.rgb_list = [] + self.depth_list = [] + self.pose_list = [] + self.episode_idx = 0 + self.conversation_history = [] + self.llm_output = "" + self.past_key_values = None + + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") + os.makedirs(self.save_dir, exist_ok=True) + def parse_actions(self, output): + action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) + regex = re.compile(action_patterns) + matches = regex.findall(output) + actions = [self.actions2idx[match] for match in matches] + actions = itertools.chain.from_iterable(actions) + return list(actions) + + def step_no_infer(self, rgb, depth, pose): + image = Image.fromarray(rgb).convert('RGB') + raw_image_size = image.size + image = image.resize((self.resize_w, self.resize_h)) + self.rgb_list.append(image) + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}.jpg") + self.episode_idx += 1 + + def trajectory_tovw(self, trajectory, kp = 1.0): + subgoal = trajectory[-1] + linear_vel, angular_vel = kp * np.linalg.norm(subgoal[:2]), kp * subgoal[2] + linear_vel = np.clip(linear_vel, 0, 0.5) + angular_vel = np.clip(angular_vel, -0.5, 0.5) + return linear_vel, angular_vel + + def step(self, rgb, depth, pose, instruction, intrinsic, look_down = False): + dual_sys_output = S2Output() + PLAN_STEP_GAP = 8 + no_output_flag = (self.output_action is None and self.output_latent is None) + if (self.episode_idx - self.last_s2_idx > PLAN_STEP_GAP) or look_down or no_output_flag: + self.output_action, self.output_latent, self.output_pixel = self.step_s2(rgb, depth, pose, instruction, intrinsic, look_down) + self.last_s2_idx = self.episode_idx + dual_sys_output.output_pixel = self.output_pixel + self.pixel_goal_rgb = copy.deepcopy(rgb) + self.pixel_goal_depth = copy.deepcopy(depth) + else: + self.step_no_infer(rgb, depth, pose) + + + if self.output_action is not None: + dual_sys_output.output_action = copy.deepcopy(self.output_action) + self.output_action = None + elif self.output_latent is not None: + processed_pixel_rgb = np.array(Image.fromarray(self.pixel_goal_rgb).resize((224, 224))) / 255 + processed_pixel_depth = np.array(Image.fromarray(self.pixel_goal_depth).resize((224, 224))) + processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255 + processed_depth = np.array(Image.fromarray(depth).resize((224, 224))) + rgbs = torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]).unsqueeze(0).to(self.device) + depths = torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]).unsqueeze(0).unsqueeze(-1).to(self.device) + trajectories = self.step_s1(self.output_latent, rgbs, depths) + + dual_sys_output.output_action = traj_to_actions(trajectories) + + return dual_sys_output + + def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False): + image = Image.fromarray(rgb).convert('RGB') + raw_image_size = image.size + if not look_down: + image = image.resize((self.resize_w, self.resize_h)) + self.rgb_list.append(image) + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}.jpg") + else: + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}_look_down.jpg") + if not look_down: + self.conversation_history = [] + self.past_key_values = None + + sources = copy.deepcopy(self.conversation) + sources[0]["value"] = sources[0]["value"].replace('.', instruction) + cur_images = self.rgb_list[-1:] + if self.episode_idx == 0: + history_id = [] + else: + history_id = np.unique(np.linspace(0, self.episode_idx - 1, self.num_history, dtype=np.int32)).tolist() + placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) + sources[0]["value"] += f' These are your historical observations: {placeholder}.' + + history_id = sorted(history_id) + self.input_images = [self.rgb_list[i] for i in history_id] + cur_images + input_img_id = 0 + self.episode_idx += 1 + else: + self.input_images.append(image) + input_img_id = -1 + assert self.llm_output != "", "Last llm_output should not be empty when look down" + sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] + self.conversation_history.append({ 'role': 'assistant', 'content': [{ 'type': 'text', 'text': self.llm_output}]}) + + prompt = self.conjunctions[0] + DEFAULT_IMAGE_TOKEN + sources[0]["value"] += f" {prompt}." + prompt_instruction = copy.deepcopy(sources[0]["value"]) + parts = split_and_clean(prompt_instruction) + + content = [] + for i in range (len(parts)): + if parts[i] == "": + content.append({"type": "image", "image": self.input_images[input_img_id]}) + input_img_id +=1 + else: + content.append({"type": "text", "text": parts[i]}) + + self.conversation_history.append({'role': 'user', 'content': content}) + + text = self.processor.apply_chat_template( + self.conversation_history, tokenize=False, add_generation_prompt=True + ) + + inputs = self.processor(text=[text], images=self.input_images, return_tensors="pt").to(self.device) + t0 = time.time() + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=128, + do_sample=False, + # use_cache=True, + # past_key_values=self.past_key_values, + return_dict_in_generate=True, + # raw_input_ids=copy.deepcopy(inputs.input_ids), + ) + output_ids = outputs.sequences + + t1 = time.time() + self.llm_output = self.processor.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + with open(f"{self.save_dir}/llm_output_{self.episode_idx:04d}.txt", 'w') as f: + f.write(self.llm_output) + self.last_output_ids = copy.deepcopy(output_ids[0]) + self.past_key_values = copy.deepcopy(outputs.past_key_values) + print(f"output {self.episode_idx} {self.llm_output} cost:{t1-t0}s") + if bool(re.search(r'\d', self.llm_output)): + coord = [int(c) for c in re.findall(r'\d+', self.llm_output)] + pixel_goal = [int(coord[1]), int(coord[0])] + image_grid_thw = torch.cat( + [thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0 + ) + pixel_values = inputs.pixel_values + t0 = time.time() + with torch.no_grad(): + traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) + return None, traj_latents, pixel_goal + + else: + action_seq = self.parse_actions(self.llm_output) + return action_seq, None, None + + def step_s1(self, latent, rgb, depth): + all_trajs = self.model.generate_traj(latent, rgb, depth, use_async=True) + return all_trajs diff --git a/scripts/realworld/http_internvla_server.py b/scripts/realworld/http_internvla_server.py new file mode 100644 index 00000000..740b6691 --- /dev/null +++ b/scripts/realworld/http_internvla_server.py @@ -0,0 +1,91 @@ +import numpy as np +import argparse +import os +import json +import sys +import time +sys.path.append('/home/pjlab/yq_ws/InternNav') + +from flask import Flask, request, jsonify +from PIL import Image +from datetime import datetime +from internnav.agent.internvla_n1_agent_realworld import InternVLAN1AsyncAgent +from internnav.model.utils.vln_utils import S2Output +app = Flask(__name__) +idx = 0 +start_time = time.time() +output_dir = '' +@app.route("/eval_dual",methods=['POST']) +def eval_dual(): + global idx, output_dir, start_time + start_time = time.time() + + image_file = request.files['image'] + depth_file = request.files['depth'] + json_data = request.form['json'] + data = json.loads(json_data) + + image = Image.open(image_file.stream) + image = image.convert('RGB') + image = np.asarray(image) + + depth = Image.open(depth_file.stream) + depth = depth.convert('I') + depth = np.asarray(depth) + depth = depth.astype(np.float32) / 10000.0 + print(f"read http data cost {time.time() - start_time}") + + camera_pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + instruction = "Turn around and walk out of this office. Turn towards your slight right at the chair. Move forward to the walkway and go near the red bin. You can see an open door on your right side, go inside the open door. Stop at the computer monitor" + policy_init = data['reset'] + if policy_init: + start_time = time.time() + idx = 0 + output_dir = 'output/runs' + datetime.now().strftime('%m-%d-%H%M') + os.makedirs(output_dir, exist_ok=True) + print("init reset model!!!") + agent.reset() + + idx += 1 + + look_down = False + t0 = time.time() + dual_sys_output = {} + + dual_sys_output = agent.step(image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down) + if dual_sys_output.output_action is not None and dual_sys_output.output_action == [5]: + look_down = True + dual_sys_output = agent.step(image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down) + + json_output = {} + if dual_sys_output.output_action is not None: + json_output['discrete_action'] = dual_sys_output.output_action + if dual_sys_output.output_pixel is not None: + json_output['pixel_goal'] = dual_sys_output.output_pixel + + t1 = time.time() + generate_time = t1 - t0 + print(f"dual sys step {generate_time}") + print(f"json_output {json_output}") + return jsonify(json_output) + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--model_path", type=str, default="/path/to/InternVLA-N1") + parser.add_argument("--resize_w", type=int, default=384) + parser.add_argument("--resize_h", type=int, default=384) + parser.add_argument("--num_history", type=int, default=8) + args = parser.parse_args() + + args.camera_intrinsic = np.array([[386.5 , 0. , 328.9, 0. ], + [ 0. , 386.5 , 244, 0. ], + [ 0. , 0. , 1. , 0. ], + [ 0. , 0. , 0. , 1. ]]) + agent = InternVLAN1AsyncAgent(args) + agent.reset() + + app.run(host='0.0.0.0', port=5801) \ No newline at end of file From d1cbf714207dcbf2ee78458b8f67628b73b11289 Mon Sep 17 00:00:00 2001 From: yangyuqiang Date: Wed, 24 Sep 2025 21:15:02 +0800 Subject: [PATCH 2/6] [feat] Add kv cache for InternVLA-N1 realworld deployment --- .../agent/internvla_n1_agent_realworld.py | 6 +- .../basemodel/internvla_n1/internvla_n1.py | 65 ++++++++++++++++++- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/internnav/agent/internvla_n1_agent_realworld.py b/internnav/agent/internvla_n1_agent_realworld.py index 0d2b801a..c86940fa 100644 --- a/internnav/agent/internvla_n1_agent_realworld.py +++ b/internnav/agent/internvla_n1_agent_realworld.py @@ -202,10 +202,10 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False): **inputs, max_new_tokens=128, do_sample=False, - # use_cache=True, - # past_key_values=self.past_key_values, + use_cache=True, + past_key_values=self.past_key_values, return_dict_in_generate=True, - # raw_input_ids=copy.deepcopy(inputs.input_ids), + raw_input_ids=copy.deepcopy(inputs.input_ids), ) output_ids = outputs.sequences diff --git a/internnav/model/basemodel/internvla_n1/internvla_n1.py b/internnav/model/basemodel/internvla_n1/internvla_n1.py index 41b9b97a..9e6f2450 100644 --- a/internnav/model/basemodel/internvla_n1/internvla_n1.py +++ b/internnav/model/basemodel/internvla_n1/internvla_n1.py @@ -101,6 +101,49 @@ def __init__(self, config): def get_model(self): return self.model + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + # add for QwenVL kv cache + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + + return model_inputs + + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -121,6 +164,7 @@ def forward( rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + raw_input_ids: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -169,10 +213,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + if pixel_values is not None and n_image_tokens > 0: pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + image_embeds = image_embeds[-n_image_tokens:] n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( @@ -232,6 +277,22 @@ def forward( attention_mask, ) self.rope_deltas = rope_deltas + elif n_image_tokens > 0: # using only for kv cache + attention_mask = attention_mask[:, :raw_input_ids.shape[1]] + position_ids, rope_deltas = self.get_rope_index( + raw_input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = position_ids[:, :,-input_ids.shape[1]:] + self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape From cb3a53ab721d7c1a1ae3ba158fe606667e055295 Mon Sep 17 00:00:00 2001 From: yangyuqiang Date: Wed, 24 Sep 2025 21:49:13 +0800 Subject: [PATCH 3/6] [feat] 1. Add realworld deployment code on robot. 2. Add mpc and pid controller. 3. InternVLA-N1 client --- scripts/realworld/controllers.py | 166 ++++++++++ scripts/realworld/http_internvla_client.py | 366 +++++++++++++++++++++ scripts/realworld/thread_utils.py | 27 ++ 3 files changed, 559 insertions(+) create mode 100644 scripts/realworld/controllers.py create mode 100644 scripts/realworld/http_internvla_client.py create mode 100644 scripts/realworld/thread_utils.py diff --git a/scripts/realworld/controllers.py b/scripts/realworld/controllers.py new file mode 100644 index 00000000..9c9b443a --- /dev/null +++ b/scripts/realworld/controllers.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# coding=utf-8 + +import casadi as ca +import numpy as np +import time +import math +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from scipy.interpolate import interp1d + +class Mpc_controller: + def __init__(self, global_planed_traj, N = 20, desired_v = 0.3, v_max = 0.4, w_max = 0.4, ref_gap = 4): + self.N, self.desired_v, self.ref_gap, self.T = N, desired_v, ref_gap, 0.1 + self.ref_traj = self.make_ref_denser(global_planed_traj) + self.ref_traj_len = N // ref_gap + 1 + + # setup mpc problem + opti = ca.Opti() + opt_controls = opti.variable(N, 2) + v, w = opt_controls[:, 0], opt_controls[:, 1] + + opt_states = opti.variable(N+1, 3) + x, y, theta = opt_states[:, 0], opt_states[:, 1], opt_states[:, 2] + + # parameters + opt_x0 = opti.parameter(3) + opt_xs = opti.parameter(3 * self.ref_traj_len) # the intermidia state may also be the parameter + + # system dynamics for mobile manipulator + f = lambda x_, u_: ca.vertcat(*[u_[0]*ca.cos(x_[2]), u_[0]*ca.sin(x_[2]), u_[1]]) + + # init_condition + opti.subject_to(opt_states[0, :] == opt_x0.T) + for i in range(N): + x_next = opt_states[i, :] + f(opt_states[i, :], opt_controls[i, :]).T*self.T + opti.subject_to(opt_states[i+1, :]==x_next) + + # define the cost function + Q = np.diag([10.0,10.0,0.0]) + R = np.diag([0.05,0.2]) + obj = 0 + for i in range(N): + obj = obj +ca.mtimes([opt_controls[i, :], R, opt_controls[i, :].T]) + if i % ref_gap == 0: + nn = i // ref_gap + obj = obj + ca.mtimes([(opt_states[i, :]-opt_xs[nn*3:nn*3+3].T), Q, (opt_states[i, :]-opt_xs[nn*3:nn*3+3].T).T]) + + opti.minimize(obj) + + # boundrary and control conditions + opti.subject_to(opti.bounded(0, v, v_max)) + opti.subject_to(opti.bounded(-w_max, w, w_max)) + + opts_setting = {'ipopt.max_iter':100, 'ipopt.print_level':0, 'print_time':0, 'ipopt.acceptable_tol':1e-8, 'ipopt.acceptable_obj_change_tol':1e-6} + opti.solver('ipopt', opts_setting) + # opts_setting = { 'qpsol':'osqp','hessian_approximation':'limited-memory','max_iter':200,'convexify_strategy':'regularize','beta':0.5,'c1':1e-4,'tol_du':1e-3,'tol_pr':1e-6} + # opti.solver('sqpmethod',opts_setting) + + self.opti = opti + self.opt_xs = opt_xs + self.opt_x0 = opt_x0 + self.opt_controls = opt_controls + self.opt_states = opt_states + self.last_opt_x_states = None + self.last_opt_u_controls = None + def make_ref_denser(self, ref_traj, ratio = 50): + x_orig = np.arange(len(ref_traj)) + new_x = np.linspace(0, len(ref_traj) - 1, num=len(ref_traj) * ratio) + + interp_func_x = interp1d(x_orig, ref_traj[:, 0], kind='linear') + interp_func_y = interp1d(x_orig, ref_traj[:, 1], kind='linear') + + uniform_x = interp_func_x(new_x) + uniform_y = interp_func_y(new_x) + ref_traj = np.stack((uniform_x, uniform_y), axis=1) + + return ref_traj + + def update_ref_traj(self, global_planed_traj): + self.ref_traj = self.make_ref_denser(global_planed_traj) + self.ref_traj_len = self.N // self.ref_gap + 1 + + def solve(self, x0): + ref_traj = self.find_reference_traj(x0, self.ref_traj) + # fake a yaw angle + ref_traj = np.concatenate((ref_traj, np.zeros((ref_traj.shape[0], 1))), axis=1).reshape(-1, 1) + + self.opti.set_value(self.opt_xs, ref_traj.reshape(-1, 1)) + u0 = np.zeros((self.N, 2)) if self.last_opt_u_controls is None else self.last_opt_u_controls + x00 = np.zeros((self.N+1, 3)) if self.last_opt_x_states is None else self.last_opt_x_states + + self.opti.set_value(self.opt_x0, x0) + self.opti.set_initial(self.opt_controls, u0) + self.opti.set_initial(self.opt_states, x00) + + sol = self.opti.solve() + + self.last_opt_u_controls = sol.value(self.opt_controls) + self.last_opt_x_states = sol.value(self.opt_states) + + return self.last_opt_u_controls, self.last_opt_x_states + def reset(self): + self.last_opt_x_states = None + self.last_opt_u_controls = None + + def find_reference_traj(self, x0, global_planed_traj): + ref_traj_pts = [] + # find the nearest point in global_planed_traj + nearest_idx = np.argmin(np.linalg.norm(global_planed_traj - x0[:2].reshape((1, 2)), axis=1)) + desire_arc_length = self.desired_v * self.ref_gap * self.T + cum_dist = np.cumsum(np.linalg.norm(np.diff(global_planed_traj, axis=0), axis=1)) + + # select the reference points from the nearest point to the end of global_planed_traj + for i in range(nearest_idx, len(global_planed_traj) - 1): + if cum_dist[i] - cum_dist[nearest_idx] >= desire_arc_length * len(ref_traj_pts): + ref_traj_pts.append(global_planed_traj[i, :]) + if len(ref_traj_pts) == self.ref_traj_len: + break + # if the target is reached before the reference trajectory is complete, add the last point of global_planed_traj + while len(ref_traj_pts) < self.ref_traj_len: + ref_traj_pts.append(global_planed_traj[-1, :]) + return np.array(ref_traj_pts) + + +class PID_controller: + def __init__(self, Kp_trans=1.0, Kd_trans=0.1, Kp_yaw=1.0, Kd_yaw=1.0, max_v=1.0, max_w=1.2): + self.Kp_trans = Kp_trans + self.Kd_trans = Kd_trans + self.Kp_yaw = Kp_yaw + self.Kd_yaw = Kd_yaw + self.max_v = max_v + self.max_w = max_w + + def solve(self, odom, target, vel=np.zeros(2)): + translation_error, yaw_error = self.calculate_errors(odom, target) + v, w = self.pd_step(translation_error, yaw_error, vel[0], vel[1]) + return v, w, translation_error, yaw_error + + def pd_step(self, translation_error, yaw_error, linear_vel, angular_vel): + translation_error = max(-1.0, min(1.0, translation_error)) + yaw_error = max(-1.0, min(1.0, yaw_error)) + + linear_velocity = self.Kp_trans * translation_error - self.Kd_trans * linear_vel + angular_velocity = self.Kp_yaw * yaw_error - self.Kd_yaw * angular_vel + + linear_velocity = max(-self.max_v, min(self.max_v, linear_velocity)) + angular_velocity = max(-self.max_w, min(self.max_w, angular_velocity)) + + return linear_velocity, angular_velocity + + def calculate_errors(self, odom, target): + + dx = target[0, 3] - odom[0, 3] + dy = target[1, 3] - odom[1, 3] + + odom_yaw = math.atan2(odom[1, 0], odom[0, 0]) + target_yaw = math.atan2(target[1, 0], target[0, 0]) + + translation_error = dx * np.cos(odom_yaw) + dy * np.sin(odom_yaw) + + yaw_error = target_yaw - odom_yaw + yaw_error = (yaw_error + math.pi) % (2 * math.pi) - math.pi + + return translation_error, yaw_error \ No newline at end of file diff --git a/scripts/realworld/http_internvla_client.py b/scripts/realworld/http_internvla_client.py new file mode 100644 index 00000000..2704f198 --- /dev/null +++ b/scripts/realworld/http_internvla_client.py @@ -0,0 +1,366 @@ +import rclpy +import sys +import threading +import io +import json +import copy +import requests +import time +import numpy as np +import math +from enum import Enum +from collections import deque + +from PIL import Image as PIL_Image +from sensor_msgs.msg import Image +from nav_msgs.msg import Odometry, Path +from geometry_msgs.msg import PoseStamped +from geometry_msgs.msg import Twist + +frame_data = {} +frame_idx = 0 +from std_msgs.msg import Bool +from cv_bridge import CvBridge +from rclpy.node import Node +from message_filters import Subscriber, ApproximateTimeSynchronizer +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy + +# user-specific +from controllers import * +from thread_utils import * + +class ControlMode(Enum): + PID_Mode = 1 + MPC_Mode = 2 + +# global variable +policy_init = True +mpc = None +pid = PID_controller(Kp_trans=2.0, Kd_trans=0.0, Kp_yaw=1.5, Kd_yaw=0.0, max_v=0.6, max_w=0.5) +http_idx = -1 +first_running_time = 0.0 +last_pixel_goal = None +last_s2_step = -1 +manager = None +current_control_mode = ControlMode.MPC_Mode +trajs_in_world = None + +desired_v, desired_w = 0.0, 0.0 +rgb_depth_rw_lock = ReadWriteLock() +odom_rw_lock = ReadWriteLock() +mpc_rw_lock = ReadWriteLock() + + +def dual_sys_eval(image_bytes, depth_bytes, front_image_bytes, url='http://127.0.0.1:5801/eval_dual'): + global policy_init, http_idx, first_running_time + data = {"reset":policy_init, "idx":http_idx} + json_data = json.dumps(data) + + policy_init = False + files = {'image': ('rgb_image', image_bytes, 'image/jpeg'), + 'depth': ('depth_image', depth_bytes, 'image/png'), + } + start = time.time() + response = requests.post(url, files=files, data={'json': json_data}, timeout=100) + print(f"response {response.text}") + http_idx += 1 + if http_idx == 0: + first_running_time = time.time() + print(f"idx:{http_idx} after http {time.time() - start}") + + return json.loads(response.text) + + +def control_thread(): + global desired_v, desired_w + while True: + global current_control_mode + if current_control_mode == ControlMode.MPC_Mode: + odom_rw_lock.acquire_read() + odom = manager.odom.copy() if manager.odom else None + odom_rw_lock.release_read() + if mpc is not None and manager is not None and odom is not None: + local_mpc = mpc + t0 = time.time() + opt_u_controls, opt_x_states = local_mpc.solve(np.array(odom)) + v, w = opt_u_controls[0, 0], opt_u_controls[0, 1] + + desired_v, desired_w = v, w + manager.move(v, 0.0, w) + elif current_control_mode == ControlMode.PID_Mode: + odom_rw_lock.acquire_read() + odom = manager.odom.copy() if manager.odom else None + odom_rw_lock.release_read() + homo_odom = manager.homo_odom.copy() if manager.homo_odom is not None else None + vel = manager.vel.copy() if manager.vel is not None else None + homo_goal = manager.homo_goal.copy() if manager.homo_goal is not None else None + + if homo_odom is not None and vel is not None and homo_goal is not None: + v, w, e_p, e_r = pid.solve(homo_odom, homo_goal, vel) + if v < 0.0: + v = 0.0 + desired_v, desired_w = v, w + manager.move(v, 0.0, w) + + time.sleep(0.1) + +def planning_thread(): + global trajs_in_world + + while True: + start_time = time.time() + DISIRED_TIME = 0.3 + time.sleep(0.05) + + if not manager.new_image_arrived: + time.sleep(0.01) + continue + manager.new_image_arrived = False + rgb_depth_rw_lock.acquire_read() + rgb_bytes = copy.deepcopy(manager.rgb_bytes) + depth_bytes = copy.deepcopy(manager.depth_bytes) + infer_rgb = copy.deepcopy(manager.rgb_image) + infer_depth = copy.deepcopy(manager.depth_image) + rgb_time = manager.rgb_time + rgb_depth_rw_lock.release_read() + odom_rw_lock.acquire_read() + min_diff = 1e10 + time_diff = 1e10 + odom_infer = None + for odom in manager.odom_queue: + diff = abs(odom[0] - rgb_time) + if diff < min_diff: + min_diff = diff + odom_infer = copy.deepcopy(odom[1]) + time_diff = odom[0] - rgb_time + odom_time = manager.odom_timestamp + odom_rw_lock.release_read() + + if odom_infer is not None and rgb_bytes is not None and depth_bytes is not None: + global frame_data + frame_data[http_idx] = { + 'infer_rgb': copy.deepcopy(infer_rgb), + 'infer_depth': copy.deepcopy(infer_depth), + 'infer_odom': copy.deepcopy(odom_infer) + } + if len(frame_data) > 100: + del frame_data[min(frame_data.keys())] + response = dual_sys_eval(rgb_bytes, depth_bytes, None) + + + global current_control_mode + traj_len = 0.0 + if 'trajectory' in response: + trajectory = response['trajectory'] + trajs_in_world = [] + odom = odom_infer + traj_len = np.linalg.norm(trajectory[-1][:2]) + print(f"traj len {traj_len}") + for i, traj in enumerate(trajectory): + if i < 3: + continue + x_, y_, yaw_ = odom[0], odom[1], odom[2] + + w_T_b = np.array([[np.cos(yaw_), -np.sin(yaw_), 0, x_], + [np.sin(yaw_), np.cos(yaw_), 0, y_], + [0.0, 0.0, 1.0, 0], + [0.0, 0.0, 0.0, 1.0]]) + w_P = (w_T_b @ (np.array([traj[0], traj[1], 0.0, 1.0])).T)[:2] + trajs_in_world.append(w_P) + trajs_in_world = np.array(trajs_in_world) + print(f"{time.time()} update traj") + + + manager.last_trajs_in_world = trajs_in_world + t0 = time.time() + mpc_rw_lock.acquire_write() + global mpc + if mpc is None: + mpc = Mpc_controller(np.array(trajs_in_world)) + else: + mpc.update_ref_traj(np.array(trajs_in_world)) + manager.request_cnt += 1 + mpc_rw_lock.release_write() + current_control_mode = ControlMode.MPC_Mode + elif 'discrete_action' in response: + actions = response['discrete_action'] + if actions != [5] and actions != [9]: + manager.incremental_change_goal(actions) + current_control_mode = ControlMode.PID_Mode + else: + print(f"skip planning. odom_infer:{odom_infer is not None} rgb_bytes:{rgb_bytes is not None} depth_bytes:{depth_bytes is not None}") + time.sleep(0.1) + + time.sleep(max(0, DISIRED_TIME - (time.time() - start_time))) + + +class Go2Manager(Node): + def __init__(self): + super().__init__('go2_manager') + + rgb_down_sub = Subscriber(self, Image, "/camera/camera/color/image_raw") + depth_down_sub = Subscriber(self, Image, "/camera/camera/aligned_depth_to_color/image_raw") + + qos_profile = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, + depth=10 + ) + + self.syncronizer = ApproximateTimeSynchronizer([rgb_down_sub, depth_down_sub], 1, 0.1) + self.syncronizer.registerCallback(self.rgb_depth_down_callback) + self.odom_sub = self.create_subscription(Odometry, "/odom_bridge", self.odom_callback, qos_profile) + + # publisher + self.control_pub = self.create_publisher(Twist, '/cmd_vel_bridge', 5) + + # class member variable + self.cv_bridge = CvBridge() + self.rgb_image = None + self.rgb_bytes = None + self.depth_image = None + self.depth_bytes = None + self.rgb_forward_image = None + self.rgb_forward_bytes = None + self.new_image_arrived = False + self.new_vis_image_arrived = False + self.rgb_time = 0.0 + + self.odom = None + self.linear_vel = 0.0 + self.angular_vel = 0.0 + self.request_cnt = 0 + self.odom_cnt = 0 + self.odom_queue = deque(maxlen=50) + self.odom_timestamp = 0.0 + + self.last_s2_step = -1 + self.last_trajs_in_world = None + self.last_all_trajs_in_world = None + self.homo_odom = None + self.homo_goal = None + self.vel = None + + def rgb_forward_callback(self, rgb_msg): + raw_image = self.cv_bridge.imgmsg_to_cv2(rgb_msg, 'rgb8')[:, :, :] + self.rgb_forward_image = raw_image + image = PIL_Image.fromarray(self.rgb_forward_image) + image_bytes = io.BytesIO() + image.save(image_bytes, format='JPEG') + image_bytes.seek(0) + self.rgb_forward_bytes = image_bytes + self.new_vis_image_arrived = True + self.new_image_arrived = True + def rgb_depth_down_callback(self, rgb_msg, depth_msg): + t0 = time.time() + raw_image = self.cv_bridge.imgmsg_to_cv2(rgb_msg, 'rgb8')[:, :, :] + self.rgb_image = raw_image + image = PIL_Image.fromarray(self.rgb_image) + image_bytes = io.BytesIO() + image.save(image_bytes, format='JPEG') + image_bytes.seek(0) + + raw_depth = self.cv_bridge.imgmsg_to_cv2(depth_msg, '16UC1') + raw_depth[np.isnan(raw_depth)] = 0 + raw_depth[np.isinf(raw_depth)] = 0 + self.depth_image = raw_depth / 1000.0 + self.depth_image -= 0.0 + self.depth_image[np.where(self.depth_image < 0)] = 0 + depth = (np.clip(self.depth_image * 10000.0, 0, 65535)).astype(np.uint16) + depth = PIL_Image.fromarray(depth) + depth_bytes = io.BytesIO() + depth.save(depth_bytes, format='PNG') + depth_bytes.seek(0) + + rgb_depth_rw_lock.acquire_write() + self.rgb_bytes = image_bytes + + self.rgb_time = rgb_msg.header.stamp.sec + rgb_msg.header.stamp.nanosec / 1.0e9 + self.last_rgb_time = self.rgb_time + + self.depth_bytes = depth_bytes + self.depth_time = depth_msg.header.stamp.sec + depth_msg.header.stamp.nanosec / 1.0e9 + self.last_depth_time = self.depth_time + + rgb_depth_rw_lock.release_write() + + self.new_vis_image_arrived = True + self.new_image_arrived = True + def odom_callback(self, msg): + self.odom_cnt += 1 + odom_rw_lock.acquire_write() + zz = msg.pose.pose.orientation.z + ww = msg.pose.pose.orientation.w + yaw = math.atan2(2 * zz * ww, 1 - 2 * zz * zz) + self.odom = [msg.pose.pose.position.x, msg.pose.pose.position.y, yaw] + self.odom_queue.append((time.time(), copy.deepcopy(self.odom))) + self.odom_timestamp = time.time() + self.linear_vel = msg.twist.twist.linear.x + self.angular_vel = msg.twist.twist.angular.z + odom_rw_lock.release_write() + + R0 = np.array([[np.cos(yaw), -np.sin(yaw)], + [np.sin(yaw), np.cos(yaw)]]) + self.homo_odom = np.eye(4) + self.homo_odom[:2, :2] = R0 + self.homo_odom[:2, 3] = [msg.pose.pose.position.x, msg.pose.pose.position.y] + self.vel = [msg.twist.twist.linear.x, msg.twist.twist.angular.z] + + if self.odom_cnt == 1: + self.homo_goal = self.homo_odom.copy() + + def incremental_change_goal(self, actions): + if self.homo_goal is None: + raise ValueError("Please initialize homo_goal before change it!") + homo_goal = self.homo_odom.copy() + for each_action in actions: + if each_action == 0: + pass + elif each_action == 1: + yaw = math.atan2(homo_goal[1, 0], homo_goal[0, 0]) + homo_goal[0, 3] += 0.25 * np.cos(yaw) + homo_goal[1, 3] += 0.25 * np.sin(yaw) + elif each_action == 2: + angle = math.radians(15) + rotation_matrix = np.array([ + [math.cos(angle), -math.sin(angle), 0], + [math.sin(angle), math.cos(angle), 0], + [0, 0, 1] + ]) + homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) + elif each_action == 3: + angle = -math.radians(15.0) + rotation_matrix = np.array([ + [math.cos(angle), -math.sin(angle), 0], + [math.sin(angle), math.cos(angle), 0], + [0, 0, 1] + ]) + homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) + self.homo_goal = homo_goal + def move(self, vx, vy, vyaw): + request = Twist() + request.linear.x = vx + request.linear.y = 0.0 + request.angular.z = vyaw + + self.control_pub.publish(request) + + +if __name__ == '__main__': + control_thread_instance = threading.Thread(target=control_thread) + planning_thread_instance = threading.Thread(target=planning_thread) + + rclpy.init() + + try: + manager = Go2Manager() + + control_thread_instance.start() + planning_thread_instance.start() + + rclpy.spin(manager) + except KeyboardInterrupt: + pass + finally: + manager.destroy_node() + rclpy.shutdown() diff --git a/scripts/realworld/thread_utils.py b/scripts/realworld/thread_utils.py new file mode 100644 index 00000000..67812aa3 --- /dev/null +++ b/scripts/realworld/thread_utils.py @@ -0,0 +1,27 @@ +import threading + +class ReadWriteLock: + def __init__(self): + self._read_ready = threading.Condition(threading.Lock()) + self._readers = 0 + + def acquire_read(self): + with self._read_ready: + self._read_ready.wait_for(lambda: self._readers >= 0) + self._readers += 1 + + def release_read(self): + with self._read_ready: + self._readers -= 1 + if self._readers == 0: + self._read_ready.notify_all() + + def acquire_write(self): + with self._read_ready: + self._read_ready.wait_for(lambda: self._readers == 0) + self._readers = -1 + + def release_write(self): + with self._read_ready: + self._readers = 0 + self._read_ready.notify_all() \ No newline at end of file From 73ee99d5f34803aa164afa324891493a1f0f872c Mon Sep 17 00:00:00 2001 From: yangyuqiang Date: Fri, 26 Sep 2025 13:33:34 +0800 Subject: [PATCH 4/6] [fix] precommit fix --- .../agent/internvla_n1_agent_realworld.py | 185 ++++++++++-------- .../basemodel/internvla_n1/internvla_n1.py | 98 +++++----- scripts/realworld/controllers.py | 101 ++++++---- scripts/realworld/http_internvla_client.py | 167 ++++++++-------- scripts/realworld/http_internvla_server.py | 54 ++--- scripts/realworld/thread_utils.py | 3 +- setup.cfg | 2 +- 7 files changed, 317 insertions(+), 293 deletions(-) diff --git a/internnav/agent/internvla_n1_agent_realworld.py b/internnav/agent/internvla_n1_agent_realworld.py index c86940fa..cd123ef7 100644 --- a/internnav/agent/internvla_n1_agent_realworld.py +++ b/internnav/agent/internvla_n1_agent_realworld.py @@ -2,82 +2,86 @@ import itertools import os import re -import time -import torch import sys -import numpy as np +import time from datetime import datetime from pathlib import Path +import numpy as np +import torch + sys.path.append(str(Path(__file__).parent.parent.parent)) -from PIL import Image, ImageFile, ImageDraw, ImageFont -from internnav.model.utils.vln_utils import split_and_clean, S2Output, traj_to_actions from collections import OrderedDict -from transformers import ( - AutoTokenizer, - AutoProcessor, -) -from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM +from PIL import Image +from transformers import AutoProcessor +from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM +from internnav.model.utils.vln_utils import S2Output, split_and_clean, traj_to_actions DEFAULT_IMAGE_TOKEN = "" + + class InternVLAN1AsyncAgent: def __init__(self, args): self.device = torch.device(args.device) self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") self.model = InternVLAN1ForCausalLM.from_pretrained( - args.model_path, torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", device_map={"": self.device} + args.model_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map={"": self.device}, ) self.model.eval() self.model.to(self.device) - + self.processor = AutoProcessor.from_pretrained(args.model_path) self.processor.tokenizer.padding_side = 'left' - + self.resize_w = args.resize_w self.resize_h = args.resize_h self.num_history = args.num_history - - prompt = f"You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task." + + prompt = "You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task." answer = "" self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] self.conjunctions = [ - 'you can see ', - 'in front of you is ', - 'there is ', - 'you can spot ', - 'you are toward the ', - 'ahead of you is ', - 'in your sight is ' - ] - - self.actions2idx = OrderedDict({ - 'STOP': [0], - "↑": [1], - "←": [2], - "→": [3], - "↓": [5], - }) - + 'you can see ', + 'in front of you is ', + 'there is ', + 'you can spot ', + 'you are toward the ', + 'ahead of you is ', + 'in your sight is ', + ] + + self.actions2idx = OrderedDict( + { + 'STOP': [0], + "↑": [1], + "←": [2], + "→": [3], + "↓": [5], + } + ) + self.rgb_list = [] self.depth_list = [] self.pose_list = [] - self.episode_idx = 0 + self.episode_idx = 0 self.conversation_history = [] self.llm_output = "" self.past_key_values = None self.last_s2_idx = -100 - + # output self.output_action = None self.output_latent = None self.output_pixel = None self.pixel_goal_rgb = None self.pixel_goal_depth = None - + def reset(self): self.rgb_list = [] self.depth_list = [] @@ -86,9 +90,10 @@ def reset(self): self.conversation_history = [] self.llm_output = "" self.past_key_values = None - + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs(self.save_dir, exist_ok=True) + def parse_actions(self, output): action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) regex = re.compile(action_patterns) @@ -96,65 +101,73 @@ def parse_actions(self, output): actions = [self.actions2idx[match] for match in matches] actions = itertools.chain.from_iterable(actions) return list(actions) - + def step_no_infer(self, rgb, depth, pose): image = Image.fromarray(rgb).convert('RGB') - raw_image_size = image.size image = image.resize((self.resize_w, self.resize_h)) self.rgb_list.append(image) - image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}.jpg") + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg") self.episode_idx += 1 - - def trajectory_tovw(self, trajectory, kp = 1.0): + + def trajectory_tovw(self, trajectory, kp=1.0): subgoal = trajectory[-1] linear_vel, angular_vel = kp * np.linalg.norm(subgoal[:2]), kp * subgoal[2] linear_vel = np.clip(linear_vel, 0, 0.5) angular_vel = np.clip(angular_vel, -0.5, 0.5) return linear_vel, angular_vel - def step(self, rgb, depth, pose, instruction, intrinsic, look_down = False): + def step(self, rgb, depth, pose, instruction, intrinsic, look_down=False): dual_sys_output = S2Output() PLAN_STEP_GAP = 8 - no_output_flag = (self.output_action is None and self.output_latent is None) + no_output_flag = self.output_action is None and self.output_latent is None if (self.episode_idx - self.last_s2_idx > PLAN_STEP_GAP) or look_down or no_output_flag: - self.output_action, self.output_latent, self.output_pixel = self.step_s2(rgb, depth, pose, instruction, intrinsic, look_down) + self.output_action, self.output_latent, self.output_pixel = self.step_s2( + rgb, depth, pose, instruction, intrinsic, look_down + ) self.last_s2_idx = self.episode_idx dual_sys_output.output_pixel = self.output_pixel self.pixel_goal_rgb = copy.deepcopy(rgb) self.pixel_goal_depth = copy.deepcopy(depth) else: self.step_no_infer(rgb, depth, pose) - - + if self.output_action is not None: dual_sys_output.output_action = copy.deepcopy(self.output_action) - self.output_action = None - elif self.output_latent is not None: + self.output_action = None + elif self.output_latent is not None: processed_pixel_rgb = np.array(Image.fromarray(self.pixel_goal_rgb).resize((224, 224))) / 255 processed_pixel_depth = np.array(Image.fromarray(self.pixel_goal_depth).resize((224, 224))) processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255 processed_depth = np.array(Image.fromarray(depth).resize((224, 224))) - rgbs = torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]).unsqueeze(0).to(self.device) - depths = torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]).unsqueeze(0).unsqueeze(-1).to(self.device) + rgbs = ( + torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]) + .unsqueeze(0) + .to(self.device) + ) + depths = ( + torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]) + .unsqueeze(0) + .unsqueeze(-1) + .to(self.device) + ) trajectories = self.step_s1(self.output_latent, rgbs, depths) - + dual_sys_output.output_action = traj_to_actions(trajectories) return dual_sys_output - - def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False): + + def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down=False): image = Image.fromarray(rgb).convert('RGB') - raw_image_size = image.size if not look_down: image = image.resize((self.resize_w, self.resize_h)) self.rgb_list.append(image) - image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}.jpg") + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg") else: - image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}_look_down.jpg") + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}_look_down.jpg") if not look_down: - self.conversation_history = [] + self.conversation_history = [] self.past_key_values = None - + sources = copy.deepcopy(self.conversation) sources[0]["value"] = sources[0]["value"].replace('.', instruction) cur_images = self.rgb_list[-1:] @@ -164,7 +177,7 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False): history_id = np.unique(np.linspace(0, self.episode_idx - 1, self.num_history, dtype=np.int32)).tolist() placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) sources[0]["value"] += f' These are your historical observations: {placeholder}.' - + history_id = sorted(history_id) self.input_images = [self.rgb_list[i] for i in history_id] + cur_images input_img_id = 0 @@ -174,33 +187,33 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False): input_img_id = -1 assert self.llm_output != "", "Last llm_output should not be empty when look down" sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] - self.conversation_history.append({ 'role': 'assistant', 'content': [{ 'type': 'text', 'text': self.llm_output}]}) - + self.conversation_history.append( + {'role': 'assistant', 'content': [{'type': 'text', 'text': self.llm_output}]} + ) + prompt = self.conjunctions[0] + DEFAULT_IMAGE_TOKEN sources[0]["value"] += f" {prompt}." prompt_instruction = copy.deepcopy(sources[0]["value"]) parts = split_and_clean(prompt_instruction) - + content = [] - for i in range (len(parts)): + for i in range(len(parts)): if parts[i] == "": content.append({"type": "image", "image": self.input_images[input_img_id]}) - input_img_id +=1 + input_img_id += 1 else: - content.append({"type": "text", "text": parts[i]}) - + content.append({"type": "text", "text": parts[i]}) + self.conversation_history.append({'role': 'user', 'content': content}) - - text = self.processor.apply_chat_template( - self.conversation_history, tokenize=False, add_generation_prompt=True - ) - + + text = self.processor.apply_chat_template(self.conversation_history, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text], images=self.input_images, return_tensors="pt").to(self.device) t0 = time.time() with torch.no_grad(): outputs = self.model.generate( - **inputs, - max_new_tokens=128, + **inputs, + max_new_tokens=128, do_sample=False, use_cache=True, past_key_values=self.past_key_values, @@ -208,30 +221,30 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False): raw_input_ids=copy.deepcopy(inputs.input_ids), ) output_ids = outputs.sequences - + t1 = time.time() - self.llm_output = self.processor.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) - with open(f"{self.save_dir}/llm_output_{self.episode_idx:04d}.txt", 'w') as f: + self.llm_output = self.processor.tokenizer.decode( + output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True + ) + with open(f"{self.save_dir}/llm_output_{self.episode_idx: 04d}.txt", 'w') as f: f.write(self.llm_output) self.last_output_ids = copy.deepcopy(output_ids[0]) self.past_key_values = copy.deepcopy(outputs.past_key_values) - print(f"output {self.episode_idx} {self.llm_output} cost:{t1-t0}s") - if bool(re.search(r'\d', self.llm_output)): + print(f"output {self.episode_idx} {self.llm_output} cost: {t1 - t0}s") + if bool(re.search(r'\d', self.llm_output)): coord = [int(c) for c in re.findall(r'\d+', self.llm_output)] pixel_goal = [int(coord[1]), int(coord[0])] - image_grid_thw = torch.cat( - [thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0 - ) + image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0) pixel_values = inputs.pixel_values t0 = time.time() with torch.no_grad(): traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) return None, traj_latents, pixel_goal - + else: action_seq = self.parse_actions(self.llm_output) - return action_seq, None, None - - def step_s1(self, latent, rgb, depth): + return action_seq, None, None + + def step_s1(self, latent, rgb, depth): all_trajs = self.model.generate_traj(latent, rgb, depth, use_async=True) return all_trajs diff --git a/internnav/model/basemodel/internvla_n1/internvla_n1.py b/internnav/model/basemodel/internvla_n1/internvla_n1.py index 9e6f2450..40e01dcd 100644 --- a/internnav/model/basemodel/internvla_n1/internvla_n1.py +++ b/internnav/model/basemodel/internvla_n1/internvla_n1.py @@ -1,20 +1,17 @@ from abc import ABC, abstractmethod -import torch -import torch.nn as nn -from .navdp import NavDP_Policy_DPT_CriticSum_DAT - -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import torch +import torch.nn as nn from transformers import ( - Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast -import torch.nn as nn -from internnav.model.utils.vln_utils import * +from .navdp import NavDP_Policy_DPT_CriticSum_DAT + def build_navdp(navdp_cfg): navdp_version = getattr(navdp_cfg, "navdp_version", 0.0) @@ -22,45 +19,44 @@ def build_navdp(navdp_cfg): memory_size = 2 else: memory_size = 3 - - navdp = NavDP_Policy_DPT_CriticSum_DAT(memory_size=memory_size, - navdp_pretrained=navdp_cfg.navdp_pretrained, - navdp_version=navdp_version) + + navdp = NavDP_Policy_DPT_CriticSum_DAT( + memory_size=memory_size, navdp_pretrained=navdp_cfg.navdp_pretrained, navdp_version=navdp_version + ) navdp.load_model() return navdp -class InternVLAN1MetaModel: +class InternVLAN1MetaModel: def __init__(self, config): super(InternVLAN1MetaModel, self).__init__(config) if hasattr(config, "navdp"): self.latent_queries = nn.Parameter(torch.randn(1, config.n_query, config.hidden_size)) self.navdp = build_navdp(config) - + def initialize_vision_modules(self, model_args): if getattr(self, 'navdp', None) is None: self.config.navdp = model_args.navdp self.config.navdp_pretrained = model_args.navdp_pretrained self.navdp = build_navdp(model_args) - + self.config.n_query = model_args.n_query if getattr(self, 'latent_queries', None) is None: print("random initiation the latent_queries !!!") self.latent_queries = nn.Parameter(torch.randn(1, self.config.n_query, self.config.hidden_size)) - + class InternVLAN1MetaForCausalLM(ABC): - @abstractmethod def get_model(self): pass - + def get_navdp(self): return self.get_model().navdp - + def get_mm_projector(self): return self.get_model().mm_projector - + def get_n_query(self): return self.get_model().config.n_query @@ -68,11 +64,11 @@ def get_n_query(self): TRAJ_START_TOKEN_INDEX = 151665 IMAGE_TOKEN_INDEX = 151655 TRAJ_TOKEN_INDEX = 151667 - + class InternVLAN1ModelConfig(Qwen2_5_VLConfig): model_type = "internvla_n1" - + def __init__(self, **kwargs): super().__init__(**kwargs) self.model_cfg = kwargs.get('model_cfg', None) @@ -80,6 +76,7 @@ def __init__(self, **kwargs): class InternVLAN1Model(InternVLAN1MetaModel, Qwen2_5_VLModel): config_class = InternVLAN1ModelConfig + def __init__(self, config: Qwen2_5_VLConfig): super(InternVLAN1Model, self).__init__(config) @@ -90,14 +87,13 @@ class InternVLAN1ForCausalLM(Qwen2_5_VLForConditionalGeneration, InternVLAN1Meta def __init__(self, config): Qwen2_5_VLForConditionalGeneration.__init__(self, config) config.model_type == "internvla_n1" - + self.model = InternVLAN1Model(config) - self.rope_deltas = None + self.rope_deltas = None self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() - - + def get_model(self): return self.model @@ -142,8 +138,7 @@ def prepare_inputs_for_generation( model_inputs["pixel_values_videos"] = pixel_values_videos return model_inputs - - + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -251,7 +246,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) n_traj_tokens = (input_ids == TRAJ_TOKEN_INDEX).sum().item() - traj_idx = (input_ids == TRAJ_TOKEN_INDEX) + traj_idx = input_ids == TRAJ_TOKEN_INDEX latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1) H = latent_queries.shape[-1] latent_queries = latent_queries.contiguous().view(-1, H) @@ -277,8 +272,8 @@ def forward( attention_mask, ) self.rope_deltas = rope_deltas - elif n_image_tokens > 0: # using only for kv cache - attention_mask = attention_mask[:, :raw_input_ids.shape[1]] + elif n_image_tokens > 0: # using only for kv cache + attention_mask = attention_mask[:, : raw_input_ids.shape[1]] position_ids, rope_deltas = self.get_rope_index( raw_input_ids, image_grid_thw, @@ -287,19 +282,15 @@ def forward( attention_mask, ) delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) - position_ids = position_ids[:, :,-input_ids.shape[1]:] + position_ids = position_ids[:, :, -input_ids.shape[1] :] self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) @@ -307,7 +298,7 @@ def forward( delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - + outputs = self.model( input_ids=None, position_ids=position_ids, @@ -333,40 +324,41 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - + def generate_latents(self, input_ids, pixel_values, image_grid_thw): input_ids.to(self.get_model().device) input_ids = torch.cat([input_ids, torch.tensor([[TRAJ_START_TOKEN_INDEX]]).to(input_ids.device)], dim=1) text_embeds = self.get_model().embed_tokens(input_ids) latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1) - image_idx = (input_ids == IMAGE_TOKEN_INDEX) - N_QUERY = self.get_n_query() + image_idx = input_ids == IMAGE_TOKEN_INDEX + N_QUERY = self.get_n_query() input_ids = torch.cat([input_ids, torch.tensor([[TRAJ_TOKEN_INDEX] * N_QUERY]).to(input_ids.device)], dim=1) pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).unsqueeze(0) - text_embeds[image_idx] = image_embeds.to(text_embeds.device)[:image_idx.sum(), :] + text_embeds[image_idx] = image_embeds.to(text_embeds.device)[: image_idx.sum(), :] text_embeds = torch.cat([text_embeds, latent_queries], dim=1) - position_ids, _ = self.get_rope_index( - input_ids, - image_grid_thw - ) + position_ids, _ = self.get_rope_index(input_ids, image_grid_thw) outputs = self.model( inputs_embeds=text_embeds, - position_ids = position_ids, + position_ids=position_ids, # attention_mask=attention_mask, output_hidden_states=True, return_dict=True, ) - hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:] + hidden_states = outputs.hidden_states[-1][:, -N_QUERY:, :] return hidden_states - + def generate_traj(self, traj_latents, images_dp=None, depths_dp=None, use_async=False): if use_async: - all_trajs = self.model.navdp.predict_pointgoal_action_async(traj_latents.to(self.get_model().device), images_dp, depths_dp, vlm_mask=None) + all_trajs = self.model.navdp.predict_pointgoal_action_async( + traj_latents.to(self.get_model().device), images_dp, depths_dp, vlm_mask=None + ) else: - all_trajs = self.model.navdp.predict_pointgoal_action(traj_latents.to(self.get_model().device), vlm_mask=None) - return all_trajs \ No newline at end of file + all_trajs = self.model.navdp.predict_pointgoal_action( + traj_latents.to(self.get_model().device), vlm_mask=None + ) + return all_trajs diff --git a/scripts/realworld/controllers.py b/scripts/realworld/controllers.py index 9c9b443a..1cbe2947 100644 --- a/scripts/realworld/controllers.py +++ b/scripts/realworld/controllers.py @@ -1,17 +1,18 @@ #!/usr/bin/env python -# coding=utf-8 -import casadi as ca -import numpy as np -import time import math import os import sys + +import casadi as ca +import numpy as np + sys.path.append(os.path.dirname(os.path.abspath(__file__))) from scipy.interpolate import interp1d + class Mpc_controller: - def __init__(self, global_planed_traj, N = 20, desired_v = 0.3, v_max = 0.4, w_max = 0.4, ref_gap = 4): + def __init__(self, global_planed_traj, N=20, desired_v=0.3, v_max=0.4, w_max=0.4, ref_gap=4): self.N, self.desired_v, self.ref_gap, self.T = N, desired_v, ref_gap, 0.1 self.ref_traj = self.make_ref_denser(global_planed_traj) self.ref_traj_len = N // ref_gap + 1 @@ -21,43 +22,55 @@ def __init__(self, global_planed_traj, N = 20, desired_v = 0.3, v_max = 0.4, w_m opt_controls = opti.variable(N, 2) v, w = opt_controls[:, 0], opt_controls[:, 1] - opt_states = opti.variable(N+1, 3) - x, y, theta = opt_states[:, 0], opt_states[:, 1], opt_states[:, 2] + opt_states = opti.variable(N + 1, 3) + # x, y, theta = opt_states[:, 0], opt_states[:, 1], opt_states[:, 2] - # parameters + # parameters opt_x0 = opti.parameter(3) - opt_xs = opti.parameter(3 * self.ref_traj_len) # the intermidia state may also be the parameter + opt_xs = opti.parameter(3 * self.ref_traj_len) # the intermidia state may also be the parameter # system dynamics for mobile manipulator - f = lambda x_, u_: ca.vertcat(*[u_[0]*ca.cos(x_[2]), u_[0]*ca.sin(x_[2]), u_[1]]) + f = lambda x_, u_: ca.vertcat(*[u_[0] * ca.cos(x_[2]), u_[0] * ca.sin(x_[2]), u_[1]]) # noqa # init_condition opti.subject_to(opt_states[0, :] == opt_x0.T) for i in range(N): - x_next = opt_states[i, :] + f(opt_states[i, :], opt_controls[i, :]).T*self.T - opti.subject_to(opt_states[i+1, :]==x_next) + x_next = opt_states[i, :] + f(opt_states[i, :], opt_controls[i, :]).T * self.T + opti.subject_to(opt_states[i + 1, :] == x_next) # define the cost function - Q = np.diag([10.0,10.0,0.0]) - R = np.diag([0.05,0.2]) - obj = 0 + Q = np.diag([10.0, 10.0, 0.0]) + R = np.diag([0.05, 0.2]) + obj = 0 for i in range(N): - obj = obj +ca.mtimes([opt_controls[i, :], R, opt_controls[i, :].T]) + obj = obj + ca.mtimes([opt_controls[i, :], R, opt_controls[i, :].T]) if i % ref_gap == 0: nn = i // ref_gap - obj = obj + ca.mtimes([(opt_states[i, :]-opt_xs[nn*3:nn*3+3].T), Q, (opt_states[i, :]-opt_xs[nn*3:nn*3+3].T).T]) + obj = obj + ca.mtimes( + [ + (opt_states[i, :] - opt_xs[nn * 3 : nn * 3 + 3].T), + Q, + (opt_states[i, :] - opt_xs[nn * 3 : nn * 3 + 3].T).T, + ] + ) opti.minimize(obj) - # boundrary and control conditions + # boundary and control conditions opti.subject_to(opti.bounded(0, v, v_max)) opti.subject_to(opti.bounded(-w_max, w, w_max)) - - opts_setting = {'ipopt.max_iter':100, 'ipopt.print_level':0, 'print_time':0, 'ipopt.acceptable_tol':1e-8, 'ipopt.acceptable_obj_change_tol':1e-6} + + opts_setting = { + 'ipopt.max_iter': 100, + 'ipopt.print_level': 0, + 'print_time': 0, + 'ipopt.acceptable_tol': 1e-8, + 'ipopt.acceptable_obj_change_tol': 1e-6, + } opti.solver('ipopt', opts_setting) # opts_setting = { 'qpsol':'osqp','hessian_approximation':'limited-memory','max_iter':200,'convexify_strategy':'regularize','beta':0.5,'c1':1e-4,'tol_du':1e-3,'tol_pr':1e-6} # opti.solver('sqpmethod',opts_setting) - + self.opti = opti self.opt_xs = opt_xs self.opt_x0 = opt_x0 @@ -65,7 +78,8 @@ def __init__(self, global_planed_traj, N = 20, desired_v = 0.3, v_max = 0.4, w_m self.opt_states = opt_states self.last_opt_x_states = None self.last_opt_u_controls = None - def make_ref_denser(self, ref_traj, ratio = 50): + + def make_ref_denser(self, ref_traj, ratio=50): x_orig = np.arange(len(ref_traj)) new_x = np.linspace(0, len(ref_traj) - 1, num=len(ref_traj) * ratio) @@ -75,21 +89,21 @@ def make_ref_denser(self, ref_traj, ratio = 50): uniform_x = interp_func_x(new_x) uniform_y = interp_func_y(new_x) ref_traj = np.stack((uniform_x, uniform_y), axis=1) - + return ref_traj - + def update_ref_traj(self, global_planed_traj): self.ref_traj = self.make_ref_denser(global_planed_traj) self.ref_traj_len = self.N // self.ref_gap + 1 - + def solve(self, x0): ref_traj = self.find_reference_traj(x0, self.ref_traj) # fake a yaw angle ref_traj = np.concatenate((ref_traj, np.zeros((ref_traj.shape[0], 1))), axis=1).reshape(-1, 1) - - self.opti.set_value(self.opt_xs, ref_traj.reshape(-1, 1)) + + self.opti.set_value(self.opt_xs, ref_traj.reshape(-1, 1)) u0 = np.zeros((self.N, 2)) if self.last_opt_u_controls is None else self.last_opt_u_controls - x00 = np.zeros((self.N+1, 3)) if self.last_opt_x_states is None else self.last_opt_x_states + x00 = np.zeros((self.N + 1, 3)) if self.last_opt_x_states is None else self.last_opt_x_states self.opti.set_value(self.opt_x0, x0) self.opti.set_initial(self.opt_controls, u0) @@ -101,15 +115,16 @@ def solve(self, x0): self.last_opt_x_states = sol.value(self.opt_states) return self.last_opt_u_controls, self.last_opt_x_states + def reset(self): self.last_opt_x_states = None self.last_opt_u_controls = None - + def find_reference_traj(self, x0, global_planed_traj): ref_traj_pts = [] # find the nearest point in global_planed_traj nearest_idx = np.argmin(np.linalg.norm(global_planed_traj - x0[:2].reshape((1, 2)), axis=1)) - desire_arc_length = self.desired_v * self.ref_gap * self.T + desire_arc_length = self.desired_v * self.ref_gap * self.T cum_dist = np.cumsum(np.linalg.norm(np.diff(global_planed_traj, axis=0), axis=1)) # select the reference points from the nearest point to the end of global_planed_traj @@ -118,11 +133,11 @@ def find_reference_traj(self, x0, global_planed_traj): ref_traj_pts.append(global_planed_traj[i, :]) if len(ref_traj_pts) == self.ref_traj_len: break - # if the target is reached before the reference trajectory is complete, add the last point of global_planed_traj + # if the target is reached before the reference trajectory is complete, add the last point of global_planed_traj while len(ref_traj_pts) < self.ref_traj_len: ref_traj_pts.append(global_planed_traj[-1, :]) return np.array(ref_traj_pts) - + class PID_controller: def __init__(self, Kp_trans=1.0, Kd_trans=0.1, Kp_yaw=1.0, Kd_yaw=1.0, max_v=1.0, max_w=1.2): @@ -132,12 +147,12 @@ def __init__(self, Kp_trans=1.0, Kd_trans=0.1, Kp_yaw=1.0, Kd_yaw=1.0, max_v=1.0 self.Kd_yaw = Kd_yaw self.max_v = max_v self.max_w = max_w - + def solve(self, odom, target, vel=np.zeros(2)): translation_error, yaw_error = self.calculate_errors(odom, target) v, w = self.pd_step(translation_error, yaw_error, vel[0], vel[1]) return v, w, translation_error, yaw_error - + def pd_step(self, translation_error, yaw_error, linear_vel, angular_vel): translation_error = max(-1.0, min(1.0, translation_error)) yaw_error = max(-1.0, min(1.0, yaw_error)) @@ -147,20 +162,20 @@ def pd_step(self, translation_error, yaw_error, linear_vel, angular_vel): linear_velocity = max(-self.max_v, min(self.max_v, linear_velocity)) angular_velocity = max(-self.max_w, min(self.max_w, angular_velocity)) - + return linear_velocity, angular_velocity - + def calculate_errors(self, odom, target): - - dx = target[0, 3] - odom[0, 3] - dy = target[1, 3] - odom[1, 3] + + dx = target[0, 3] - odom[0, 3] + dy = target[1, 3] - odom[1, 3] odom_yaw = math.atan2(odom[1, 0], odom[0, 0]) target_yaw = math.atan2(target[1, 0], target[0, 0]) - - translation_error = dx * np.cos(odom_yaw) + dy * np.sin(odom_yaw) + + translation_error = dx * np.cos(odom_yaw) + dy * np.sin(odom_yaw) yaw_error = target_yaw - odom_yaw yaw_error = (yaw_error + math.pi) % (2 * math.pi) - math.pi - - return translation_error, yaw_error \ No newline at end of file + + return translation_error, yaw_error diff --git a/scripts/realworld/http_internvla_client.py b/scripts/realworld/http_internvla_client.py index 2704f198..f20adb6c 100644 --- a/scripts/realworld/http_internvla_client.py +++ b/scripts/realworld/http_internvla_client.py @@ -1,38 +1,36 @@ -import rclpy -import sys -import threading +import copy import io import json -import copy -import requests -import time -import numpy as np import math -from enum import Enum +import threading +import time from collections import deque +from enum import Enum +import numpy as np +import rclpy +import requests +from geometry_msgs.msg import Twist +from nav_msgs.msg import Odometry from PIL import Image as PIL_Image from sensor_msgs.msg import Image -from nav_msgs.msg import Odometry, Path -from geometry_msgs.msg import PoseStamped -from geometry_msgs.msg import Twist frame_data = {} frame_idx = 0 -from std_msgs.msg import Bool +# user-specific +from controllers import Mpc_controller, PID_controller from cv_bridge import CvBridge +from message_filters import ApproximateTimeSynchronizer, Subscriber from rclpy.node import Node -from message_filters import Subscriber, ApproximateTimeSynchronizer -from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy +from rclpy.qos import HistoryPolicy, QoSProfile, ReliabilityPolicy +from thread_utils import ReadWriteLock -# user-specific -from controllers import * -from thread_utils import * class ControlMode(Enum): PID_Mode = 1 MPC_Mode = 2 + # global variable policy_init = True mpc = None @@ -53,22 +51,23 @@ class ControlMode(Enum): def dual_sys_eval(image_bytes, depth_bytes, front_image_bytes, url='http://127.0.0.1:5801/eval_dual'): global policy_init, http_idx, first_running_time - data = {"reset":policy_init, "idx":http_idx} + data = {"reset": policy_init, "idx": http_idx} json_data = json.dumps(data) - + policy_init = False - files = {'image': ('rgb_image', image_bytes, 'image/jpeg'), - 'depth': ('depth_image', depth_bytes, 'image/png'), - } + files = { + 'image': ('rgb_image', image_bytes, 'image/jpeg'), + 'depth': ('depth_image', depth_bytes, 'image/png'), + } start = time.time() response = requests.post(url, files=files, data={'json': json_data}, timeout=100) print(f"response {response.text}") http_idx += 1 if http_idx == 0: first_running_time = time.time() - print(f"idx:{http_idx} after http {time.time() - start}") - - return json.loads(response.text) + print(f"idx: {http_idx} after http {time.time() - start}") + + return json.loads(response.text) def control_thread(): @@ -81,7 +80,6 @@ def control_thread(): odom_rw_lock.release_read() if mpc is not None and manager is not None and odom is not None: local_mpc = mpc - t0 = time.time() opt_u_controls, opt_x_states = local_mpc.solve(np.array(odom)) v, w = opt_u_controls[0, 0], opt_u_controls[0, 1] @@ -90,9 +88,9 @@ def control_thread(): elif current_control_mode == ControlMode.PID_Mode: odom_rw_lock.acquire_read() odom = manager.odom.copy() if manager.odom else None - odom_rw_lock.release_read() + odom_rw_lock.release_read() homo_odom = manager.homo_odom.copy() if manager.homo_odom is not None else None - vel = manager.vel.copy() if manager.vel is not None else None + vel = manager.vel.copy() if manager.vel is not None else None homo_goal = manager.homo_goal.copy() if manager.homo_goal is not None else None if homo_odom is not None and vel is not None and homo_goal is not None: @@ -101,17 +99,18 @@ def control_thread(): v = 0.0 desired_v, desired_w = v, w manager.move(v, 0.0, w) - - time.sleep(0.1) + + time.sleep(0.1) + def planning_thread(): global trajs_in_world - + while True: start_time = time.time() DISIRED_TIME = 0.3 time.sleep(0.05) - + if not manager.new_image_arrived: time.sleep(0.01) continue @@ -125,29 +124,28 @@ def planning_thread(): rgb_depth_rw_lock.release_read() odom_rw_lock.acquire_read() min_diff = 1e10 - time_diff = 1e10 + # time_diff = 1e10 odom_infer = None for odom in manager.odom_queue: diff = abs(odom[0] - rgb_time) if diff < min_diff: min_diff = diff odom_infer = copy.deepcopy(odom[1]) - time_diff = odom[0] - rgb_time - odom_time = manager.odom_timestamp + # time_diff = odom[0] - rgb_time + # odom_time = manager.odom_timestamp odom_rw_lock.release_read() - + if odom_infer is not None and rgb_bytes is not None and depth_bytes is not None: global frame_data frame_data[http_idx] = { 'infer_rgb': copy.deepcopy(infer_rgb), 'infer_depth': copy.deepcopy(infer_depth), - 'infer_odom': copy.deepcopy(odom_infer) + 'infer_odom': copy.deepcopy(odom_infer), } if len(frame_data) > 100: del frame_data[min(frame_data.keys())] response = dual_sys_eval(rgb_bytes, depth_bytes, None) - global current_control_mode traj_len = 0.0 if 'trajectory' in response: @@ -160,19 +158,21 @@ def planning_thread(): if i < 3: continue x_, y_, yaw_ = odom[0], odom[1], odom[2] - - w_T_b = np.array([[np.cos(yaw_), -np.sin(yaw_), 0, x_], - [np.sin(yaw_), np.cos(yaw_), 0, y_], - [0.0, 0.0, 1.0, 0], - [0.0, 0.0, 0.0, 1.0]]) + + w_T_b = np.array( + [ + [np.cos(yaw_), -np.sin(yaw_), 0, x_], + [np.sin(yaw_), np.cos(yaw_), 0, y_], + [0.0, 0.0, 1.0, 0], + [0.0, 0.0, 0.0, 1.0], + ] + ) w_P = (w_T_b @ (np.array([traj[0], traj[1], 0.0, 1.0])).T)[:2] trajs_in_world.append(w_P) trajs_in_world = np.array(trajs_in_world) print(f"{time.time()} update traj") - - + manager.last_trajs_in_world = trajs_in_world - t0 = time.time() mpc_rw_lock.acquire_write() global mpc if mpc is None: @@ -183,16 +183,18 @@ def planning_thread(): mpc_rw_lock.release_write() current_control_mode = ControlMode.MPC_Mode elif 'discrete_action' in response: - actions = response['discrete_action'] - if actions != [5] and actions != [9]: - manager.incremental_change_goal(actions) - current_control_mode = ControlMode.PID_Mode + actions = response['discrete_action'] + if actions != [5] and actions != [9]: + manager.incremental_change_goal(actions) + current_control_mode = ControlMode.PID_Mode else: - print(f"skip planning. odom_infer:{odom_infer is not None} rgb_bytes:{rgb_bytes is not None} depth_bytes:{depth_bytes is not None}") + print( + f"skip planning. odom_infer: {odom_infer is not None} rgb_bytes: {rgb_bytes is not None} depth_bytes: {depth_bytes is not None}" + ) time.sleep(0.1) - + time.sleep(max(0, DISIRED_TIME - (time.time() - start_time))) - + class Go2Manager(Node): def __init__(self): @@ -200,12 +202,8 @@ def __init__(self): rgb_down_sub = Subscriber(self, Image, "/camera/camera/color/image_raw") depth_down_sub = Subscriber(self, Image, "/camera/camera/aligned_depth_to_color/image_raw") - - qos_profile = QoSProfile( - reliability=ReliabilityPolicy.BEST_EFFORT, - history=HistoryPolicy.KEEP_LAST, - depth=10 - ) + + qos_profile = QoSProfile(reliability=ReliabilityPolicy.BEST_EFFORT, history=HistoryPolicy.KEEP_LAST, depth=10) self.syncronizer = ApproximateTimeSynchronizer([rgb_down_sub, depth_down_sub], 1, 0.1) self.syncronizer.registerCallback(self.rgb_depth_down_callback) @@ -225,7 +223,7 @@ def __init__(self): self.new_image_arrived = False self.new_vis_image_arrived = False self.rgb_time = 0.0 - + self.odom = None self.linear_vel = 0.0 self.angular_vel = 0.0 @@ -240,7 +238,7 @@ def __init__(self): self.homo_odom = None self.homo_goal = None self.vel = None - + def rgb_forward_callback(self, rgb_msg): raw_image = self.cv_bridge.imgmsg_to_cv2(rgb_msg, 'rgb8')[:, :, :] self.rgb_forward_image = raw_image @@ -251,15 +249,15 @@ def rgb_forward_callback(self, rgb_msg): self.rgb_forward_bytes = image_bytes self.new_vis_image_arrived = True self.new_image_arrived = True - def rgb_depth_down_callback(self, rgb_msg, depth_msg): - t0 = time.time() + + def rgb_depth_down_callback(self, rgb_msg, depth_msg): raw_image = self.cv_bridge.imgmsg_to_cv2(rgb_msg, 'rgb8')[:, :, :] self.rgb_image = raw_image image = PIL_Image.fromarray(self.rgb_image) image_bytes = io.BytesIO() image.save(image_bytes, format='JPEG') image_bytes.seek(0) - + raw_depth = self.cv_bridge.imgmsg_to_cv2(depth_msg, '16UC1') raw_depth[np.isnan(raw_depth)] = 0 raw_depth[np.isinf(raw_depth)] = 0 @@ -270,22 +268,23 @@ def rgb_depth_down_callback(self, rgb_msg, depth_msg): depth = PIL_Image.fromarray(depth) depth_bytes = io.BytesIO() depth.save(depth_bytes, format='PNG') - depth_bytes.seek(0) - + depth_bytes.seek(0) + rgb_depth_rw_lock.acquire_write() self.rgb_bytes = image_bytes - + self.rgb_time = rgb_msg.header.stamp.sec + rgb_msg.header.stamp.nanosec / 1.0e9 self.last_rgb_time = self.rgb_time - + self.depth_bytes = depth_bytes self.depth_time = depth_msg.header.stamp.sec + depth_msg.header.stamp.nanosec / 1.0e9 self.last_depth_time = self.depth_time - + rgb_depth_rw_lock.release_write() self.new_vis_image_arrived = True - self.new_image_arrived = True + self.new_image_arrived = True + def odom_callback(self, msg): self.odom_cnt += 1 odom_rw_lock.acquire_write() @@ -299,13 +298,12 @@ def odom_callback(self, msg): self.angular_vel = msg.twist.twist.angular.z odom_rw_lock.release_write() - R0 = np.array([[np.cos(yaw), -np.sin(yaw)], - [np.sin(yaw), np.cos(yaw)]]) + R0 = np.array([[np.cos(yaw), -np.sin(yaw)], [np.sin(yaw), np.cos(yaw)]]) self.homo_odom = np.eye(4) self.homo_odom[:2, :2] = R0 self.homo_odom[:2, 3] = [msg.pose.pose.position.x, msg.pose.pose.position.y] self.vel = [msg.twist.twist.linear.x, msg.twist.twist.angular.z] - + if self.odom_cnt == 1: self.homo_goal = self.homo_odom.copy() @@ -322,21 +320,18 @@ def incremental_change_goal(self, actions): homo_goal[1, 3] += 0.25 * np.sin(yaw) elif each_action == 2: angle = math.radians(15) - rotation_matrix = np.array([ - [math.cos(angle), -math.sin(angle), 0], - [math.sin(angle), math.cos(angle), 0], - [0, 0, 1] - ]) + rotation_matrix = np.array( + [[math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1]] + ) homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) elif each_action == 3: angle = -math.radians(15.0) - rotation_matrix = np.array([ - [math.cos(angle), -math.sin(angle), 0], - [math.sin(angle), math.cos(angle), 0], - [0, 0, 1] - ]) - homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) + rotation_matrix = np.array( + [[math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1]] + ) + homo_goal[:3, :3] = np.dot(rotation_matrix, homo_goal[:3, :3]) self.homo_goal = homo_goal + def move(self, vx, vy, vyaw): request = Twist() request.linear.x = vx @@ -349,9 +344,9 @@ def move(self, vx, vy, vyaw): if __name__ == '__main__': control_thread_instance = threading.Thread(target=control_thread) planning_thread_instance = threading.Thread(target=planning_thread) - + rclpy.init() - + try: manager = Go2Manager() diff --git a/scripts/realworld/http_internvla_server.py b/scripts/realworld/http_internvla_server.py index 740b6691..fb13206a 100644 --- a/scripts/realworld/http_internvla_server.py +++ b/scripts/realworld/http_internvla_server.py @@ -1,21 +1,27 @@ -import numpy as np import argparse -import os import json +import os import sys import time + +import numpy as np + sys.path.append('/home/pjlab/yq_ws/InternNav') -from flask import Flask, request, jsonify -from PIL import Image from datetime import datetime + +from flask import Flask, jsonify, request +from PIL import Image + from internnav.agent.internvla_n1_agent_realworld import InternVLAN1AsyncAgent -from internnav.model.utils.vln_utils import S2Output + app = Flask(__name__) idx = 0 start_time = time.time() output_dir = '' -@app.route("/eval_dual",methods=['POST']) + + +@app.route("/eval_dual", methods=['POST']) def eval_dual(): global idx, output_dir, start_time start_time = time.time() @@ -24,17 +30,17 @@ def eval_dual(): depth_file = request.files['depth'] json_data = request.form['json'] data = json.loads(json_data) - + image = Image.open(image_file.stream) image = image.convert('RGB') image = np.asarray(image) - + depth = Image.open(depth_file.stream) depth = depth.convert('I') depth = np.asarray(depth) depth = depth.astype(np.float32) / 10000.0 print(f"read http data cost {time.time() - start_time}") - + camera_pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) instruction = "Turn around and walk out of this office. Turn towards your slight right at the chair. Move forward to the walkway and go near the red bin. You can see an open door on your right side, go inside the open door. Stop at the computer monitor" policy_init = data['reset'] @@ -45,24 +51,28 @@ def eval_dual(): os.makedirs(output_dir, exist_ok=True) print("init reset model!!!") agent.reset() - + idx += 1 - + look_down = False t0 = time.time() dual_sys_output = {} - dual_sys_output = agent.step(image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down) + dual_sys_output = agent.step( + image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down + ) if dual_sys_output.output_action is not None and dual_sys_output.output_action == [5]: look_down = True - dual_sys_output = agent.step(image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down) - + dual_sys_output = agent.step( + image, depth, camera_pose, instruction, intrinsic=args.camera_intrinsic, look_down=look_down + ) + json_output = {} if dual_sys_output.output_action is not None: json_output['discrete_action'] = dual_sys_output.output_action if dual_sys_output.output_pixel is not None: json_output['pixel_goal'] = dual_sys_output.output_pixel - + t1 = time.time() generate_time = t1 - t0 print(f"dual sys step {generate_time}") @@ -70,22 +80,20 @@ def eval_dual(): return jsonify(json_output) - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--model_path", type=str, default="/path/to/InternVLA-N1") - parser.add_argument("--resize_w", type=int, default=384) + parser.add_argument("--resize_w", type=int, default=384) parser.add_argument("--resize_h", type=int, default=384) parser.add_argument("--num_history", type=int, default=8) args = parser.parse_args() - - args.camera_intrinsic = np.array([[386.5 , 0. , 328.9, 0. ], - [ 0. , 386.5 , 244, 0. ], - [ 0. , 0. , 1. , 0. ], - [ 0. , 0. , 0. , 1. ]]) + + args.camera_intrinsic = np.array( + [[386.5, 0.0, 328.9, 0.0], [0.0, 386.5, 244, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + ) agent = InternVLAN1AsyncAgent(args) agent.reset() - app.run(host='0.0.0.0', port=5801) \ No newline at end of file + app.run(host='0.0.0.0', port=5801) diff --git a/scripts/realworld/thread_utils.py b/scripts/realworld/thread_utils.py index 67812aa3..2bd02c5a 100644 --- a/scripts/realworld/thread_utils.py +++ b/scripts/realworld/thread_utils.py @@ -1,5 +1,6 @@ import threading + class ReadWriteLock: def __init__(self): self._read_ready = threading.Condition(threading.Lock()) @@ -24,4 +25,4 @@ def acquire_write(self): def release_write(self): with self._read_ready: self._readers = 0 - self._read_ready.notify_all() \ No newline at end of file + self._read_ready.notify_all() diff --git a/setup.cfg b/setup.cfg index 73140fbd..74e1c305 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ skip_glob = internutopia/*, internutopia_extension/* # than "BA" [codespell] quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,jetbot,wth,coverted,descrete +ignore-words-list = patten,nd,ty,mot,hist,formating,jetbot,wth,coverted,descrete,thw,ro skip = *.js *.txt From 9554120f71c9b78080116fb55087698a818a828e Mon Sep 17 00:00:00 2001 From: yangyuqiang Date: Fri, 26 Sep 2025 15:52:26 +0800 Subject: [PATCH 5/6] [fix] optimize the codebase. fix some typo. --- scripts/realworld/controllers.py | 20 ++++++++++++++++++++ scripts/realworld/http_internvla_client.py | 7 ++++--- scripts/realworld/http_internvla_server.py | 7 +------ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/scripts/realworld/controllers.py b/scripts/realworld/controllers.py index 1cbe2947..c9d495a8 100644 --- a/scripts/realworld/controllers.py +++ b/scripts/realworld/controllers.py @@ -13,6 +13,16 @@ class Mpc_controller: def __init__(self, global_planed_traj, N=20, desired_v=0.3, v_max=0.4, w_max=0.4, ref_gap=4): + """Initialize the MPC controller. + + Args: + global_planed_traj (np.ndarray): The global planned trajectory, shape (n, 2). + N (int): Prediction horizon. + desired_v (float): Desired linear velocity. + v_max (float): Maximum linear velocity. + w_max (float): Maximum angular velocity. + ref_gap (int): Gap between reference points in the prediction horizon. + """ self.N, self.desired_v, self.ref_gap, self.T = N, desired_v, ref_gap, 0.1 self.ref_traj = self.make_ref_denser(global_planed_traj) self.ref_traj_len = N // ref_gap + 1 @@ -141,6 +151,16 @@ def find_reference_traj(self, x0, global_planed_traj): class PID_controller: def __init__(self, Kp_trans=1.0, Kd_trans=0.1, Kp_yaw=1.0, Kd_yaw=1.0, max_v=1.0, max_w=1.2): + """Initialize the PID controller. + + Args: + Kp_trans (float): Proportional gain for translational error. + Kd_trans (float): Derivative gain for translational error. + Kp_yaw (float): Proportional gain for yaw error. + Kd_yaw (float): Derivative gain for yaw error. + max_v (float): Maximum linear velocity. + max_w (float): Maximum angular velocity. + """ self.Kp_trans = Kp_trans self.Kd_trans = Kd_trans self.Kp_yaw = Kp_yaw diff --git a/scripts/realworld/http_internvla_client.py b/scripts/realworld/http_internvla_client.py index f20adb6c..f7de8762 100644 --- a/scripts/realworld/http_internvla_client.py +++ b/scripts/realworld/http_internvla_client.py @@ -108,7 +108,7 @@ def planning_thread(): while True: start_time = time.time() - DISIRED_TIME = 0.3 + DESIRED_TIME = 0.3 time.sleep(0.05) if not manager.new_image_arrived: @@ -193,7 +193,7 @@ def planning_thread(): ) time.sleep(0.1) - time.sleep(max(0, DISIRED_TIME - (time.time() - start_time))) + time.sleep(max(0, DESIRED_TIME - (time.time() - start_time))) class Go2Manager(Node): @@ -344,7 +344,8 @@ def move(self, vx, vy, vyaw): if __name__ == '__main__': control_thread_instance = threading.Thread(target=control_thread) planning_thread_instance = threading.Thread(target=planning_thread) - + control_thread_instance.daemon = True + planning_thread_instance.daemon = True rclpy.init() try: diff --git a/scripts/realworld/http_internvla_server.py b/scripts/realworld/http_internvla_server.py index fb13206a..c93bbe83 100644 --- a/scripts/realworld/http_internvla_server.py +++ b/scripts/realworld/http_internvla_server.py @@ -1,15 +1,10 @@ import argparse import json import os -import sys import time - -import numpy as np - -sys.path.append('/home/pjlab/yq_ws/InternNav') - from datetime import datetime +import numpy as np from flask import Flask, jsonify, request from PIL import Image From 29fb6517deb51fb90a558038678efa35f4187b48 Mon Sep 17 00:00:00 2001 From: yangyuqiang Date: Sun, 28 Sep 2025 12:03:36 +0800 Subject: [PATCH 6/6] [feat] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d750e9ca..288c3afb 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ The toolbox supports the most comprehensive 6 datasets \& benchmarks and 10+ pop The toolbox supports the most advanced high-quality navigation dataset, InternData-N1, which includes 3k+ scenes and 830k VLN data covering diverse embodiments and scenes, and the first dual-system navigation foundation model with leading performance on all the benchmarks and zero-shot generalization capability in the real world, InternVLA-N1. ## 🔥 News - +- [2025/09] Real-world deployment code of InternVLA-N1 is released. - [2025/07] We are hosting 🏆IROS 2025 Grand Challenge, stay tuned at [official website](https://internrobotics.shlab.org.cn/challenge/2025/). - [2025/07] InternNav v0.1.1 released.