From 7b5ac9fd427dee289e785ccc7c6463c092ea49a0 Mon Sep 17 00:00:00 2001 From: Margaret Hansen Date: Wed, 26 Mar 2025 10:46:43 -0400 Subject: [PATCH 1/4] Fix frequency updates to force conversion to tensor from list --- ergodic_search/erg_metric.py | 4 ++++ example_replanning.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ergodic_search/erg_metric.py b/ergodic_search/erg_metric.py index 62efb47..ceb0abf 100644 --- a/ergodic_search/erg_metric.py +++ b/ergodic_search/erg_metric.py @@ -150,9 +150,13 @@ def update_pdf(self, pdf, fourier_freqs=None, freq_wts=None): else: if fourier_freqs is not None: + if not isinstance(fourier_freqs, torch.Tensor): + fourier_freqs = torch.tensor(fourier_freqs) self.fourier_freqs = fourier_freqs if freq_wts is not None: + if not isinstance(freq_wts, torch.Tensor): + freq_wts = torch.tensor(freq_wts) self.freq_wts = freq_wts self.pdf = pdf diff --git a/example_replanning.py b/example_replanning.py index d7c772e..28265b3 100644 --- a/example_replanning.py +++ b/example_replanning.py @@ -100,6 +100,6 @@ def create_map(dim): planner.update_pdf(new_map) # "take a step" along the trajectory - # this will increment the controls such that the planner will start at the first point in the trajectory and + # this will increment the controls such that the planner will start at the first point in the trajectory planner.take_step() From 159b318aa8dc54aaa1b9f4893112122c224a6f47 Mon Sep 17 00:00:00 2001 From: Margaret Hansen Date: Wed, 26 Mar 2025 10:49:35 -0400 Subject: [PATCH 2/4] Add option to set seed for replicability --- ergodic_search/erg_planner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ergodic_search/erg_planner.py b/ergodic_search/erg_planner.py index 5deeb8a..19a7bd8 100644 --- a/ergodic_search/erg_planner.py +++ b/ergodic_search/erg_planner.py @@ -34,6 +34,7 @@ def ErgArgs(): parser.add_argument('--debug', action='store_true', help='Whether to print loss components for debugging') parser.add_argument('--outpath', type=str, help='File path to save images to, None displays them in a window', default=None) parser.add_argument('--replan_type', type=str, default='full', help='Type of replanning to perform (accepts partial or full)') + parser.add_argument('--seed', type=str, default=687456, help='Seed to set for PyTorch stochastic optimization') args = parser.parse_args() print(args) @@ -63,6 +64,9 @@ def __init__(self, args, pdf=None, init_controls=None, dyn_model=None, fourier_f self.args = args self.pdf = pdf + # set seed + torch.random.manual_seed(args.seed) + # get device self.device = torch.device("cuda") if args.gpu else torch.device("cpu") From 10dc78fcd34e53979cdb9b4f2443f01bab20f042 Mon Sep 17 00:00:00 2001 From: Maggie Hansen Date: Fri, 28 Mar 2025 14:23:15 -0400 Subject: [PATCH 3/4] Force weights to be tensors --- ergodic_search/erg_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ergodic_search/erg_metric.py b/ergodic_search/erg_metric.py index ceb0abf..b38a0b4 100644 --- a/ergodic_search/erg_metric.py +++ b/ergodic_search/erg_metric.py @@ -252,6 +252,6 @@ def get_lambdak(self, k, freq_wts=None): # however MOES uses (-4/2) and that seems to produce better results, at least for this implementation lambdak = (1. + torch.linalg.norm(k / torch.pi, dim=1)**2)**(-4./2.) else: - lambdak = freq_wts + lambdak = torch.tensor(freq_wts) return lambdak From f1f71e203268d6030abee2711cab2dab1e80490a Mon Sep 17 00:00:00 2001 From: Maggie Hansen Date: Fri, 28 Mar 2025 14:24:12 -0400 Subject: [PATCH 4/4] Type checking for tensors --- ergodic_search/erg_metric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ergodic_search/erg_metric.py b/ergodic_search/erg_metric.py index b38a0b4..f81a1c6 100644 --- a/ergodic_search/erg_metric.py +++ b/ergodic_search/erg_metric.py @@ -252,6 +252,8 @@ def get_lambdak(self, k, freq_wts=None): # however MOES uses (-4/2) and that seems to produce better results, at least for this implementation lambdak = (1. + torch.linalg.norm(k / torch.pi, dim=1)**2)**(-4./2.) else: - lambdak = torch.tensor(freq_wts) + if not isinstance(freq_wts, torch.Tensor): + freq_wts = torch.tensor(freq_wts) + lambdak = freq_wts return lambdak