-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmain.py
137 lines (111 loc) · 4.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import division
from __future__ import print_function
from attention import AttentionNN
from data import read_vocabulary
from bleu.length_analysis import process_files
import os
import time
import pprint
import random
import numpy as np
import tensorflow as tf
pp = pprint.PrettyPrinter().pprint
flags = tf.app.flags
flags.DEFINE_integer("max_size", 30, "Maximum sentence length [30]")
flags.DEFINE_integer("batch_size", 128, "Number of examples in minibatch [128]")
flags.DEFINE_integer("random_seed", 123, "Value of random seed [123]")
flags.DEFINE_integer("epochs", 12, "Number of epochs to run [10]")
flags.DEFINE_integer("hidden_size", 1024, "Size of hidden units [1024]")
flags.DEFINE_integer("emb_size", 256, "Size of embedding dimension [256]")
flags.DEFINE_integer("num_layers", 4, "Depth of RNNs [4]")
flags.DEFINE_float("dropout", 0.2, "Dropout probability [0.2]")
flags.DEFINE_float("minval", -0.1, "Minimum value for initialization [-0.1]")
flags.DEFINE_float("maxval", 0.1, "Maximum value for initialization [0.1]")
flags.DEFINE_float("lr_init", 1.0, "Initial learning rate [1.0]")
flags.DEFINE_float("max_grad_norm", 5.0, "Maximum gradient cutoff [5.0]")
flags.DEFINE_string("checkpoint_dir", "checkpoints", "Checkpoint directory [checkpoints]")
flags.DEFINE_string("dataset", "small", "Dataset to use [small]")
flags.DEFINE_string("name", "default", "Model name [default]")
flags.DEFINE_string("sample", None, "Sample from dataset [None]")
flags.DEFINE_boolean("is_test", False, "True for testing, False for training [False]")
FLAGS = flags.FLAGS
tf.set_random_seed(FLAGS.random_seed)
random.seed(FLAGS.random_seed)
class debug:
source_data_path = "data/train.debug.en"
target_data_path = "data/train.debug.vi"
source_vocab_path = "data/vocab.small.en"
target_vocab_path = "data/vocab.small.vi"
valid_source_data_path = "data/test.debug.en"
valid_target_data_path = "data/test.debug.vi"
test_source_data_path = "data/test.debug.en"
test_target_data_path = "data/test.debug.vi"
class small:
source_data_path = "data/train.small.en.pruned"
target_data_path = "data/train.small.vi.pruned"
source_vocab_path = "data/vocab.small.en"
target_vocab_path = "data/vocab.small.vi"
valid_source_data_path = "data/valid.small.en.pruned"
valid_target_data_path = "data/valid.small.vi.pruned"
test_source_data_path = "data/tst2013.en.pruned"
test_target_data_path = "data/tst2013.vi.pruned"
class medium:
source_data_path = "data/train.medium.en"
target_data_path = "data/train.medium.de"
source_vocab_path = "data/vocab.medium.en"
target_vocab_path = "data/vocab.medium.de"
def get_bleu_score(samples, target_file):
hyp_file = "hyp" + str(int(time.time()))
with open(hyp_file, "w") as f:
for sample in samples:
for s in sample:
if s == "</s>": break
f.write(" " + s)
f.write("\n")
process_files(hyp_file, target_file)
os.remove(hyp_file)
def print_samples(samples):
for sample in samples:
for s in sample:
if s == "</s>":
break
print(" " + s, end="")
print()
def main(_):
config = FLAGS
if config.dataset == "small":
data_config = small
elif config.dataset == "medium":
data_config = medium
elif config.dataset == "debug":
data_config = debug
else:
raise Exception("[!] Unknown dataset {}".format(config.dataset))
config.source_data_path = data_config.source_data_path
config.target_data_path = data_config.target_data_path
config.source_vocab_path = data_config.source_vocab_path
config.target_vocab_path = data_config.target_vocab_path
s_nwords = len(read_vocabulary(config.source_vocab_path))
t_nwords = len(read_vocabulary(config.target_vocab_path))
config.s_nwords = s_nwords
config.t_nwords = t_nwords
pp(config.__dict__["__flags"])
with tf.Session() as sess:
attn = AttentionNN(config, sess)
if config.sample:
attn.load()
samples = attn.sample(config.sample)
print_samples(samples)
else:
if not config.is_test:
attn.run(data_config.valid_source_data_path,
data_config.valid_target_data_path)
else:
attn.load()
loss = attn.test(data_config.test_source_data_path,
data_config.test_target_data_path)
print("[Test] [Loss: {}] [Perplexity: {}]".format(loss, np.exp(loss)))
samples = attn.sample(data_config.test_source_data_path)
get_bleu_score(samples, data_config.test_target_data_path)
if __name__ == "__main__":
tf.app.run()