diff --git a/d2go/optimizer/build.py b/d2go/optimizer/build.py index 11a5d64e..4d0ae34e 100644 --- a/d2go/optimizer/build.py +++ b/d2go/optimizer/build.py @@ -326,6 +326,126 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ) +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def nadam_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor NAdam optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.NAdam)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def radam_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor RAdam optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.RAdam)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def rmsprop_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor RMSprop optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.RMSprop)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def rprop_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor RMSprop optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Rprop)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def asgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor ASGD optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.ASGD)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def adamax_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor Adamax optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adamax)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def adadelta_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor Adadelta optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adadelta)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + +@D2GO_OPTIM_MAPPER_REGISTRY.register() +def adagrad_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build a multi_tensor Adagrad optimizer that works significantly faster. + This version is expected to be the default implementation for adamw + optimizer by end of H1'21. To benefit from the speedup, the number + of parameter groups needs to be reduced using `reduce_param_groups`. + """ + params = get_optimizer_param_groups(model, cfg) + return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adagrad)( + params=params, + lr=cfg.SOLVER.BASE_LR, + ) + + def build_optimizer_mapper(cfg, model): name = cfg.SOLVER.OPTIMIZER optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)