Skip to content

Commit

Permalink
Adding search for all torch multi-tensor optimizers
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #629

This diff adds all of torch multi-tensor optimizers to d2go since it only supports Adamw, Adam and SGD in its current form.

Reviewed By: mlopezantequera

Differential Revision: D50498623

fbshipit-source-id: 5a38509354e565dd22256261bf1a688bcdc94951
  • Loading branch information
Matteo Presutto authored and facebook-github-bot committed Oct 23, 2023
1 parent b18c078 commit 7ace1ef
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions d2go/optimizer/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,126 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def nadam_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor NAdam optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.NAdam)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def radam_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor RAdam optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.RAdam)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def rmsprop_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor RMSprop optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.RMSprop)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def rprop_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor RMSprop optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Rprop)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def asgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor ASGD optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.ASGD)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adamax_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor Adamax optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adamax)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adadelta_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor Adadelta optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adadelta)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adagrad_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor Adagrad optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adagrad)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)


def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER
optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
Expand Down

0 comments on commit 7ace1ef

Please sign in to comment.