-
Notifications
You must be signed in to change notification settings - Fork 0
/
skipGram.py
424 lines (365 loc) · 16.9 KB
/
skipGram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# -*- coding: utf-8 -*-
"""skipGram.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/15yk698annzTgrJmCIhlyLo0XjU6gFJrc
"""
#imports
from __future__ import division
import argparse
import pandas as pd
import pickle
from collections import defaultdict
from math import ceil
from scipy.special import expit
import random
import numpy as np
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt
################################################################################
#preprocessing
def vocab_dict(sentences):
"""returns a dictonnary of occurences of the different words in the corpus
Input:
sentences : list of sentencess, each sentence id a list of its words
Output:
occurences : dictionnary that maps each word to its occurence
"""
occurences = {}
for sentence in sentences:
for word in sentence :
if word in occurences.keys():
occurences[word] += 1
else:
occurences[word] = 1
return occurences
def rare_removal(sentences, occurences, minCount=5, rare_print = False):
"""
Removes words that occur less than minCount (rare words) using the dictionnary of occurences
Inputs:
sentences : the set of the sentences
occurences: dictionnary of occurences
minCount : minimum number of occurences to keep a word in the dataset
rare_print : boolean, set by defaut at False, if True print the list of rare words
Output:
clean_sentences: new list of sentences, without rare words
"""
# Keeping only words that appear more than minCount
clean_sentences = []
rare_words = []
for sentence in sentences:
clean_sentence = []
for word in sentence:
if occurences[word] >= minCount:
clean_sentence.append(word)
else:
occurences[word] = 0
rare_words.append(word)
clean_sentences.append(clean_sentence)
if rare_print:
print(rare_words)
return clean_sentences,rare_words
def subsampling(sentences,occurences, plot_graph = False, frequent_print=False):
'''
Undersamples very frequent words in our corpus of sentences
Inputs:
sentences : the set of setences
occurences : the dictionnary of occurences
plot_graph : boolean, defaut value is False, if True plots the graph of probability distribution
frequent_print : boolean, defaut value is False, if True prints the list of very frequent words to remove
Output:
clean_sentences : sentences after removing very frequent words
'''
sub_proba = {}
proba_threshold = 0.93 # This threshold was tuned after plotting the graph of the probability distribution of words
frequent_words = []
epsilon = 1e-5
total_number_words = sum(list(occurences.values()))
for word in occurences.keys():
if occurences[word] != 0:
freq = occurences[word] / total_number_words
p = 1 - np.sqrt(epsilon/ freq)
sub_proba[word] = p
if p > proba_threshold: #very common word
frequent_words.append(word)
occurences[word] = 0
# Cleaning
clean_sentences = []
for sentence in sentences:
clean_sentence = []
for word in sentence:
if not (word in frequent_words):
clean_sentence.append(word)
clean_sentences.append(clean_sentence)
#print very frequent words that were removed
if frequent_print:
print(frequent_words)
# Plot graph of the probability distribution of words
if plot_graph:
proba_gram = { }
for p in sub_proba.keys():
if p in proba_gram.keys():
proba_gram[p] +=1
else:
proba_gram[p] =1
plt.plot( list(proba_gram.keys()), list(proba_gram.values()), 'ro')
plt.xlabel('Probability',fontsize = 16)
plt.ylabel('Number of words', fontsize=16)
plt.show()
return clean_sentences, frequent_words
def concatenation(path):
'''
Concatenates all the text from the file located in "path" in a one struing and removes line jumping
Input:
path : the path of the data file
Output:
text_concat : string, contains the text after concatenation
'''
texts_list = []
with open(path, encoding='utf8') as f:
for l in f:
l = l[:-1] #get rid of line jumping
texts_list.append(l)
text_concat = ' '.join(texts_list)
return text_concat
def alphabet(word):
""" returns true if the word is exclusively composed of letters, else False
Input:
word : string
Output:
boolean
"""
for letter in word :
if ( not (ord(letter.lower()) in range( ord('a'),ord('z')+1)) or letter.lower in ['.']):
#if the word is exclusively composed of letters or contains '.' such as abbreviations
return False
return True
def text2sentences(path, undersample = True):
'''
Converts a raw text from path to tokenized sentences
Output:
a list containing all the sentences, each sentence is a list of its words
'''
text = concatenation(path)
tokenized_sentences = []
sentence = []
for word in text.split():
if word in ['!','?','.']: #end of the sentence
tokenized_sentences.append(sentence)
sentence = []
else:
if alphabet(word):
sentence.append(word.lower())
return tokenized_sentences
def loadPairs(path):
data = pd.read_csv(path, delimiter='\t')
pairs = zip(data['word1'],data['word2'],data['similarity'])
return pairs
################################################################################
class SkipGram:
def __init__(self, sentences, nEmbed=100, negativeRate=10, winSize=5, minCount=5):
"""
Initialisation Step: Generates the triplets of (target word, context word, +/- 1)
"""
self.sentences = sentences
self.occurences = vocab_dict(self.sentences)
self.nEmbed = nEmbed
self.negativeRate = negativeRate
self.winSize = winSize
self.minCount = minCount
self.clean_sentences, self.rare_words = rare_removal(self.sentences, self.occurences,self.minCount) # removing rare words
self.clean_sentences, self.stopwords = subsampling(self.sentences,self.occurences) # removing very frequent words
print(" Preprocessing : done ! ")
# Generate positive pairs: list of (target word, context word, 1) pairs
self.positive_pairs = self.generate_positive_pairs(self.clean_sentences,self.winSize)
# Generate negative pairs: list of (target word, context word, 0 ) pairs for each positive pair
self.negative_pairs = self.generate_negative_pairs(self.occurences, self.positive_pairs,self.negativeRate)
self.pairs = self.positive_pairs.copy()
for list_pair in self.negative_pairs:
for neg in list_pair:
self.pairs.append(neg)
print(" Generating positive and negative pairs : done! ")
self.total_target_words ={}
self.total_context_words = {}
i , j = 0 , 0
for pair in self.pairs:
if pair[0] not in self.total_target_words:
self.total_target_words[pair[0]] = i
i += 1
if pair[1] not in self.total_context_words:
self.total_context_words[pair[1]] = j
j += 1
def generate_positive_pairs(self, sentences, winSize):
"""Generates (target word, context word, +1) pairs
Inputs:
sentences : set of sentences
winSize : size of the sliding window for context
Output:
positive_pairs : list of (target word, context word, +1) triplets
context_words_positive_list : list of unique positive context words
target_words_positive_list : list of unique positive target words
"""
positive_pairs = []
for sentence in sentences:
counter = 0
for target in sentence:
for context_index in range(max(0, counter - winSize),min(counter + winSize + 1, len(sentence))):
context = sentence[context_index]
if context != target: #it's useless to associate a target to itself
positive_pairs.append((target, context, +1)) # +1 because positive pair
counter += 1
return positive_pairs
def generate_negative_pairs(self, occurences, positive_pairs, negativeRate):
"""Generates (target word, context word, 0) pairs
Inputs:
occurences : dictionnary of occurences
positive_pairs : the list of generated positive pairs
negativeRate : int, size of the sliding window for context
Output:
negative_pairs : list, for each positive pair associates a list of k (target word, context word, 0) triplets
"""
# Create a Unigram Table: the nbre of times a word's index appears in the table is given by P(w_i)*unigram_table_size
unigram = {}
for word in occurences.keys():
if occurences[word] !=0 :
unigram[word] = occurences[word]**0.75
normalize = sum(list(unigram.values()))
nbr_words = len(unigram.keys())
unigram_table = [unigram[word]/normalize for word in unigram.keys()]
l = list(unigram.keys())
negative_candidates = np.random.choice(list(range(len(l))),size=negativeRate*len(positive_pairs),p=unigram_table)
np.random.shuffle(negative_candidates)
# now let's build negative pairs
negative_pairs = []
for i in range(len(positive_pairs)):
positive_pair = positive_pairs[i]
k_negative = []
target = positive_pair[0]
for k in range(negativeRate):
context_index = negative_candidates[negativeRate*i+k]
k_negative.append((target, l[context_index], 0))
negative_pairs.append(k_negative)
return negative_pairs
def train(self, stepsize, epochs, batch_size,plot_loss=True):
#random initialization of W and C
self.W = np.array([np.random.randint(low=0, high=10, size=self.nEmbed) for _ in range(len(self.total_target_words))]) #target embedding matrix
self.C = np.array([np.random.randint(low=0, high=10, size=self.nEmbed) for _ in range(len(self.total_context_words))]) #context embedding matrix
print("Initialization of embedding matrices : done !")
print(" Start training...")
loss = []
# Running through the epochs
for epoch in range(epochs):
loss_epoch = 0
print("Epoch {}/{}".format(epoch+1,epochs))
batch_indices = np.random.randint(low = 0, high = len(self.positive_pairs),size = batch_size)
for index in batch_indices: #implement equations in section 4
positive_pair = self.positive_pairs[index]
target_index = self.total_target_words[positive_pair[0]]
all_pairs = [positive_pair] #all pairs (positive and negative) for the target word
all_pairs.extend(self.negative_pairs[index])
grad_t = np.zeros(self.nEmbed)
for _ , context, gamma in all_pairs:
context_index = self.total_context_words[context]
x = expit(self.W[target_index,:].T@self.C[context_index,:]) - gamma
grad_c = x*self.W[target_index,:]
grad_t += x*self.C[context_index,:]
self.C[context_index,:] = self.C[context_index,:] - stepsize *grad_c
self.W[target_index] = self.W[target_index] - stepsize*grad_t
loss_batch = 0
for _ , context, gamma in all_pairs:
context_index = self.total_context_words[context]
s = expit((2*gamma -1)*self.W[target_index,:].T@self.C[context_index,:])
if s!=0:
loss_batch += -np.log(s)
else :
loss_batch += 999 #actually, it is +inf
loss_epoch += loss_batch/batch_size
loss.append(loss_epoch)
plt.figure()
plt.plot(range(1,epochs+1),loss)
plt.title("Evolution of the loss along epochs",fontsize=16)
plt.show()
def similarity(self, word1, word2):
"""
Computes similiarity of words 1 and 2, using the output of the training phase
:return: cosine distance between the embeddings of words 1 and 2 if they are in the vocabulary
other cases are explained is section 4.3 of the report
"""
#checking if word1 and word2 are in our target words' list
if word1==word2: #trivial but can occur!
return 1
if not (word1 in self.occurences.keys()):
#if word1 is unknown, we include it to the vocab with freq = 0
self.occurences[word1] =0
if not (word2 in self.occurences.keys()):
#if word2 is unknown, we include it to the vocab with freq = 0
self.occurences[word2] =0
if (word1 in self.total_target_words.keys()) and (word2 in self.total_target_words.keys()) :
index1 = self.total_target_words[word1]
vec_1 = self.W[index1] #embedding representation of word1
index2 = self.total_target_words[word2]
vec_2 = self.W[index2]#embedding representation of word2
cosine_distance = vec_1.dot(vec_2) / (np.linalg.norm(vec_1) * np.linalg.norm(vec_2))
return abs(cosine_distance)
#if word1 or word2 not target words means they are in the vocab but with no context,
#or "rare" words or stopwords(very frequent) or just unknown
# (even not in the intial text before preprocessing)
elif word1 in self.total_target_words.keys() :
if word2 in self.stopwords :
return 0
else :
if self.occurences[word2] !=0: #word2 in vocab but not in targets neither rare nor stopwords
totalWords = sum(list(self.occurences.values()))
word2prob = self.occurences[word2] /totalWords
return word2prob * 0.5
else : #word2 is in removed rare words or unknown
totalWords = sum(list(self.occurences.values()))
return 0.1*(min(self.occurences.values())*self.occurences[word2])/(totalWords**2)
elif word2 in self.total_target_words.keys() :
if word1 in self.stopwords :
return 0
else :
if self.occurences[word1] !=0: #word1 in vocab but not in targets neither rare nor stopwords
totalWords = sum(list(self.occurences.values()))
word1prob = self.occurences[word1] /totalWords
return word1prob * 0.5
else : #word1 is in removed rare words or unknown
totalWords = sum(list(self.occurences.values()))
return 0.1*(min(self.occurences.values())*self.occurences[word2])/(totalWords**2)
else: #both words are not in target words
if (word1 in self.stopwords) or (word2 in self.stopwords):
return 0
else: #baoth rare or unknown
totalWords = sum(list(self.occurences.values()))
return min(self.occurences.values())/totalWords
def save(self, path):
'''
Save W (matrix of embeddings) and self.index_word (dictionnary of (word: index))
'''
print("We are saving the results")
with open(path, 'wb') as f:
pickle.dump([self.W, self.total_target_words,self.rare_words,self.stopwords,self.occurences], f)
@staticmethod
def load(path):
with open(path, 'rb') as f:
W, total_target_words, rare_words,stopwords ,occurences= pickle.load(f)
return W, total_target_words,rare_words,stopwords,occurences
################################################################################
#main
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--text', help='path containing training data', required=True)
parser.add_argument('--model', help='path to store/read model (when training/testing)', required=True)
parser.add_argument('--test', help='enters test mode', action='store_true')
opts = parser.parse_args()
if not opts.test: #training step
sentences = text2sentences(opts.text)
sg = SkipGram(sentences, nEmbed=100, negativeRate=5, winSize=5, minCount=5)
sg.train(stepsize = 0.01, epochs = 500, batch_size = 50,plot_loss = True)
sg.save(opts.model)
else: #testing step
pairs = loadPairs(opts.text)
sg = SkipGram.load(opts.model)
for a,b,_ in pairs:
# make sure this does not raise any exception, even if a or b are not in sg.vocab
print(sg.similarity(a,b))