Skip to content

Commit

Permalink
streamline cot config keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ggbetz committed Apr 11, 2024
1 parent f9bfe8f commit a636fb5
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/cot_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@
MAX_RETRIALS_PUSH_TO_HUB = 5
RETRIALS_INTERVAL = 30

COT_CONFIG_KEYS = [
"name",
"model",
"dtype",
"tensor_parallel_size",
"max_new_tokens",
"cot_chain",
"n",
"best_of",
"use_beam_search",
"temperature",
"top_p",
"top_k",
"gpu_memory_utilization",
"max_model_len",
"revision",
"swap_space",
]


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Expand Down Expand Up @@ -159,7 +178,7 @@ def main():
model_kwargs = config_data.pop("modelkwargs", {})
vllm_kwargs = model_kwargs.pop("vllm_kwargs", {})
config_data = {**config_data, **model_kwargs, **vllm_kwargs}
config_data = {k: str(v) for k, v in config_data.items()}
config_data = {k: str(v) for k, v in config_data.items() if k in COT_CONFIG_KEYS}
logging.info(f"Adding config_data: {config_data}")

for task, ds in cot_data.items():
Expand Down

0 comments on commit a636fb5

Please sign in to comment.