Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions internnav/evaluator/vln_multi_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import sys
from enum import Enum
from pathlib import Path
from time import time

import numpy as np

from internnav.configs.evaluator import EvalCfg
from internnav.evaluator.base import Evaluator
from internnav.evaluator.utils.common import set_seed_model
Expand All @@ -28,22 +31,23 @@ def transform_action_batch(actions, flash=False):
if 'ideal_flag' in action.keys():
ideal_flag = action['ideal_flag']
if flash:
assert ideal_flag == True
assert ideal_flag is True
else:
ideal_flag = False
if not ideal_flag:
transformed_actions.append({'h1': {'vln_dp_move_by_speed': action['action'][0]}})
continue
a = action['action']
if a == 0 or a == [0] or a==[[0]]:
if a == 0 or a == [0] or a == [[0]]:
transformed_actions.append({'h1': {'stop': []}})
elif a == -1 or a == [-1] or a == [[-1]]:
transformed_actions.append({'h1': {'stand_still': []}})
else:
move = f"move_by_{'discrete' if not flash else 'flash'}"
transformed_actions.append({'h1': {move: a}}) # discrete e.g. [3]
transformed_actions.append({'h1': {move: a}}) # discrete e.g. [3]
return transformed_actions


@Evaluator.register('vln_multi')
class VlnMultiEvaluator(Evaluator):
def __init__(self, config: EvalCfg):
Expand All @@ -61,6 +65,9 @@ def __init__(self, config: EvalCfg):
)
# generate episode
episodes = generate_episode(self.dataloader, config)
if len(episodes) == 0:
log.info("No more episodes to evaluate")
sys.exit(0)
config.task.task_settings.update({'episodes': episodes})
self.env_num = config.task.task_settings['env_num']
self.proc_num = (
Expand Down Expand Up @@ -88,7 +95,6 @@ def __init__(self, config: EvalCfg):
self.data_collector = DataCollector(self.dataloader.lmdb_path)
self.robot_flash = config.task.robot_flash


@property
def ignore_obs_attr(self):
return [
Expand Down Expand Up @@ -223,15 +229,11 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
log.info(f'env{reset_env_ids}: states switch to WARM UP.')
# modify original reset_info
reset_infos = np.array(reset_infos)
reset_infos[reset_env_ids] = (
new_reset_infos if len(new_reset_infos) > 0 else None
)
reset_infos[reset_env_ids] = new_reset_infos if len(new_reset_infos) > 0 else None
self.runner_status[
np.vectorize(lambda x: x)(reset_infos) == None # noqa: E711
] = runner_status_code.TERMINATED
log.info(
f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.'
)
log.info(f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.')
reset_infos = reset_infos.tolist()

if np.logical_and.reduce(self.runner_status == runner_status_code.TERMINATED):
Expand Down
30 changes: 15 additions & 15 deletions internnav/evaluator/vln_pe_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import sys
from enum import Enum
from pathlib import Path
from time import time

import numpy as np

from internnav.configs.evaluator import EvalCfg
from internnav.evaluator.base import Evaluator
from internnav.evaluator.utils.common import set_seed_model, obs_to_image
from internnav.evaluator.utils.common import set_seed_model
from internnav.evaluator.utils.config import get_lmdb_path
from internnav.evaluator.utils.data_collector import DataCollector
from internnav.evaluator.utils.dataset import ResultLogger, split_data
Expand Down Expand Up @@ -56,6 +59,9 @@ def __init__(self, config: EvalCfg):

# generate episode
episodes = generate_episode(self.dataloader, config)
if len(episodes) == 0:
log.info("No more episodes to evaluate. Episodes are saved in data/sample_episodes/")
sys.exit(0)
config.task.task_settings.update({'episodes': episodes})
self.env_num = config.task.task_settings['env_num']
self.proc_num = (
Expand Down Expand Up @@ -211,7 +217,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):

# need this status to reset
reset_env_ids = np.where(self.runner_status == runner_status_code.NOT_RESET)[0].tolist()

if len(reset_env_ids) > 0:
log.debug(f'env{reset_env_ids}: start new episode!')
obs, new_reset_infos = self.env.reset(reset_env_ids)
Expand All @@ -225,9 +231,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
self.runner_status[
np.vectorize(lambda x: x)(reset_infos) == None # noqa: E711
] = runner_status_code.TERMINATED
log.debug(
f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.'
)
log.debug(f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.')
reset_infos = reset_infos.tolist()

if np.logical_and.reduce(self.runner_status == runner_status_code.TERMINATED):
Expand All @@ -241,8 +245,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
)
if self.vis_output:
self.visualize_util.trace_start(
trajectory_id=self.now_path_key(reset_info),
reference_path=reset_info.data['reference_path']
trajectory_id=self.now_path_key(reset_info), reference_path=reset_info.data['reference_path']
)
return False, reset_infos

Expand All @@ -258,8 +261,7 @@ def eval(self):
)
if self.vis_output:
self.visualize_util.trace_start(
trajectory_id=self.now_path_key(info),
reference_path=info.data['reference_path']
trajectory_id=self.now_path_key(info), reference_path=info.data['reference_path']
)
log.info('start new episode!')

Expand All @@ -281,18 +283,16 @@ def eval(self):
env_term, reset_info = self.terminate_ops(obs, reset_info, terminated)
if env_term:
break

# save step obs
if self.vis_output:
for ob, info, act in zip(obs, reset_info, action):
if info is None or not 'rgb' in ob or ob['fail_reason']:
if info is None or 'rgb' not in ob or ob['fail_reason']:
continue
self.visualize_util.save_observation(
trajectory_id=self.now_path_key(info),
obs=ob,
action=act[self.robot_name]
trajectory_id=self.now_path_key(info), obs=ob, action=act[self.robot_name]
)

self.env.close()
progress_log_multi_util.report()

Expand Down
12 changes: 6 additions & 6 deletions tests/function_test/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def test_server():
common_body(start_command)


@pytest.mark.gpu
def test_challenge():
start_command = 'python ./tests/function_test/test_challenge.py'
@pytest.mark.ray
def test_evaluator():
start_command = 'python ./tests/function_test/test_evaluator.py'
common_body(start_command)


@pytest.mark.ray
def test_challenge_ray():
start_command = 'python ./tests/function_test/test_challenge_ray.py'
@pytest.mark.gpu
def test_challenge():
start_command = 'python ./tests/function_test/test_challenge.py'
common_body(start_command)
39 changes: 16 additions & 23 deletions tests/function_test/test_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
'''

import importlib.util
import subprocess
import sys
import time

import numpy as np
from test_server import start_server, stop_server

from internnav.configs.evaluator.default_config import get_config
from internnav.evaluator import Evaluator
Expand Down Expand Up @@ -66,39 +66,32 @@ def load_eval_cfg(config_path, attr_name='eval_cfg'):
evaluator.env.close()


def start_server():
server_cmd = [
sys.executable,
"internnav/agent/utils/server.py",
"--config",
"scripts/eval/configs/challenge_cfg.py",
]
def start_evaluator():
from multiprocessing import get_context

proc = subprocess.Popen(
server_cmd,
stdout=None,
stderr=None,
)
return proc
ctx = get_context("spawn") # Use 'spawn' to avoid issues on some platforms
p = ctx.Process(target=main)
p.start()
p.join()
assert p.exitcode == 0
print("Evaluator process completed successfully.")


if __name__ == '__main__':
try:
proc = start_server()
time.sleep(3)
main()
start_evaluator()

except Exception as e:
print(f'exception is {e}')
import traceback

traceback.print_exc()
sys.exit(1)

except SystemExit as e:
print(f"Caught SystemExit from env.close(): code={e.code}", flush=True)

finally:
if proc and proc.poll() is None:
print("Shutting down server...")
proc.terminate()
try:
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
print("Force killing server...")
proc.kill()
stop_server(proc)
Loading