diff --git a/ergodic_search/erg_metric.py b/ergodic_search/erg_metric.py index 62efb47..f81a1c6 100644 --- a/ergodic_search/erg_metric.py +++ b/ergodic_search/erg_metric.py @@ -150,9 +150,13 @@ def update_pdf(self, pdf, fourier_freqs=None, freq_wts=None): else: if fourier_freqs is not None: + if not isinstance(fourier_freqs, torch.Tensor): + fourier_freqs = torch.tensor(fourier_freqs) self.fourier_freqs = fourier_freqs if freq_wts is not None: + if not isinstance(freq_wts, torch.Tensor): + freq_wts = torch.tensor(freq_wts) self.freq_wts = freq_wts self.pdf = pdf @@ -248,6 +252,8 @@ def get_lambdak(self, k, freq_wts=None): # however MOES uses (-4/2) and that seems to produce better results, at least for this implementation lambdak = (1. + torch.linalg.norm(k / torch.pi, dim=1)**2)**(-4./2.) else: + if not isinstance(freq_wts, torch.Tensor): + freq_wts = torch.tensor(freq_wts) lambdak = freq_wts return lambdak diff --git a/ergodic_search/erg_planner.py b/ergodic_search/erg_planner.py index 5deeb8a..19a7bd8 100644 --- a/ergodic_search/erg_planner.py +++ b/ergodic_search/erg_planner.py @@ -34,6 +34,7 @@ def ErgArgs(): parser.add_argument('--debug', action='store_true', help='Whether to print loss components for debugging') parser.add_argument('--outpath', type=str, help='File path to save images to, None displays them in a window', default=None) parser.add_argument('--replan_type', type=str, default='full', help='Type of replanning to perform (accepts partial or full)') + parser.add_argument('--seed', type=str, default=687456, help='Seed to set for PyTorch stochastic optimization') args = parser.parse_args() print(args) @@ -63,6 +64,9 @@ def __init__(self, args, pdf=None, init_controls=None, dyn_model=None, fourier_f self.args = args self.pdf = pdf + # set seed + torch.random.manual_seed(args.seed) + # get device self.device = torch.device("cuda") if args.gpu else torch.device("cpu") diff --git a/example_replanning.py b/example_replanning.py index d7c772e..28265b3 100644 --- a/example_replanning.py +++ b/example_replanning.py @@ -100,6 +100,6 @@ def create_map(dim): planner.update_pdf(new_map) # "take a step" along the trajectory - # this will increment the controls such that the planner will start at the first point in the trajectory and + # this will increment the controls such that the planner will start at the first point in the trajectory planner.take_step()