diff --git a/main.py b/main.py index ad69227..951b6f7 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ def main(): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) # Make the a directory corresponding to this run for saving results, checkpoints etc. @@ -68,4 +69,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()