Skip to content

Commit

Permalink
fix some makefile issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Oct 24, 2024
1 parent ea80942 commit 546531c
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,20 @@ def get_optimizer_parameters(
weight_decay: float,
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
) -> PARAMETERS:
r"""Get optimizer parameters while filtering specified modules.
r"""
Get optimizer parameters while filtering specified modules.
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
:param weight_decay: float. weight_decay.
:param wd_ban_list: List[str]. ban list not to set weight decay.
:returns: PARAMETERS. new parameter list.
"""


fully_qualified_names = []
for module_name, module in model_or_parameter.named_modules():
for param_name, param in module.named_parameters(recurse=False):
for param_name, _param in module.named_parameters(recurse=False):
# Full parameter name includes module and parameter names
full_param_name = (
f"{module_name}.{param_name}" if module_name else param_name
)
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
# Check if any ban list substring is in the parameter name or module name
if (
any(banned in param_name for banned in wd_ban_list)
Expand All @@ -223,14 +222,20 @@ def get_optimizer_parameters(

if isinstance(model_or_parameter, nn.Module):
model_or_parameter = list(model_or_parameter.named_parameters())

return [
{
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in fully_qualified_names)],
'params': [
p
for n, p in model_or_parameter
if p.requires_grad and not any(nd in n for nd in fully_qualified_names)
],
'weight_decay': weight_decay,
},
{
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)],
'params': [
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)
],
'weight_decay': 0.0,
},
]
Expand Down

0 comments on commit 546531c

Please sign in to comment.