-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
78 lines (55 loc) Β· 3.37 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import re
import json
import time
import argparse
import numpy as np
import pandas as pd
from rule_based import FeatureExtractor, Recommendation
MODEL_NAME = "jhgan/ko-sroberta-multitask"
def parse_args():
parser = argparse.ArgumentParser(
)
parser.add_argument("--data_dir", type=str, default = "/opt/ml/data", help="crawled dataset")
parser.add_argument("--topk", type=int, default = 4, help="recommend tag num")
return parser.parse_args()
def main(args):
# Example
input_question = None
sim = None
input_question = "μ
μ¬ ν ν¬λΆ : μ
μ¬ ν 10λ
λμμ νμ¬μν μλ리μ€μ κ·Έκ²μ μΆκ΅¬νλ μ΄μ λ₯Ό κΈ°μ ν΄μ£ΌμΈμ."
start1 = time.time()
document = pd.read_csv(os.path.join(args.data_dir, "jk_documents_3_2.csv"), low_memory = False)
item = pd.read_csv(os.path.join(args.data_dir, "jk_answers_without_samples_3_2.csv"), low_memory = False)
answer_emb_matrix = np.load(os.path.join(args.data_dir, "answer_embedding_matrix.npy"))
question_emb_matrix = np.load(os.path.join(args.data_dir, "question_embedding_matrix.npy"))
with open(os.path.join(args.data_dir, "question_cate_map_answerid.json"), 'r') as f: #key: question_category, value(list): answer_id
qcate_dict = json.load(f)
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
embedder = FeatureExtractor(model_name = MODEL_NAME)
if isinstance(input_question, str):
question_category, sim = embedder.match_question_top1(WHITESPACE_HANDLER(input_question), question_emb_matrix)
if len(WHITESPACE_HANDLER(input_question)) == 0 or sim <= 0.5:
raise Exception("Unexpected question input")
example_user = {"question_category" : 5, "company": "(μ£Ό)LGνν",\
"job_large": "μ°κ΅¬κ°λ°Β·μ€κ³", "job_small":"λ°λ체·λμ€νλ μ΄", "answer":"'곡μ κ°μ κ²½νκ³Ό μ 곡μ§μ'μ λ 곡μ μμ§λμ΄μκ² νμν κ²μ νν곡μ μ λν μ§μκ³Ό κ·Έκ²μ λ°νμΌλ‘ μμ°λκ³Ό μλμ§ ν¨μ¨μ ν₯μμν¬ μ μλ λ₯λ ₯μ΄λΌκ³ μκ°ν©λλ€. μ λ νν곡μ₯ μ€κ³νλ‘μ νΈμμ 곡μ κ°μ μΌλ‘ μμ°λμ 20%ν₯μμν¨ κ²½νμ΄ μμ΅λλ€. μ²μμλ μνλ λ§νΌ μμ°λμ΄ μ λμμ§λ§ DMAICκΈ°λ²μ μ¬μ©νμ¬ κ³΅μ λ°μ΄ν°λ₯Ό λΆμνμ¬κ³ λ©νμ¬μ΄ λλΉλκ³ μλ€λ κ²μ νμ
νμμ΅λλ€. "}
print("data loader time : ", time.time() - start1)
question_category, company, job_large, job_small, answer = example_user.values()
start2 = time.time()
recommend = Recommendation(document, item, qcate_dict, answer_emb_matrix, embedder,
question_category, company, job_large, job_small, answer,
args.topk)
recommend.filtering()
result = {
"tag1" : recommend.recommend_with_company_jobtype(),
"tag2" : recommend.recommend_with_jobtype_without_company(),
"tag3" : recommend.recommend_with_company_without_jobtype(),
"tag4" : recommend.recommed_based_popularity(),
"tag5" : recommend.recommend_based_expert()
}
print("=======νκ·Έλ³ μΆμ² κ²°κ³Ό(answer_id)========")
print(result)
print("recommend time : ", time.time()-start2, '\n')
if __name__ == '__main__':
args = parse_args()
main(args)