forked from ucfnlp/multidoc_summarization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbeam_search.py
213 lines (175 loc) · 10.4 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Modifications Copyright 2017 Abigail See
# Modifications made 2018 by Logan Lebanoff
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This file contains code to run beam search decoding"""
import numpy as np
import data
from absl import flags
import pg_mmr_functions
FLAGS = flags.FLAGS
class Hypothesis(object):
"""Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis."""
def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage, mmr):
"""Hypothesis constructor.
Args:
tokens: List of integers. The ids of the tokens that form the summary so far.
log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far.
state: Current state of the decoder, a LSTMStateTuple.
attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far.
p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far.
coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector.
"""
self.tokens = tokens
self.log_probs = log_probs
self.state = state
self.attn_dists = attn_dists
self.p_gens = p_gens
self.coverage = coverage
self.similarity = 0.
self.mmr = mmr
def extend(self, token, log_prob, state, attn_dist, p_gen, coverage, mmr):
"""Return a NEW hypothesis, extended with the information from the latest step of beam search.
Args:
token: Integer. Latest token produced by beam search.
log_prob: Float. Log prob of the latest token.
state: Current decoder state, a LSTMStateTuple.
attn_dist: Attention distribution from latest step. Numpy array shape (attn_length).
p_gen: Generation probability on latest step. Float.
coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage.
Returns:
New Hypothesis for next step.
"""
return Hypothesis(tokens=self.tokens + [token],
log_probs=self.log_probs + [log_prob],
state=state,
attn_dists=self.attn_dists + [attn_dist],
p_gens=self.p_gens + [p_gen],
coverage=coverage,
mmr=mmr)
@property
def latest_token(self):
return self.tokens[-1]
@property
def log_prob(self):
# the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
return sum(self.log_probs)
@property
def avg_log_prob(self):
# normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
return self.log_prob / len(self.tokens)
def run_beam_search(sess, model, vocab, batch, ex_index, hps):
"""Performs beam search decoding on the given example.
Args:
sess: a tf.Session
model: a seq2seq model
vocab: Vocabulary object
batch: Batch object that is the same example repeated across the batch
Returns:
best_hyp: Hypothesis object; the best hypothesis found by beam search.
"""
max_dec_steps = FLAGS.max_dec_steps
# Run the encoder to get the encoder hidden states and decoder initial state
enc_states, dec_in_state = model.run_encoder(sess, batch)
# dec_in_state is a LSTMStateTuple
# enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim].
# Sentence importance
enc_sentences, enc_tokens = batch.tokenized_sents[0], batch.word_ids_sents[0]
importances = pg_mmr_functions.get_importances(model, batch, enc_states, vocab, sess, hps)
mmr_init = importances
# Initialize beam_size-many hyptheses
hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)],
log_probs=[0.0],
state=dec_in_state,
attn_dists=[],
p_gens=[],
coverage=np.zeros([batch.enc_batch.shape[1]]), # zero vector of length attention_length
mmr=mmr_init
) for hyp_idx in xrange(FLAGS.beam_size)]
results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)
steps = 0
while steps < max_dec_steps and len(results) < FLAGS.beam_size:
latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
latest_tokens = [t if t in xrange(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in
latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
states = [h.state for h in hyps] # list of current decoder states of the hypotheses
prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)
# Mute all source sentences except the top k sentences
prev_mmr = [h.mmr for h in hyps]
if FLAGS.pg_mmr:
if FLAGS.mute_k != -1:
prev_mmr = [pg_mmr_functions.mute_all_except_top_k(mmr, FLAGS.mute_k) for mmr in prev_mmr]
prev_mmr_for_words = [pg_mmr_functions.convert_to_word_level(mmr, batch, enc_tokens) for mmr in prev_mmr]
else:
prev_mmr_for_words = [None for _ in prev_mmr]
# Run one step of the decoder to get the new info
(topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage, pre_attn_dists) = model.decode_onestep(sess=sess,
batch=batch,
latest_tokens=latest_tokens,
enc_states=enc_states,
dec_init_states=states,
prev_coverage=prev_coverage,
mmr_score=prev_mmr_for_words)
# Extend each hypothesis and collect them all in all_hyps
all_hyps = []
num_orig_hyps = 1 if steps == 0 else len(
hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
for i in xrange(num_orig_hyps):
h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], \
new_coverage[
i] # take the ith hypothesis and new decoder state info
for j in xrange(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
# Extend the ith hypothesis with the jth option
new_hyp = h.extend(token=topk_ids[i, j],
log_prob=topk_log_probs[i, j],
state=new_state,
attn_dist=attn_dist,
p_gen=p_gen,
coverage=new_coverage_i,
mmr=h.mmr)
all_hyps.append(new_hyp)
# Filter and collect any hypotheses that have produced the end token.
hyps = [] # will contain hypotheses for the next step
for h in sort_hyps(all_hyps): # in order of most likely h
if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
# If this hypothesis is sufficiently long, put in results. Otherwise discard.
if steps >= FLAGS.min_dec_steps:
results.append(h)
else: # hasn't reached stop token, so continue to extend this hypothesis
hyps.append(h)
if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
# Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
break
# Update the MMR scores when a sentence is completed
if FLAGS.pg_mmr:
for hyp_idx, hyp in enumerate(hyps):
if hyp.latest_token == vocab.word2id(data.PERIOD): # if in regular mode, and the hyp ends in a period
pg_mmr_functions.update_similarity_and_mmr(hyp, importances, batch, enc_tokens, vocab)
steps += 1
# At this point, either we've got beam_size results, or we've reached maximum decoder steps
if len(results) == 0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
results = hyps
# Sort hypotheses by average log probability
hyps_sorted = sort_hyps(results)
best_hyp = hyps_sorted[0]
# Save plots of the distributions (importance, similarity, mmr)
if FLAGS.plot_distributions and FLAGS.pg_mmr:
pg_mmr_functions.save_distribution_plots(importances, enc_sentences,
enc_tokens, best_hyp, batch, vocab, ex_index)
# Return the hypothesis with highest average log prob
return best_hyp
def sort_hyps(hyps):
"""Return a list of Hypothesis objects, sorted by descending average log probability"""
return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)