-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNoisy_retrieval.py
48 lines (40 loc) · 1.38 KB
/
Noisy_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import faiss
from sentence_transformers import SentenceTransformer, LoggingHandler
import logging
import os
from shutil import copyfile
import numpy as np
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
model_name = 'msmarco-MiniLM-L6-cos-v5'
model = SentenceTransformer(model_name)
model.max_seq_length = 300
data_folder='msmarco-data'
index= faiss.read_index(os.path.join(data_folder,'marco_corpus_faiss') )
top_k = 1000
q_file_name='msmarco-queries.dev.small.tsv'
queries_filepath = os.path.join(data_folder,q_file_name )
qids=[]
queries=[]
with open(queries_filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
qids.append(qid)
queries.append(query)
xq = model.encode(queries)
print(xq.shape[1])
print(xq.shape,'q done')
for var in range(1,10):
var = var*0.01
noise_tensor =np.random.normal(0, var , 384)
out=open('run/Noisyq_'+str(var)+'.'+q_file_name+'.tsv','w')
for i in range(6980):
xq[i, :] = xq[i, :] + noise_tensor
D, I = index.search(xq, top_k) # search
rank=1
for q_id in range(len(I)):
for rank in range(1,top_k +1):
out.write(qids[q_id]+'\t'+str( I[q_id][rank-1])+'\t'+str(rank)+'\n')
out.close()