Skip to content

Commit

Permalink
Wandb step handling bugfix and feature (#2580)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjmielke authored Dec 20, 2024
1 parent 6ccd520 commit b86aa21
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ This mode supports a number of command-line arguments, the details of which can

* `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.

* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```
* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```. Also allows for the passing of the step to log things at (passed to `wandb.run.log`), e.g., `--wandb_args step=123`.

* `--hf_hub_log_args` : Logs evaluation results to Hugging Face Hub. Accepts a string with the arguments separated by commas. Available arguments:
* `hub_results_org` - organization name on Hugging Face Hub, e.g., `EleutherAI`. If not provided, the results will be pushed to the owner of the Hugging Face token,
Expand Down
15 changes: 9 additions & 6 deletions lm_eval/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, **kwargs) -> None:

self.wandb_args: Dict[str, Any] = kwargs

# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)

# initialize a W&B run
if wandb.run is None:
self.run = wandb.init(**self.wandb_args)
Expand Down Expand Up @@ -152,11 +155,11 @@ def make_table(columns: List[str], key: str = "results"):

# log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table})
self.run.log({"evaluation/eval_results": table}, step=self.step)

if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table})
self.run.log({"evaluation/group_eval_results": table}, step=self.step)

def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B."""
Expand All @@ -174,13 +177,13 @@ def log_eval_result(self) -> None:
"""Log evaluation results to W&B."""
# Log configs to wandb
configs = self._get_config()
self.run.config.update(configs)
self.run.config.update(configs, allow_val_change=self.step is not None)

wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb
self.run.log(self.wandb_results)
self.run.log(self.wandb_results, step=self.step)
# Log the evaluation metrics as W&B Table
self._log_results_as_table()
# Log the results dict as json to W&B Artifacts
Expand Down Expand Up @@ -329,7 +332,7 @@ def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:

# log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df})
self.run.log({f"{task_name}_eval_results": df}, step=self.step)

# log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name)
Expand All @@ -348,4 +351,4 @@ def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
# log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name)

self.run.log({f"{group}_eval_results": grouped_df})
self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
3 changes: 2 additions & 1 deletion lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def simple_parse_args_string(args_string):
return {}
arg_list = [arg for arg in args_string.split(",") if arg]
args_dict = {
k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]
kv[0]: handle_arg_string("=".join(kv[1:]))
for kv in [arg.split("=") for arg in arg_list]
}
return args_dict

Expand Down

0 comments on commit b86aa21

Please sign in to comment.