From 9cfc774167dd9fe32d6bae31cc99f7fe3a91f73d Mon Sep 17 00:00:00 2001 From: atharvadeore999 <82568039+atharvadeore999@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:34:53 +0530 Subject: [PATCH 1/2] Update eval_percep.py --- evaluation/eval_percep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py index 53377a7..f49bb88 100644 --- a/evaluation/eval_percep.py +++ b/evaluation/eval_percep.py @@ -52,7 +52,7 @@ def load_dreamsim_model(args, device="cuda"): with open(os.path.join(args.eval_checkpoint_cfg), "r") as f: cfg = yaml.load(f, Loader=yaml.Loader) - model_cfg = vars(cfg) + model_cfg = cfg model_cfg['load_dir'] = args.load_dir model = LightningPerceptualModel(**model_cfg) model.load_lora_weights(args.eval_checkpoint) @@ -141,4 +141,4 @@ def run(args, device): args = parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" run(args, device) - \ No newline at end of file + From cf67f02658a9d285cc236a7dd8eb0b6ea7a43008 Mon Sep 17 00:00:00 2001 From: atharvadeore999 <82568039+atharvadeore999@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:46:04 +0530 Subject: [PATCH 2/2] Update train.py --- training/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/training/train.py b/training/train.py index f8105ef..e25aabb 100644 --- a/training/train.py +++ b/training/train.py @@ -184,7 +184,13 @@ def load_lora_weights(self, checkpoint_root, epoch_load=None): if self.save_mode in {'adapter_only', 'all'}: if epoch_load is not None: checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}') - + + with open(os.path.join(checkpoint_root, 'adapter_config.json'), 'r') as f: + adapter_config = json.load(f) + lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules'] + lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys}) + self.perceptual_model = get_peft_model(self.perceptual_model, lora_config) + logging.info(f'Loading adapter weights from {checkpoint_root}') self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device) else: