From c225602c1bc3c345ef64c904d4f985af9d5441d2 Mon Sep 17 00:00:00 2001 From: Petr Baudis Date: Sat, 28 Dec 2024 18:46:40 +0100 Subject: [PATCH 1/4] fix(zeno): Generate unique ids in case of multiple filters --- scripts/zeno_visualize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/zeno_visualize.py b/scripts/zeno_visualize.py index 4bc7e03bf8..f040e1ce8f 100644 --- a/scripts/zeno_visualize.py +++ b/scripts/zeno_visualize.py @@ -168,7 +168,7 @@ def generate_dataset( Returns: pd.Dataframe: A dataframe that is ready to be uploaded to Zeno. """ - ids = [x["doc_id"] for x in data] + ids = [x["doc_id"] for x in data] if not config.get("filter_list") else [f"{x['doc_id']}.{x['filter']}" for x in data] labels = [x["target"] for x in data] instance = [""] * len(ids) @@ -190,6 +190,7 @@ def generate_dataset( return pd.DataFrame( { "id": ids, + "doc_id": [x["doc_id"] for x in data], "data": instance, "input_len": [len(x) for x in instance], "labels": labels, @@ -208,8 +209,11 @@ def generate_system_df(data, config): Returns: pd.Dataframe: A dataframe that is ready to be uploaded to Zeno as a system. """ - ids = [x["doc_id"] for x in data] + ids = [x["doc_id"] for x in data] if not config.get("filter_list") else [f"{x['doc_id']}.{x['filter']}" for x in data] system_dict = {"id": ids} + system_dict["doc_id"] = [x["doc_id"] for x in data] + if config.get("filter_list"): + system_dict["filter"] = [x["filter"] for x in data] system_dict["output"] = [""] * len(ids) if config["output_type"] == "loglikelihood": From 0bd64c252c76f19bec244898748bd4d23e30724d Mon Sep 17 00:00:00 2001 From: Petr Baudis Date: Sat, 28 Dec 2024 18:47:17 +0100 Subject: [PATCH 2/4] fix(zeno): Report even non-aggregable metrics, just not as metrics --- scripts/zeno_visualize.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/scripts/zeno_visualize.py b/scripts/zeno_visualize.py index f040e1ce8f..995d67edda 100644 --- a/scripts/zeno_visualize.py +++ b/scripts/zeno_visualize.py @@ -109,13 +109,14 @@ def main(): if model_index == 0: # Only need to assemble data for the first model metrics = [] for metric in config["metric_list"]: - metrics.append( - ZenoMetric( - name=metric["metric"], - type="mean", - columns=[metric["metric"]], + if metric.get("aggregation") == "mean": + metrics.append( + ZenoMetric( + name=metric["metric"], + type="mean", + columns=[metric["metric"]], + ) ) - ) project = client.create_project( name=args.project_name + (f"_{task}" if len(tasks) > 1 else ""), view="text-classification", @@ -232,11 +233,7 @@ def generate_system_df(data, config): system_dict["output"] = [str(x["filtered_resps"][0]) for x in data] system_dict["output_length"] = [len(str(x["filtered_resps"][0])) for x in data] - metrics = {} - for metric in config["metric_list"]: - if "aggregation" in metric and metric["aggregation"] == "mean": - metrics[metric["metric"]] = [x[metric["metric"]] for x in data] - + metrics = {metric["metric"]: [x[metric["metric"]] for x in data] for metric in config["metric_list"]} system_dict.update(metrics) system_df = pd.DataFrame(system_dict) return system_df From 5cca68f04ca2416bc824626d441a0cd094a63d70 Mon Sep 17 00:00:00 2001 From: Petr Baudis Date: Sun, 29 Dec 2024 17:26:43 +0100 Subject: [PATCH 3/4] Add a basic support for --multiple-choice-generate --- lm_eval/__main__.py | 12 ++++ lm_eval/api/task.py | 86 ++++++++++++++++++--------- lm_eval/evaluator.py | 11 +++- lm_eval/loggers/evaluation_tracker.py | 3 + scripts/zeno_visualize.py | 16 ++--- 5 files changed, 93 insertions(+), 35 deletions(-) diff --git a/lm_eval/__main__.py b/lm_eval/__main__.py index ab68781939..989cd680ba 100644 --- a/lm_eval/__main__.py +++ b/lm_eval/__main__.py @@ -187,6 +187,17 @@ def setup_parser() -> argparse.ArgumentParser: default=False, help="If True, uses the fewshot as a multi-turn conversation", ) + parser.add_argument( + "--multiple_choice_generate", + action="store_true", + default=False, + help=( + "If True, multiple choice problems are not evaluated based on lowest logprob continuation, " + "but asking the model to generate the choice letter. This departs from the traditional evaluation " + "methodology, but allows evaluation with popular chat-completion APIs and evaluates each multiple choice " + "problem only once rather than #choice times." + ), + ) parser.add_argument( "--show_config", action="store_true", @@ -396,6 +407,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: system_instruction=args.system_instruction, apply_chat_template=args.apply_chat_template, fewshot_as_multiturn=args.fewshot_as_multiturn, + multiple_choice_generate=args.multiple_choice_generate, gen_kwargs=args.gen_kwargs, task_manager=task_manager, verbosity=args.verbosity, diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 555cb4330d..82c3d6cbcf 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -27,7 +27,7 @@ from lm_eval import utils from lm_eval.api import samplers from lm_eval.api.instance import Instance, OutputType -from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity +from lm_eval.api.metrics import bits_per_byte, exact_match_fn, mean, weighted_perplexity from lm_eval.api.registry import ( AGGREGATION_REGISTRY, DEFAULT_METRIC_REGISTRY, @@ -80,6 +80,7 @@ class TaskConfig(dict): use_prompt: Optional[str] = None description: str = "" target_delimiter: str = " " + choice_delimiter: str = " / " fewshot_delimiter: str = "\n\n" fewshot_config: Optional[dict] = None # runtime configuration options @@ -111,16 +112,15 @@ def __post_init__(self) -> None: if "until" not in self.generation_kwargs: self.generation_kwargs["until"] = [self.fewshot_delimiter] else: - if self.output_type == "generate_until": - # ensure that we greedily generate in absence of explicit arguments otherwise - self.generation_kwargs = { - "until": ( - None - if self.fewshot_delimiter is None - else [self.fewshot_delimiter] - ), - "do_sample": False, - } + # ensure that we greedily generate in absence of explicit arguments otherwise + self.generation_kwargs = { + "until": ( + None + if self.fewshot_delimiter is None + else [self.fewshot_delimiter] + ), + "do_sample": False, + } def __getitem__(self, item): return getattr(self, item) @@ -380,6 +380,7 @@ def build_all_requests( system_instruction: Optional[str] = None, apply_chat_template: bool = False, fewshot_as_multiturn: bool = False, + multiple_choice_generate: bool = False, chat_template: Optional[Callable] = None, tokenizer_name: str = "", ) -> None: @@ -391,6 +392,7 @@ def build_all_requests( cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" cache_key += "-chat_template" if apply_chat_template else "" cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" + cache_key += "-multiple_choice_generate" if multiple_choice_generate else "" cache_key += ( f"-system_prompt_hash{utils.hash_string(system_instruction)}" if system_instruction is not None @@ -435,12 +437,19 @@ def build_all_requests( total=num_docs, ): # sample fewshot context #TODO: need to offset doc_id by rank now! + doc_system_instruction = system_instruction or "" + if multiple_choice_generate: + if doc_system_instruction: + doc_system_instruction += " " + doc_system_instruction += "Please answer with the letter of the correct answer." + fewshot_ctx = self.fewshot_context( doc, 0 if self.config.num_fewshot is None else self.config.num_fewshot, - system_instruction, + doc_system_instruction, apply_chat_template, fewshot_as_multiturn, + multiple_choice_generate, chat_template, ) @@ -450,6 +459,7 @@ def build_all_requests( ctx=fewshot_ctx, metadata=(self.config["task"], doc_id, self.config.repeats), apply_chat_template=apply_chat_template, + multiple_choice_generate=multiple_choice_generate, ) if not isinstance(inst, list): @@ -1024,6 +1034,7 @@ def fewshot_context( system_instruction: Optional[str] = None, apply_chat_template: bool = False, fewshot_as_multiturn: bool = False, + multiple_choice_generate: bool = False, chat_template: Optional[Callable] = None, ) -> str: """Returns a fewshot context string that is made up of a prepended description @@ -1039,6 +1050,8 @@ def fewshot_context( Whether to apply the chat template to the fewshot context. :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param multiple_choice_generate: bool + Whether to generate multiple choice answer from scratch rather than pick by logprobs. :param chat_template: callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. :returns: str @@ -1085,6 +1098,12 @@ def fewshot_context( labeled_examples += self.sampler.get_context(doc, num_fewshot) example = self.doc_to_text(doc) + if self.config.doc_to_choice is not None and multiple_choice_generate: + if not isinstance(example, str): + raise NotImplementedError("--multiple_choice_generate is implemented only for simple text docs") + example += self.config.target_delimiter + example += "(" + self.config.choice_delimiter.join(self.doc_to_choice(doc)) + ")" + if apply_chat_template: if self.multiple_input: return chat_template(labeled_examples) @@ -1300,17 +1319,24 @@ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: return None def construct_requests( - self, doc: dict, ctx: str, **kwargs + self, doc: dict, ctx: str, multiple_choice_generate: bool, **kwargs ) -> Union[List[Instance], Instance]: apply_chat_template = kwargs.pop("apply_chat_template", False) aux_arguments = None - if self.OUTPUT_TYPE == "loglikelihood": + self.multiple_choice_generate = multiple_choice_generate + output_type = self.OUTPUT_TYPE + if output_type == "multiple_choice" and multiple_choice_generate: + output_type = "generate_until" + if self.multiple_input: + raise NotImplementedError("The \"multiple input\" mode of multiple_choice tasks is not implemented for --multiple_choice_generate.") + + if output_type == "loglikelihood": arguments = (ctx, self.doc_to_target(doc)) - elif self.OUTPUT_TYPE == "loglikelihood_rolling": + elif output_type == "loglikelihood_rolling": arguments = (self.doc_to_target(doc),) - elif self.OUTPUT_TYPE == "multiple_choice": + elif output_type == "multiple_choice": choices = self.doc_to_choice(doc) target_delimiter = self.config.target_delimiter if apply_chat_template: @@ -1337,7 +1363,7 @@ def construct_requests( arguments.extend(aux_arguments) - elif self.OUTPUT_TYPE == "generate_until": + elif output_type == "generate_until": arguments = (ctx, deepcopy(self.config.generation_kwargs)) multimodal_arg = {} @@ -1355,7 +1381,7 @@ def construct_requests( else: arguments = arguments + (multimodal_arg,) - if self.OUTPUT_TYPE == "multiple_choice": + if output_type == "multiple_choice": request_list = [ Instance( request_type="loglikelihood", @@ -1370,7 +1396,7 @@ def construct_requests( return request_list return Instance( - request_type=self.OUTPUT_TYPE, + request_type=output_type, doc=doc, arguments=arguments, idx=0, @@ -1411,7 +1437,7 @@ def process_results(self, doc, results): else {} ), } - elif self.OUTPUT_TYPE == "multiple_choice": + elif self.OUTPUT_TYPE == "multiple_choice" and not self.multiple_choice_generate: lls, is_greedy = zip(*results) # retrieve choices in List[str] form, to compute choice lengths, etc. @@ -1492,7 +1518,7 @@ def process_results(self, doc, results): acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 result_dict["acc_mutual_info"] = acc_mutual_info - elif self.OUTPUT_TYPE == "generate_until": + elif self.OUTPUT_TYPE == "generate_until" or (self.OUTPUT_TYPE == "multiple_choice" and self.multiple_choice_generate): gold = self.doc_to_target(doc) result = results[0] if self.config.doc_to_choice is not None: @@ -1511,6 +1537,12 @@ def process_results(self, doc, results): gold = type(result)(gold) for metric in self._metric_fn_list.keys(): + metric_fn = self._metric_fn_list[metric] + metric_result_key = metric + if self.OUTPUT_TYPE == "multiple_choice" and self.multiple_choice_generate: + metric_fn = exact_match_fn + metric_result_key = "exact_match" + if self.multiple_target: # in the case where we have multiple targets, # return true if any are true @@ -1522,7 +1554,7 @@ def process_results(self, doc, results): gold = [gold] if metric == "exact_match": result = [result for _ in range(len(gold))] - scores = self._metric_fn_list[metric]( + scores = metric_fn( references=gold, predictions=result, **self._metric_fn_kwargs[metric], @@ -1531,7 +1563,7 @@ def process_results(self, doc, results): else: for gold_option in gold: try: - result_score = self._metric_fn_list[metric]( + result_score = metric_fn( references=[gold_option], predictions=[result], **self._metric_fn_kwargs[metric], @@ -1539,7 +1571,7 @@ def process_results(self, doc, results): except ( TypeError ): # TODO: this is hacky and I don't want to do it - result_score = self._metric_fn_list[metric]( + result_score = metric_fn( [gold_option, result] ) if isinstance(result_score, dict): @@ -1552,16 +1584,16 @@ def process_results(self, doc, results): result_score = 0.0 else: try: - result_score = self._metric_fn_list[metric]( + result_score = metric_fn( references=[gold], predictions=[result], **self._metric_fn_kwargs[metric], ) except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics - result_score = self._metric_fn_list[metric]([gold, result]) + result_score = metric_fn([gold, result]) if isinstance(result_score, dict): # TODO: this handles the case where HF evaluate returns a dict. - result_score = result_score[metric] + result_score = result_score[metric_result_key] result_dict[metric] = result_score else: raise ValueError( diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index e7dd3043cb..d5e3792a81 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -66,6 +66,7 @@ def simple_evaluate( system_instruction: Optional[str] = None, apply_chat_template: Union[bool, str] = False, fewshot_as_multiturn: bool = False, + multiple_choice_generate: bool = False, gen_kwargs: Optional[str] = None, task_manager: Optional[TaskManager] = None, verbosity: str = "INFO", @@ -119,6 +120,8 @@ def simple_evaluate( Defaults to False (no chat template applied). :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param multiple_choice_generate: bool + Whether to generate multiple choice answer from scratch rather than pick by logprobs. :param gen_kwargs: str String arguments for model generation Ignored for all tasks with loglikelihood output_type @@ -246,7 +249,7 @@ def _adjust_config(task_dict): } else: - if task_obj.get_config("output_type") == "generate_until": + if task_obj.get_config("output_type") == "generate_until" or multiple_choice_generate: if gen_kwargs is not None: task_obj.set_config( key="generation_kwargs", value=gen_kwargs, update=True @@ -298,6 +301,7 @@ def _adjust_config(task_dict): if apply_chat_template else None, fewshot_as_multiturn=fewshot_as_multiturn, + multiple_choice_generate=multiple_choice_generate, ) results = evaluate( @@ -312,6 +316,7 @@ def _adjust_config(task_dict): system_instruction=system_instruction, apply_chat_template=apply_chat_template, fewshot_as_multiturn=fewshot_as_multiturn, + multiple_choice_generate=multiple_choice_generate, verbosity=verbosity, ) @@ -371,6 +376,7 @@ def evaluate( system_instruction: Optional[str] = None, apply_chat_template: Union[bool, str] = False, fewshot_as_multiturn: bool = False, + multiple_choice_generate: bool = False, verbosity: str = "INFO", ): """Instantiate and evaluate a model on a list of tasks. @@ -396,6 +402,8 @@ def evaluate( Defaults to False (no chat template applied). :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param multiple_choice_generate: bool + Whether to generate multiple choice answer from scratch rather than pick by logprobs. :return Dictionary of results """ @@ -457,6 +465,7 @@ def evaluate( system_instruction=system_instruction, apply_chat_template=bool(apply_chat_template), fewshot_as_multiturn=fewshot_as_multiturn, + multiple_choice_generate=multiple_choice_generate, chat_template=getattr(lm, "apply_chat_template") if apply_chat_template else None, diff --git a/lm_eval/loggers/evaluation_tracker.py b/lm_eval/loggers/evaluation_tracker.py index 067b047b59..70e7b7e1c9 100644 --- a/lm_eval/loggers/evaluation_tracker.py +++ b/lm_eval/loggers/evaluation_tracker.py @@ -51,6 +51,7 @@ class GeneralConfigTracker: system_instruction: str = None system_instruction_sha: str = None fewshot_as_multiturn: bool = None + multiple_choice_generate: bool = None chat_template: str = None chat_template_sha: str = None start_time: float = None @@ -84,6 +85,7 @@ def log_experiment_args( system_instruction: str, chat_template: str, fewshot_as_multiturn: bool, + multiple_choice_generate: bool, ) -> None: """Logs model parameters and job ID.""" self.model_source = model_source @@ -96,6 +98,7 @@ def log_experiment_args( self.chat_template = chat_template self.chat_template_sha = hash_string(chat_template) if chat_template else None self.fewshot_as_multiturn = fewshot_as_multiturn + self.multiple_choice_generate = multiple_choice_generate def log_end_time(self) -> None: """Logs the end time of the evaluation and calculates the total evaluation time.""" diff --git a/scripts/zeno_visualize.py b/scripts/zeno_visualize.py index 995d67edda..903acd6735 100644 --- a/scripts/zeno_visualize.py +++ b/scripts/zeno_visualize.py @@ -84,12 +84,13 @@ def main(): latest_sample_results = get_latest_filename( [Path(f).name for f in model_sample_filenames if task in f] ) + results = json.load( + open(Path(args.data_path, model, latest_results), encoding="utf-8") + ) model_args = re.sub( r"[\"<>:/\|\\?\*\[\]]+", "__", - json.load( - open(Path(args.data_path, model, latest_results), encoding="utf-8") - )["config"]["model_args"], + results["config"]["model_args"], ) print(model_args) data = [] @@ -105,6 +106,7 @@ def main(): open(Path(args.data_path, model, latest_results), encoding="utf-8") )["configs"] config = configs[task] + config["multiple_choice_generate"] = results.get("multiple_choice_generate", False) if model_index == 0: # Only need to assemble data for the first model metrics = [] @@ -176,7 +178,7 @@ def generate_dataset( if config["output_type"] == "loglikelihood": instance = [x["arguments"]["gen_args_0"]["arg_0"] for x in data] labels = [x["arguments"]["gen_args_0"]["arg_1"] for x in data] - elif config["output_type"] == "multiple_choice": + elif config["output_type"] == "multiple_choice" and not config["multiple_choice_generate"]: instance = [ x["arguments"]["gen_args_0"]["arg_0"] + "\n\n" @@ -185,7 +187,7 @@ def generate_dataset( ] elif config["output_type"] == "loglikelihood_rolling": instance = [x["arguments"]["gen_args_0"]["arg_0"] for x in data] - elif config["output_type"] == "generate_until": + elif config["output_type"] == "generate_until" or config["multiple_choice_generate"]: instance = [x["arguments"]["gen_args_0"]["arg_0"] for x in data] return pd.DataFrame( @@ -222,14 +224,14 @@ def generate_system_df(data, config): "correct" if x["filtered_resps"][0][1] is True else "incorrect" for x in data ] - elif config["output_type"] == "multiple_choice": + elif config["output_type"] == "multiple_choice" and not config["multiple_choice_generate"]: system_dict["output"] = [ ", ".join([str(y[0]) for y in x["filtered_resps"]]) for x in data ] system_dict["num_answers"] = [len(x["filtered_resps"]) for x in data] elif config["output_type"] == "loglikelihood_rolling": system_dict["output"] = [str(x["filtered_resps"][0]) for x in data] - elif config["output_type"] == "generate_until": + elif config["output_type"] == "generate_until" or config["multiple_choice_generate"]: system_dict["output"] = [str(x["filtered_resps"][0]) for x in data] system_dict["output_length"] = [len(str(x["filtered_resps"][0])) for x in data] From d9e49af7e1450dd3481d7406f5da635628a368ed Mon Sep 17 00:00:00 2001 From: Petr Baudis Date: Mon, 30 Dec 2024 00:55:04 +0100 Subject: [PATCH 4/4] Add support for --multiple_choice_generate abcd --- docs/interface.md | 2 ++ lm_eval/__main__.py | 8 ++++++-- lm_eval/api/task.py | 33 +++++++++++++++++++++++++-------- lm_eval/evaluator.py | 8 ++++---- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/docs/interface.md b/docs/interface.md index cea1aab027..4fded81883 100644 --- a/docs/interface.md +++ b/docs/interface.md @@ -54,6 +54,8 @@ This mode supports a number of command-line arguments, the details of which can - `--fewshot_as_multiturn` : If this flag is on, the Fewshot examples are treated as a multi-turn conversation. Questions are provided as user content and answers are provided as assistant responses. Requires `--num_fewshot` to be set to be greater than 0, and `--apply_chat_template` to be on. +- `--multiple_choice_generate` : If True, multiple choice problems are not evaluated based on lowest logprob continuation, but asking the model to generate the choice letter. This departs from the traditional evaluation methodology, but allows evaluation with popular chat-completion APIs and evaluates each multiple choice problem only once rather than #choice times. Without additional argument, choices must be reproduced verbatim by the model; with additional argument 'abcd' (RECOMMENDED), choices will be lettered and the model has to produce only the corresponding letter. + - `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results. * `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42. diff --git a/lm_eval/__main__.py b/lm_eval/__main__.py index 989cd680ba..381618b272 100644 --- a/lm_eval/__main__.py +++ b/lm_eval/__main__.py @@ -189,13 +189,17 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--multiple_choice_generate", - action="store_true", + type=str, + nargs="?", + const=True, default=False, help=( "If True, multiple choice problems are not evaluated based on lowest logprob continuation, " "but asking the model to generate the choice letter. This departs from the traditional evaluation " "methodology, but allows evaluation with popular chat-completion APIs and evaluates each multiple choice " - "problem only once rather than #choice times." + "problem only once rather than #choice times. Without additional argument, choices must be reproduced " + "verbatim by the model; with additional argument 'abcd', choices will be lettered and the model has to " + "produce only the corresponding letter." ), ) parser.add_argument( diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 82c3d6cbcf..ae30908d21 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -81,6 +81,7 @@ class TaskConfig(dict): description: str = "" target_delimiter: str = " " choice_delimiter: str = " / " + option_delimiter: str = "\n" fewshot_delimiter: str = "\n\n" fewshot_config: Optional[dict] = None # runtime configuration options @@ -380,7 +381,7 @@ def build_all_requests( system_instruction: Optional[str] = None, apply_chat_template: bool = False, fewshot_as_multiturn: bool = False, - multiple_choice_generate: bool = False, + multiple_choice_generate: Union[bool, str] = False, chat_template: Optional[Callable] = None, tokenizer_name: str = "", ) -> None: @@ -438,10 +439,13 @@ def build_all_requests( ): # sample fewshot context #TODO: need to offset doc_id by rank now! doc_system_instruction = system_instruction or "" - if multiple_choice_generate: + if self.OUTPUT_TYPE == "multiple_choice" and multiple_choice_generate: if doc_system_instruction: doc_system_instruction += " " - doc_system_instruction += "Please answer with the letter of the correct answer." + if multiple_choice_generate == "abcd": + doc_system_instruction += "Please include \"ANSWER: \" in your response with the letter of the correct last answer." + else: + doc_system_instruction += "Please answer with the letter of the correct last answer." fewshot_ctx = self.fewshot_context( doc, @@ -1034,7 +1038,7 @@ def fewshot_context( system_instruction: Optional[str] = None, apply_chat_template: bool = False, fewshot_as_multiturn: bool = False, - multiple_choice_generate: bool = False, + multiple_choice_generate: Union[bool, str] = False, chat_template: Optional[Callable] = None, ) -> str: """Returns a fewshot context string that is made up of a prepended description @@ -1050,7 +1054,7 @@ def fewshot_context( Whether to apply the chat template to the fewshot context. :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. - :param multiple_choice_generate: bool + :param multiple_choice_generate: Union[bool, str] Whether to generate multiple choice answer from scratch rather than pick by logprobs. :param chat_template: callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. @@ -1101,8 +1105,13 @@ def fewshot_context( if self.config.doc_to_choice is not None and multiple_choice_generate: if not isinstance(example, str): raise NotImplementedError("--multiple_choice_generate is implemented only for simple text docs") - example += self.config.target_delimiter - example += "(" + self.config.choice_delimiter.join(self.doc_to_choice(doc)) + ")" + if multiple_choice_generate == "abcd": + choices = self.doc_to_choice(doc) + for label, choice in zip(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")[:len(choices)], choices): + example += f"{self.config.option_delimiter}({label}) {choice}" + else: + example += self.config.target_delimiter + example += "(" + self.config.choice_delimiter.join(self.doc_to_choice(doc)) + ")" if apply_chat_template: if self.multiple_input: @@ -1319,7 +1328,7 @@ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: return None def construct_requests( - self, doc: dict, ctx: str, multiple_choice_generate: bool, **kwargs + self, doc: dict, ctx: str, multiple_choice_generate: Union[bool, str], **kwargs ) -> Union[List[Instance], Instance]: apply_chat_template = kwargs.pop("apply_chat_template", False) @@ -1526,6 +1535,14 @@ def process_results(self, doc, results): # it assumes that doc_to_target returns a number. choices = self.doc_to_choice(doc) gold = choices[gold] + if self.multiple_choice_generate == "abcd": + try: + result_label = re.findall(r"ANSWER: ([A-Z])", result)[-1] + result_i = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ").index(result_label) + result = choices[result_i] + except (AttributeError, ValueError, IndexError): + eval_logger.warning(f"[{self}] LLM did not pick a valid result ('{result}')") + result = choices[0] # XXX guess "randomly" # we expect multiple_targets to be a list. elif self.multiple_target: gold = list(gold) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index d5e3792a81..378edc7376 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -66,7 +66,7 @@ def simple_evaluate( system_instruction: Optional[str] = None, apply_chat_template: Union[bool, str] = False, fewshot_as_multiturn: bool = False, - multiple_choice_generate: bool = False, + multiple_choice_generate: Union[bool, str] = False, gen_kwargs: Optional[str] = None, task_manager: Optional[TaskManager] = None, verbosity: str = "INFO", @@ -120,7 +120,7 @@ def simple_evaluate( Defaults to False (no chat template applied). :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. - :param multiple_choice_generate: bool + :param multiple_choice_generate: Union[bool, str] Whether to generate multiple choice answer from scratch rather than pick by logprobs. :param gen_kwargs: str String arguments for model generation @@ -376,7 +376,7 @@ def evaluate( system_instruction: Optional[str] = None, apply_chat_template: Union[bool, str] = False, fewshot_as_multiturn: bool = False, - multiple_choice_generate: bool = False, + multiple_choice_generate: Union[bool, str] = False, verbosity: str = "INFO", ): """Instantiate and evaluate a model on a list of tasks. @@ -402,7 +402,7 @@ def evaluate( Defaults to False (no chat template applied). :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. - :param multiple_choice_generate: bool + :param multiple_choice_generate: Union[bool, str] Whether to generate multiple choice answer from scratch rather than pick by logprobs. :return Dictionary of results