-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_sequence_tagger.py
executable file
·344 lines (277 loc) · 17.9 KB
/
run_sequence_tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import os
import pickle
import subprocess
import data_seq_helper
import tensorflow as tf
from bert import modeling
from bert import optimization
from bert import tokenization
from data_seq_helper import file_writer
from data_seq_helper import file_based_input_fn_builder
from data_seq_helper import filed_based_convert_examples_to_features
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("task_name", "ner", "The name of the task to train.")
flags.DEFINE_string("data_dir", "datasets/CoNLL2003_en", "The input data dir.")
flags.DEFINE_string("bert_config_file", "bert_ckpt/base_cased/bert_config.json", "The config json file")
flags.DEFINE_string("init_checkpoint", "bert_ckpt/base_cased/bert_model.ckpt", "Initial checkpoint")
flags.DEFINE_string("vocab_file", "bert_ckpt/base_cased/vocab.txt", "vocab file that the BERT model was trained on.")
flags.DEFINE_string("output_dir", "checkpoint/conll2003_en", "output dir where the model checkpoints will be written.")
flags.DEFINE_bool("do_lower_case", False, "Whether to lower case the input text.")
flags.DEFINE_integer("max_seq_length", 128, "The maximum total input sequence length after WordPiece tokenization.")
flags.DEFINE_bool("do_train", True, "Whether to run training.")
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_predict", True, "Whether to run the model in inference mode on the dev and test sets.")
flags.DEFINE_integer("batch_size", 32, "Total batch size for training.")
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_float("num_train_epochs", 6.0, "Total number of training epochs to perform.")
flags.DEFINE_float("warmup_proportion", 0.1, "Proportion of training to perform linear learning rate warmup")
flags.DEFINE_integer("save_checkpoints_steps", 1000, "How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 1000, "How many steps to make in each estimator call.")
flags.DEFINE_bool("use_crf", True, "if use CRF for decode")
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels,
use_one_hot_embeddings, use_crf=False):
"""Creates a classification model."""
model = modeling.BertModel(config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
embedding = model.get_sequence_output() # [batch_size, max_seq_length, hidden_size]
max_seq_length = embedding.shape[-2].value
hidden_size = embedding.shape[-1].value
seq_len = tf.reduce_sum(tf.sign(tf.abs(input_ids)), reduction_indices=1) # [batch_size]
if is_training:
embedding = tf.nn.dropout(embedding, keep_prob=0.9)
# If you want, you can add BiLSTM layers here, do not forget to config the hidden size
embedding = tf.reshape(embedding, shape=[-1, hidden_size]) # [batch_size x max_seq_length, hidden_size]
output_weights = tf.get_variable(name="output_weights",
shape=[num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(name="output_bias",
shape=[num_labels],
initializer=tf.zeros_initializer())
logits = tf.matmul(embedding, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
logits = tf.reshape(logits, shape=[-1, max_seq_length, num_labels])
if use_crf:
with tf.variable_scope("crf_loss"):
trans = tf.get_variable(name="transition", shape=[num_labels, num_labels],
initializer=tf.truncated_normal_initializer(stddev=0.02))
log_likelihood, transition = tf.contrib.crf.crf_log_likelihood(inputs=logits,
tag_indices=labels,
transition_params=trans,
sequence_lengths=seq_len)
loss = tf.reduce_mean(-log_likelihood)
predicts, viterbi_score = tf.contrib.crf.crf_decode(potentials=logits,
transition_params=transition,
sequence_length=seq_len)
return loss, logits, predicts
else:
with tf.variable_scope("loss"):
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) # [batch_sz, max_seq, num_labels]
mask = tf.cast(input_mask, dtype=tf.float32) # [batch_sz, max_seq]
log_probabilities = tf.nn.log_softmax(logits, axis=-1) # [batch_sz, max_seq, num_labels]
per_sample_loss = -tf.reduce_sum(one_hot_labels * log_probabilities, axis=-1) # [batch_sz, max_seq]
loss = tf.reduce_sum(per_sample_loss * mask) / FLAGS.batch_size
predicts = tf.argmax(logits, axis=-1, output_type=tf.int32)
return loss, logits, predicts
def model_fn_builder(bert_config, num_labels, init_ckpt, learning_rate, num_train_steps, num_warmup_steps,
use_one_hot_embeddings, use_crf):
def model_fn(features, labels, mode, params):
"""The `model_fn` for TPUEstimator."""
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
(total_loss, logits, predicts) = create_model(bert_config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
labels=label_ids,
num_labels=num_labels,
use_one_hot_embeddings=use_one_hot_embeddings,
use_crf=use_crf)
tvars = tf.trainable_variables()
if init_ckpt:
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, init_ckpt)
tf.train.init_from_checkpoint(init_ckpt, assignment_map)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer(loss=total_loss,
init_lr=learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=None)
hook_dict = dict()
hook_dict["loss"] = total_loss
hook_dict["global_steps"] = tf.train.get_or_create_global_step()
logging_hook = tf.train.LoggingTensorHook(hook_dict, every_n_iter=100)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
loss=total_loss,
train_op=train_op,
training_hooks=[logging_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(label_ids_, predicts_, mask_):
eval_loss = tf.metrics.mean_squared_error(labels=label_ids_, predictions=predicts_, weights=mask_)
return {"eval_loss": eval_loss}
eval_metrics = (metric_fn, [label_ids, predicts, input_mask])
output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
loss=total_loss,
eval_metrics=eval_metrics)
else:
output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
predictions=predicts)
return output_spec
return model_fn
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"ner": data_seq_helper.NerProcessor,
"chunk": data_seq_helper.ChunkProcessor,
"pos": data_seq_helper.PosProcessor
}
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint)
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
raise ValueError("At least one of `do_train`, `do_eval` or `do_predict' must be True.")
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
raise ValueError("Cannot use sequence length %d because the BERT model was only trained up to sequence "
"length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings))
if not os.path.exists(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
task_name = FLAGS.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
processor = processors[task_name]()
label_list = processor.get_labels(data_dir=FLAGS.data_dir)
print(label_list, flush=True)
tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_config = tf.contrib.tpu.RunConfig(
cluster=None,
master=None,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=8,
per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
train_examples = None
num_train_steps = None
num_warmup_steps = None
if FLAGS.do_train:
train_examples = processor.get_train_examples(data_dir=FLAGS.data_dir)
num_train_steps = int(len(train_examples) / FLAGS.batch_size * FLAGS.num_train_epochs)
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
model_fn = model_fn_builder(bert_config=bert_config,
num_labels=len(label_list),
init_ckpt=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_one_hot_embeddings=False,
use_crf=FLAGS.use_crf)
# If TPU is not available, this will fall back to normal Estimator on CPU or GPU.
estimator = tf.contrib.tpu.TPUEstimator(use_tpu=False,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size,
predict_batch_size=FLAGS.batch_size)
if FLAGS.do_train:
tf.logging.info("***** Running training *****")
tf.logging.info(" Num of Training examples = %d", len(train_examples))
tf.logging.info(" Batch size = %d", FLAGS.batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
if not os.path.exists(train_file):
filed_based_convert_examples_to_features(examples=train_examples,
label_list=label_list,
max_seq_length=FLAGS.max_seq_length,
tokenizer=tokenizer,
output_file=train_file,
output_dir=FLAGS.output_dir)
train_input_fn = file_based_input_fn_builder(input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
if FLAGS.do_eval:
eval_examples = processor.get_dev_examples(data_dir=FLAGS.data_dir)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num of Evaluate examples = %d", len(eval_examples))
tf.logging.info(" Batch size = %d", FLAGS.batch_size)
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
if not os.path.exists(eval_file):
filed_based_convert_examples_to_features(examples=eval_examples,
label_list=label_list,
max_seq_length=FLAGS.max_seq_length,
tokenizer=tokenizer,
output_file=eval_file,
output_dir=FLAGS.output_dir)
eval_input_fn = file_based_input_fn_builder(input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=False)
result = estimator.evaluate(input_fn=eval_input_fn)
tf.logging.info("***** Evaluation results *****")
tf.logging.info("Eval loss: {}".format(result["eval_loss"]))
if FLAGS.do_predict:
with open(FLAGS.output_dir + '/label2id.pkl', 'rb') as rf:
label2id = pickle.load(rf)
id2label = {value: key for key, value in label2id.items()}
'''Prediction on dev dataset'''
eval_examples = processor.get_dev_examples(data_dir=FLAGS.data_dir)
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
batch_tokens, batch_labels = filed_based_convert_examples_to_features(examples=eval_examples,
label_list=label_list,
max_seq_length=FLAGS.max_seq_length,
tokenizer=tokenizer,
output_file=eval_file,
output_dir=FLAGS.output_dir)
tf.logging.info("***** Running prediction on dev dataset *****")
tf.logging.info(" Num of Predicting examples = %d", len(eval_examples))
tf.logging.info(" Batch size = %d", FLAGS.batch_size)
eval_input_fn = file_based_input_fn_builder(input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=False)
result = estimator.predict(input_fn=eval_input_fn)
output_predict_file = os.path.join(FLAGS.output_dir, "label_dev.txt")
file_writer(output_predict_file=output_predict_file,
result=result,
batch_tokens=batch_tokens,
batch_labels=batch_labels,
id2label=id2label)
# run evaluation script
subprocess.call("perl conlleval.pl -d '\t' < ./{}/label_dev.txt".format(FLAGS.output_dir), shell=True)
'''Prediction on test dataset'''
predict_examples = processor.get_test_examples(data_dir=FLAGS.data_dir)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
batch_tokens, batch_labels = filed_based_convert_examples_to_features(examples=predict_examples,
label_list=label_list,
max_seq_length=FLAGS.max_seq_length,
tokenizer=tokenizer,
output_file=predict_file,
output_dir=FLAGS.output_dir)
tf.logging.info("***** Running prediction on test dataset *****")
tf.logging.info(" Num of Predicting examples = %d", len(predict_examples))
tf.logging.info(" Batch size = %d", FLAGS.batch_size)
predict_input_fn = file_based_input_fn_builder(input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=False)
result = estimator.predict(input_fn=predict_input_fn)
output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")
file_writer(output_predict_file=output_predict_file,
result=result,
batch_tokens=batch_tokens,
batch_labels=batch_labels,
id2label=id2label)
# run evaluation script
subprocess.call("perl conlleval.pl -d '\t' < ./{}/label_test.txt".format(FLAGS.output_dir), shell=True)
if __name__ == "__main__":
tf.app.run()