-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbelief_graph.py
76 lines (65 loc) · 2.67 KB
/
belief_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from extractor import model, base_model_name
from sequence_probabilities import make_greedy_tracker
import outlines.text.generate as generate
from outlines.text.generate.regex import choice
from transformers import AutoTokenizer
from math import exp
from load import dataset, suspects, culprit, tagline, story_text
import mystery
import tms_mystery
import common_wandb
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
bool_choices = ['True', 'False']
mask_token_ids = [tokenizer.encode(w)[-1] for w in bool_choices]
mask_token_ids = [5574, 8824] # set it manually by inspection...
belief_generator = make_greedy_tracker(
choice(model, bool_choices, max_tokens=50),
mask_token_ids
)
postfix = ': Answer exactly one of: True or False.'
def belief_probability(prompt):
question = prompt[0:prompt.index(postfix)]
instr_prompt = f'<s>[INST]{prompt}[/INST] '
print(question)
#print('Unconstrained answer:', generate.continuation(model)(instr_prompt))
belief_generator.sequence_log_prob = 0.0
val = belief_generator(instr_prompt)
val = True if val=='True' else False
confidence = exp(belief_generator.sequence_log_prob)
print(f"{val} ({confidence})")
return (val, confidence)
def create_prompt(story, id, neg, what=None):
if what is None:
prompt = f"{id} is{' not' if neg else ''} guilty"
else:
prompt = f"{id} has{' no' if neg else ''} {what}"
question = prompt + postfix
return question + '\n' + story + '\n' + question
def create_story_lines(story, suspect_list):
lines = []
for suspect in suspect_list:
for what in [None] + mystery.whats:
#for neg in [True, False]:
(val, confidence) = belief_probability(create_prompt(story, suspect, False, what))
lines.append((suspect, what, val, confidence))
#lines.append((suspect, what, not val, 1.0-confidence))
return lines
def solve(story, suspect_list):
return tms_mystery.solve(create_story_lines(story, suspect_list))
def processCase(x, run=None):
print(f"## {tagline(x)}")
found_culprit = solve(story_text(x), suspects(x))
print(f"The culprit is {found_culprit}.")
print(f"In fact, it is {culprit(x)}.")
x['eval'] = found_culprit == culprit(x)
common_wandb.log_eval(run, x['eval'])
return x
def processAll(run=None):
results = dataset.map(lambda x: processCase(x, run))
solved = len(list(1 for e in results['train']['eval'] if e==1))
total = results.num_rows['train']
print(f"Solved {solved} out of {total}.")
if __name__ == '__main__':
run = common_wandb.init(project="belief_graph")
processAll(run)