Skip to content

Commit

Permalink
Merge pull request #2 from IsaacJ60/master
Browse files Browse the repository at this point in the history
Added translated queries to txt file
  • Loading branch information
DelaramRajaei authored Oct 16, 2023
2 parents 286fb82 + 12711ff commit 8110de6
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions qe/expanders/backtranslation.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,42 @@
from cmn import param
from expanders.abstractqexpander import AbstractQExpander
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
import sys

sys.path.extend(['../qe'])

from expanders.abstractqexpander import AbstractQExpander
from cmn import param


class BackTranslation(AbstractQExpander):
def __init__(self, tgt):
AbstractQExpander.__init__(self)

# Initialization
self.tgt = tgt
model = AutoModelForSeq2SeqLM.from_pretrained(param.backtranslation['model_card'])
tokenizer = AutoTokenizer.from_pretrained(param.backtranslation['model_card'])
model = AutoModelForSeq2SeqLM.from_pretrained(
param.backtranslation['model_card'])
tokenizer = AutoTokenizer.from_pretrained(
param.backtranslation['model_card'])

# Translation models
self.translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=param.backtranslation['src_lng'], tgt_lang=self.tgt, max_length=param.backtranslation['max_length'], device=param.backtranslation['device'])
self.back_translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=self.tgt, tgt_lang=param.backtranslation['src_lng'], max_length=param.backtranslation['max_length'], device=param.backtranslation['device'])
self.translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=param.backtranslation[
'src_lng'], tgt_lang=self.tgt, max_length=param.backtranslation['max_length'], device=param.backtranslation['device'])
self.back_translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=self.tgt,
tgt_lang=param.backtranslation['src_lng'], max_length=param.backtranslation['max_length'], device=param.backtranslation['device'])
# Model use for calculating semsim
self.transformer_model = SentenceTransformer(param.backtranslation['transformer_model'])
self.transformer_model = SentenceTransformer(
param.backtranslation['transformer_model'])

# Generate the backtranslated of the original query then calculates the difference of the two queries
def get_expanded_query(self, q, args=None):
translated_query = self.translator(q)
back_translated_query = self.back_translator(translated_query[0]['translation_text'])
back_translated_query = self.back_translator(
translated_query[0]['translation_text'])

with open('output\\robust04\\translatedqueries.txt', 'a+') as outfile:
# output qid, original query, translated query, backtranslated query
outfile.write(str(args[0]) + '\t' + str(q) + '\t' + translated_query[0]['translation_text'] + '\t' + back_translated_query[0]['translation_text'] + '\n')

score = self.semsim(q, back_translated_query[0]['translation_text'])
return super().get_expanded_query(back_translated_query[0]['translation_text'], [score])
Expand Down

0 comments on commit 8110de6

Please sign in to comment.