Skip to content

Commit

Permalink
Add the passkey_retrieval_test method
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent c706553 commit 815c717
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,55 @@ def generate_passkey_prompt(passkey, context_length):
)

return prompt


def passkey_retrieval_test(model, tokenizer, max_length, num_trials=10):
"""
Perform the passkey retrieval test on the model.
Args:
model: The LongRoPE model to evaluate.
tokenizer: Tokenizer for encoding/decoding text.
max_length: Maximum sequence length to test.
num_trials: Number of trials to run for each context length.
Returns:
dict: A dictionary of accuracies for each tested context length.
"""
model.eval()
accuracies = {}

for length in [
4096,
8192,
16384,
32768,
65536,
131072,
262144,
524288,
1048576,
2097152,
]:
if length > max_length:
break

correct_retrievals = 0

for _ in range(num_trials):
passkey = "".join([str(random.randint(0, 9)) for _ in range(5)])
prompt = generate_passkey_prompt(passkey, length)

input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
output = model(input_ids)
generated_ids = output.argmax(dim=-1)

generated_text = tokenizer.decode(generated_ids[0])
if passkey in generated_text:
correct_retrievals += 1

accuracies[length] = correct_retrievals / num_trials

return accuracies

0 comments on commit 815c717

Please sign in to comment.