-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Implement better wd_ban_list
handling
#282
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi thanks for the contribution!
wd_ban_list logic is only checking for the actual, fully-qualified parameter names
Yes. It was originally intended to exclude parameters that included names on the blacklist. So as you mentioned above, if you have a layer norm layer called 'asdf' and don't put the exact parameter name into wd_ban_list
for example, it won't be excluded. I added LayerNorm.bias
, and LayerNorm.weight
to the default wd_ban_list
to align with the usages in the Transformers
library.
Your idea sounds good to me also in the aspect of adding module names (e.g. LayerNorm) in the exclusion criteria cuz we usually ban based on the type of module.
your code looks good to me! could you please run make format
& make check
by any chance? or I can handle it later then.
I just pushed a new commit, with a few fixes. However, there is one error I was not able to fix:
If you run I'm not super familiar with |
it's okay. I can handle lint stuff. anyway, thanks for the contributions! |
Problem (Why?)
The
wd_ban_list
argument forget_optimizer_parameters()
is somewhat misleading. When you look at it, you would expect any of the default arguments' name-formats to work correctly. However, that is not the case.From this list, only
bias
is "detected" and "banned" correctly. NeitherLayerNorm.bias
is detected, nor isLayerNorm.weight
. Neither of these parameters have theirweight_decay
set to 0.I even tested
LayerNorm
- and that doesn't work, either.Solution (What/How?)
The reason this fails is that the
wd_ban_list
logic is only checking for the actual, fully-qualified parameter names; it is NOT checking for the class name of eachnn.Module
, aspytorch_optimizer
's default arguments and tests would imply.I implemented a more complete method for handling the
wd_ban_list
. Now, we check both for "true names", as well as fornn.Module
names.Notes
I've been using this patch in my own code for several weeks now; it seems to work great! Let me know if there is anything you would change.