-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Hi authors — thanks again for open-sourcing the codebase.
I think there may be a bug in maybe_load_pretrained_model() in main.py related to loading the pretrained weights (especially for ChestCT).
Suspected issue: key renaming leads to visual.visual.*
The Hugging Face checkpoint keys appear to be of the form model.visual.*.
However, around line 596 the code does:
visual_key = src_key.replace("model.", "visual.")
So a key like model.visual.xxx becomes visual.visual.xxx.
That new key (visual.visual.*) will not exist in the model's state_dict(), so it gets skipped by the if visual_key in model_state_dict check (around line 597). As a result, model_state_dict receives no updates from the pretrained checkpoint, and later (around line 602) the model ends up loading the original/randomly initialized weights back into itself — i.e., the pretrained weights are not actually applied.
Another concern: hardcoded X-ray branch
Also, around line 533 there seems to be logic hardcoded for X-rays. When loading for ChestCTs, this path is skipped entirely, which might further contribute to pretrained weights not being applied for CT.
Could you please confirm whether this is expected? If not, it seems the key mapping should instead strip the model. prefix (e.g., model.visual.* -> visual.*) rather than replacing it with visual..