forked from collin-burns/discovering_latent_knowledge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
53 lines (45 loc) · 2.67 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from sklearn.linear_model import LogisticRegression
from utils import get_parser, load_all_generations, CCS
def main(args, generation_args):
# load hidden states and labels
neg_hs, pos_hs, y = load_all_generations(generation_args)
# Make sure the shape is correct
assert neg_hs.shape == pos_hs.shape
neg_hs, pos_hs = neg_hs[..., -1], pos_hs[..., -1] # take the last layer
if neg_hs.shape[1] == 1: # T5 may have an extra dimension; if so, get rid of it
neg_hs = neg_hs.squeeze(1)
pos_hs = pos_hs.squeeze(1)
# Very simple train/test split (using the fact that the data is already shuffled)
neg_hs_train, neg_hs_test = neg_hs[:len(neg_hs) // 2], neg_hs[len(neg_hs) // 2:]
pos_hs_train, pos_hs_test = pos_hs[:len(pos_hs) // 2], pos_hs[len(pos_hs) // 2:]
y_train, y_test = y[:len(y) // 2], y[len(y) // 2:]
# Make sure logistic regression accuracy is reasonable; otherwise our method won't have much of a chance of working
# you can also concatenate, but this works fine and is more comparable to CCS inputs
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test
lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {}".format(lr.score(x_test, y_test)))
# Set up CCS. Note that you can usually just use the default args by simply doing ccs = CCS(neg_hs, pos_hs, y)
ccs = CCS(neg_hs_train, pos_hs_train, nepochs=args.nepochs, ntries=args.ntries, lr=args.lr, batch_size=args.ccs_batch_size,
verbose=args.verbose, device=args.ccs_device, linear=args.linear, weight_decay=args.weight_decay,
var_normalize=args.var_normalize)
# train and evaluate CCS
ccs.repeated_train()
ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
print("CCS accuracy: {}".format(ccs_acc))
if __name__ == "__main__":
parser = get_parser()
generation_args = parser.parse_args() # we'll use this to load the correct hidden states + labels
# We'll also add some additional args for evaluation
parser.add_argument("--nepochs", type=int, default=1000)
parser.add_argument("--ntries", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--ccs_batch_size", type=int, default=-1)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--ccs_device", type=str, default="cuda")
parser.add_argument("--linear", action="store_true")
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--var_normalize", action="store_true")
args = parser.parse_args()
main(args, generation_args)