From 905fca7ed99108dfb56624ced0057140b2a900ed Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Wed, 23 Oct 2024 00:18:04 -0500 Subject: [PATCH 1/3] implement better logic for detecting weights/modules --- pytorch_optimizer/optimizer/utils.py | 31 +++++++++++++++++++++++++--- tests/test_utils.py | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index f44cd393..3c7dbff0 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -205,16 +205,41 @@ def get_optimizer_parameters( :param wd_ban_list: List[str]. ban list not to set weight decay. :returns: PARAMETERS. new parameter list. """ + + def find_fully_qualified_names( + model: nn.Module, + wd_ban_list: List[str] = ("bias", "LayerNorm.weight", "LayerNorm.bias"), + ): + names_without_wd = [] + + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + # Full parameter name includes module and parameter names + full_param_name = ( + f"{module_name}.{param_name}" if module_name else param_name + ) + # Check if any ban list substring is in the parameter name or module name + if ( + any(banned in param_name for banned in wd_ban_list) + or any(banned in module_name for banned in wd_ban_list) + or any(banned in module._get_name() for banned in wd_ban_list) + ): + names_without_wd.append(full_param_name) + + return names_without_wd + + full_names = find_fully_qualified_names(model_or_parameter, wd_ban_list) + if isinstance(model_or_parameter, nn.Module): model_or_parameter = list(model_or_parameter.named_parameters()) - + return [ { - 'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in wd_ban_list)], + 'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in full_names)], 'weight_decay': weight_decay, }, { - 'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in wd_ban_list)], + 'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in full_names)], 'weight_decay': 0.0, }, ] diff --git a/tests/test_utils.py b/tests/test_utils.py index e28f9f08..0a53c133 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -98,7 +98,7 @@ def test_neuron_mean_norm(): def test_get_optimizer_parameters(): model: nn.Module = Example() - wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm'] before_parameters = list(model.named_parameters()) after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list) From ea809429ceafb417f1791dd7dcdab8835e81cfb2 Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Wed, 23 Oct 2024 02:25:49 -0500 Subject: [PATCH 2/3] make it slightly more concise --- pytorch_optimizer/optimizer/utils.py | 41 +++++++++++----------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 3c7dbff0..5329b9f9 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -206,40 +206,31 @@ def get_optimizer_parameters( :returns: PARAMETERS. new parameter list. """ - def find_fully_qualified_names( - model: nn.Module, - wd_ban_list: List[str] = ("bias", "LayerNorm.weight", "LayerNorm.bias"), - ): - names_without_wd = [] - - for module_name, module in model.named_modules(): - for param_name, param in module.named_parameters(recurse=False): - # Full parameter name includes module and parameter names - full_param_name = ( - f"{module_name}.{param_name}" if module_name else param_name - ) - # Check if any ban list substring is in the parameter name or module name - if ( - any(banned in param_name for banned in wd_ban_list) - or any(banned in module_name for banned in wd_ban_list) - or any(banned in module._get_name() for banned in wd_ban_list) - ): - names_without_wd.append(full_param_name) - - return names_without_wd - - full_names = find_fully_qualified_names(model_or_parameter, wd_ban_list) + fully_qualified_names = [] + for module_name, module in model_or_parameter.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + # Full parameter name includes module and parameter names + full_param_name = ( + f"{module_name}.{param_name}" if module_name else param_name + ) + # Check if any ban list substring is in the parameter name or module name + if ( + any(banned in param_name for banned in wd_ban_list) + or any(banned in module_name for banned in wd_ban_list) + or any(banned in module._get_name() for banned in wd_ban_list) + ): + fully_qualified_names.append(full_param_name) if isinstance(model_or_parameter, nn.Module): model_or_parameter = list(model_or_parameter.named_parameters()) return [ { - 'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in full_names)], + 'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in fully_qualified_names)], 'weight_decay': weight_decay, }, { - 'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in full_names)], + 'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)], 'weight_decay': 0.0, }, ] From 546531c98d0b526335d6151faa214269f29dcbe6 Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Thu, 24 Oct 2024 09:56:01 -0500 Subject: [PATCH 3/3] fix some makefile issues --- pytorch_optimizer/optimizer/utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 5329b9f9..2efd78fa 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -198,21 +198,20 @@ def get_optimizer_parameters( weight_decay: float, wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), ) -> PARAMETERS: - r"""Get optimizer parameters while filtering specified modules. - + r""" + Get optimizer parameters while filtering specified modules. :param model_or_parameter: Union[nn.Module, List]. model or parameters. :param weight_decay: float. weight_decay. :param wd_ban_list: List[str]. ban list not to set weight decay. :returns: PARAMETERS. new parameter list. """ + fully_qualified_names = [] for module_name, module in model_or_parameter.named_modules(): - for param_name, param in module.named_parameters(recurse=False): + for param_name, _param in module.named_parameters(recurse=False): # Full parameter name includes module and parameter names - full_param_name = ( - f"{module_name}.{param_name}" if module_name else param_name - ) + full_param_name = f'{module_name}.{param_name}' if module_name else param_name # Check if any ban list substring is in the parameter name or module name if ( any(banned in param_name for banned in wd_ban_list) @@ -223,14 +222,20 @@ def get_optimizer_parameters( if isinstance(model_or_parameter, nn.Module): model_or_parameter = list(model_or_parameter.named_parameters()) - + return [ { - 'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in fully_qualified_names)], + 'params': [ + p + for n, p in model_or_parameter + if p.requires_grad and not any(nd in n for nd in fully_qualified_names) + ], 'weight_decay': weight_decay, }, { - 'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)], + 'params': [ + p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names) + ], 'weight_decay': 0.0, }, ]