From 058675078c888138ff5c41c3892183d48b230daa Mon Sep 17 00:00:00 2001 From: Saibo-creator <53392976+Saibo-creator@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:01:46 -0700 Subject: [PATCH] test: Added new test file to validate case where model'embedding has mismatch with tokenizer's vocab (#85) --- tests/test_hf_generation/test_generation.py | 2 +- .../test_generation_w_expanded_emb.py | 121 ++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 tests/test_hf_generation/test_generation_w_expanded_emb.py diff --git a/tests/test_hf_generation/test_generation.py b/tests/test_hf_generation/test_generation.py index de6d244..86698a5 100644 --- a/tests/test_hf_generation/test_generation.py +++ b/tests/test_hf_generation/test_generation.py @@ -6,7 +6,7 @@ MODEL_IDS = [ "hf-internal-testing/tiny-random-GPTJForCausalLM", "JackFram/llama-68m", - # "hf-internal-testing/tiny-random-PhiForCausalLM", + "hf-internal-testing/tiny-random-PhiForCausalLM", "hf-internal-testing/tiny-random-gpt2", # "hf-internal-testing/tiny-random-BlenderbotForCausalLM", ] diff --git a/tests/test_hf_generation/test_generation_w_expanded_emb.py b/tests/test_hf_generation/test_generation_w_expanded_emb.py new file mode 100644 index 0000000..1b8fb5a --- /dev/null +++ b/tests/test_hf_generation/test_generation_w_expanded_emb.py @@ -0,0 +1,121 @@ +from unittest import TestCase +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers_cfg.token_grammar_recognizer import IncrementalTokenRecognizer +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor + +MODEL_IDS = [ + "JackFram/llama-68m", +] + + +def check_parentheses(generation): + stack = [] + for char in generation: + if char == "(": + stack.append(char) + elif char == ")": + if not stack: + return False + stack.pop() + return not stack + + +class TestGreedyDecoding(TestCase): + @classmethod + def setUpClass(cls): + cls.models = {} + cls.tokenizers = {} + for model_id in MODEL_IDS: + cls.models[model_id] = AutoModelForCausalLM.from_pretrained(model_id) + cls.tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id) + cls.tokenizers[model_id].pad_token = cls.tokenizers[model_id].eos_token + # we expand the embedding layer to simulate the case where the model has a larger embedding layer than the tokenizer + cls.models[model_id].resize_token_embeddings( + 10 + len(cls.tokenizers[model_id]) + ) + + def test_generate_only_number(self): + # test greedy decoding with grammar constraints + grammar_str = """ + root ::= [0-9]+ + """ + + for model_id in MODEL_IDS: + model = self.models[model_id] + tokenizer = self.tokenizers[model_id] + + grammar = IncrementalTokenRecognizer( + grammar_str, start_rule_name="root", tokenizer=tokenizer + ) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + + prefix = "This is a valid number:" + + input_ids = tokenizer( + [prefix], add_special_tokens=False, return_tensors="pt", padding=True + )["input_ids"] + + output = model.generate( + input_ids, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, + num_beams=1, + max_new_tokens=40, + top_p=0.92, + top_k=5, + logits_processor=[grammar_processor], + repetition_penalty=100.0, + early_stopping=True, + ) + + generations = tokenizer.batch_decode( + output[:, input_ids.shape[1] :], skip_special_tokens=True + ) + self.assertTrue( + generations[0].isdigit(), f"generations: {generations} is not a number" + ) + + def test_generate_balanced_parenthesis(self): + # test greedy decoding with grammar constraints + grammar_str = """ + root ::= "(" root ")" | "" + """ + + for model_id in MODEL_IDS: + model = self.models[model_id] + tokenizer = self.tokenizers[model_id] + + grammar = IncrementalTokenRecognizer( + grammar_str, start_rule_name="root", tokenizer=tokenizer + ) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + + prefix = "This is a valid json:" + + input_ids = tokenizer( + [prefix], add_special_tokens=False, return_tensors="pt", padding=True + )["input_ids"] + + output = model.generate( + input_ids, + do_sample=True, + pad_token_id=tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, + num_beams=1, + max_new_tokens=40, + top_p=0.92, + top_k=5, + logits_processor=[grammar_processor], + repetition_penalty=100.0, + early_stopping=True, + ) + + generation: str = tokenizer.batch_decode( + output[:, input_ids.shape[1] :], skip_special_tokens=True + )[0] + + self.assertTrue( + check_parentheses(generation), + f"generations: {generation} is not a balanced parenthesis", + )