Refactor Config Dataclass to Dynamically Compute intermediate dimension#140
Open
githubshaurya wants to merge 1 commit intohuggingface:mainfrom
Open
Refactor Config Dataclass to Dynamically Compute intermediate dimension#140githubshaurya wants to merge 1 commit intohuggingface:mainfrom
githubshaurya wants to merge 1 commit intohuggingface:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently, the configuration classes in models/config.py calculate certain parameters once at class definition rather than when you actually create a config object.
For example:
vit_inter_dim is defined as 4 * vit_hidden_dim right in the class body, so it always uses the default vit_hidden_dim of 768, even if you pass in a different value when you instantiate the class.
Similarly, lm_vocab_size and the training-schedule fields (eval_interval and stats_log_interval) are computed statically based on default hyperparameters.
The problem
vit_hidden_dimfrom 768 to 1024 to give the Vision Transformer’s MLP more representational power expecting the MLP to expand to 4096, but it’ll stay stuck at 3072 and crash at runtime.Solution
I switched those dependent fields to use
field(init=False)and moved their calculations into a__post_init__method. Now, as soon as you instantiate your config, it recomputes everything based on the actual inputs you passed in.