Skip to content

Commit

Permalink
add weight decay
Browse files Browse the repository at this point in the history
Signed-off-by: sichu <sichu@nvidia.com>
  • Loading branch information
sichu2023 committed Jan 2, 2025
1 parent 4cd7cd8 commit 8a7c3e9
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def main(
overlap_param_gather: bool = False, # TODO waiting for a NeMo fix
average_in_collective: bool = True,
grad_reduce_in_fp32: bool = False,
weight_decay: float = 0.01,
) -> None:
"""Train an ESM2 model on UR data.
Expand Down Expand Up @@ -155,6 +156,7 @@ def main(
overlap_param_gather (bool): overlap parameter gather
average_in_collective (bool): average in collective
grad_reduce_in_fp32 (bool): gradient reduction in fp32
weight_decay (float): weight decay of the model
"""
# Create the result directory if it does not exist.
result_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -283,7 +285,7 @@ def main(
lr=lr,
optimizer="adam",
use_distributed_optimizer=True,
weight_decay=0.01,
weight_decay=weight_decay,
adam_beta1=0.9,
adam_beta2=0.98,
),
Expand Down Expand Up @@ -387,6 +389,7 @@ def train_esm2_entrypoint():
overlap_param_gather=args.overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
weight_decay=args.weight_decay,
)


Expand Down Expand Up @@ -694,6 +697,13 @@ def get_parser():
default=4 * 1280,
help="FFN hidden size of the model. Default is 4 * 1280.",
)
parser.add_argument(
"--weight-decay",
type=float,
required=False,
default=0.01,
help="Weight decay of the model. Default is 0.01.",
)
# DDP config
parser.add_argument(
"--no-overlap-grad-reduce",
Expand Down

0 comments on commit 8a7c3e9

Please sign in to comment.