diff --git a/breaching/attacks/optimization_based_attack.py b/breaching/attacks/optimization_based_attack.py index 9cf96e8..94a8209 100644 --- a/breaching/attacks/optimization_based_attack.py +++ b/breaching/attacks/optimization_based_attack.py @@ -117,7 +117,7 @@ def _run_trial(self, rec_model, shared_data, labels, stats, trial, initial_data= if self.cfg.optim.boxed: candidate.data = torch.max(torch.min(candidate, (1 - self.dm) / self.ds), -self.dm / self.ds) if objective_value < minimal_value_so_far: - minimal_value_so_far = objective_value.detach() + minimal_value_so_far = objective_value best_candidate = candidate.detach().clone() if iteration + 1 == self.cfg.optim.max_iterations or iteration % self.cfg.optim.callback == 0: @@ -184,7 +184,7 @@ def closure(): pass self.current_task_loss = total_task_loss # Side-effect this because of L-BFGS closure limitations :< - return total_objective + return total_objective.detach() return closure