From b3930d3509a2718ebb7e53e6c9a72f1dfbb581c8 Mon Sep 17 00:00:00 2001 From: cpondoc Date: Mon, 20 Jan 2025 20:57:07 -0800 Subject: [PATCH] Fix package tests that are failing --- tests/package_test.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/package_test.py b/tests/package_test.py index f2b282e..f928161 100644 --- a/tests/package_test.py +++ b/tests/package_test.py @@ -50,15 +50,15 @@ def test_constructor(): env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs) # Make basic checks for the constructor - assert env.num_rows == num_rows - assert env.num_cols == num_cols - np.testing.assert_array_equal(env.populated_areas, populated_areas) - np.testing.assert_array_equal(env.paths, paths) + assert env.unwrapped.num_rows == num_rows + assert env.unwrapped.num_cols == num_cols + np.testing.assert_array_equal(env.unwrapped.populated_areas, populated_areas) + np.testing.assert_array_equal(env.unwrapped.paths, paths) # Special check for paths to populated areas for key in paths_to_pops: np.testing.assert_array_equal( - np.array(env.paths_to_pops[key]), + np.array(env.unwrapped.paths_to_pops[key]), np.array(paths_to_pops[key]), ) @@ -112,15 +112,15 @@ def test_reset(): # Check that reset makes it all the same env.reset() - assert env.num_rows == num_rows - assert env.num_cols == num_cols - np.testing.assert_array_equal(env.populated_areas, populated_areas) - np.testing.assert_array_equal(env.paths, paths) + assert env.unwrapped.num_rows == num_rows + assert env.unwrapped.num_cols == num_cols + np.testing.assert_array_equal(env.unwrapped.populated_areas, populated_areas) + np.testing.assert_array_equal(env.unwrapped.paths, paths) # Special check for paths to populated areas for key in paths_to_pops: np.testing.assert_array_equal( - np.array(env.paths_to_pops[key]), + np.array(env.unwrapped.paths_to_pops[key]), np.array(paths_to_pops[key]), ) @@ -242,6 +242,6 @@ def test_generate_gif(mocker): env.render() # Generate the gif, check that it exists, and then remove it - env.generate_gif() + env.unwrapped.generate_gif() assert os.path.exists("training.gif") os.remove("training.gif")