Upload 6 files
Browse files- features/graph_utils.py +110 -0
- features/text_utils.py +84 -0
- models/bm25_utils.py +40 -0
- models/pairwise_model.py +140 -0
- models/predict_model.py +76 -0
- models/qa_model.py +52 -0
features/graph_utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import networkx as nx
|
2 |
+
import numpy as np
|
3 |
+
from cdlib import algorithms
|
4 |
+
|
5 |
+
|
6 |
+
# these functions are heavily influenced by the HF squad_metrics.py script
|
7 |
+
def normalize_text(s):
|
8 |
+
"""Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
|
9 |
+
import string, re
|
10 |
+
|
11 |
+
def remove_articles(text):
|
12 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
13 |
+
return re.sub(regex, " ", text)
|
14 |
+
|
15 |
+
def white_space_fix(text):
|
16 |
+
return " ".join(text.split())
|
17 |
+
|
18 |
+
def remove_punc(text):
|
19 |
+
exclude = set(string.punctuation)
|
20 |
+
return "".join(ch for ch in text if ch not in exclude)
|
21 |
+
|
22 |
+
def lower(text):
|
23 |
+
return text.lower()
|
24 |
+
|
25 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
26 |
+
|
27 |
+
|
28 |
+
def compute_exact_match(prediction, truth):
|
29 |
+
return int(normalize_text(prediction) == normalize_text(truth))
|
30 |
+
|
31 |
+
|
32 |
+
def compute_f1(prediction, truth):
|
33 |
+
pred_tokens = normalize_text(prediction).split()
|
34 |
+
truth_tokens = normalize_text(truth).split()
|
35 |
+
|
36 |
+
# if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
|
37 |
+
if len(pred_tokens) == 0 or len(truth_tokens) == 0:
|
38 |
+
return int(pred_tokens == truth_tokens)
|
39 |
+
|
40 |
+
common_tokens = set(pred_tokens) & set(truth_tokens)
|
41 |
+
|
42 |
+
# if there are no common tokens then f1 = 0
|
43 |
+
if len(common_tokens) == 0:
|
44 |
+
return 0
|
45 |
+
|
46 |
+
prec = len(common_tokens) / len(pred_tokens)
|
47 |
+
rec = len(common_tokens) / len(truth_tokens)
|
48 |
+
|
49 |
+
return 2 * (prec * rec) / (prec + rec)
|
50 |
+
|
51 |
+
|
52 |
+
def is_date_or_num(answer):
|
53 |
+
answer = answer.lower().split()
|
54 |
+
for w in answer:
|
55 |
+
w = w.strip()
|
56 |
+
if w.isnumeric() or w in ["ngày", "tháng", "năm"]:
|
57 |
+
return True
|
58 |
+
return False
|
59 |
+
|
60 |
+
|
61 |
+
def find_best_cluster(answers, best_answer, thr=0.79):
|
62 |
+
if len(answers) == 0: # or best_answer not in answers:
|
63 |
+
return best_answer
|
64 |
+
elif len(answers) == 1:
|
65 |
+
return answers[0]
|
66 |
+
dists = np.zeros((len(answers), len(answers)))
|
67 |
+
for i in range(len(answers) - 1):
|
68 |
+
for j in range(i + 1, len(answers)):
|
69 |
+
a1 = answers[i].lower().strip()
|
70 |
+
a2 = answers[j].lower().strip()
|
71 |
+
if is_date_or_num(a1) or is_date_or_num(a2):
|
72 |
+
# print(a1, a2)
|
73 |
+
if a1 == a2 or ("tháng" in a1 and a1 in a2) or ("tháng" in a2 and a2 in a1):
|
74 |
+
dists[i, j] = 1
|
75 |
+
dists[j, i] = 1
|
76 |
+
# continue
|
77 |
+
elif a1 == a2 or (a1 in a2) or (a2 in a1) or compute_f1(a1.lower(), a2.lower()) >= thr:
|
78 |
+
dists[i, j] = 1
|
79 |
+
dists[j, i] = 1
|
80 |
+
# print(dists)
|
81 |
+
try:
|
82 |
+
thr = 1
|
83 |
+
dups = np.where(dists >= thr)
|
84 |
+
dup_strs = []
|
85 |
+
edges = []
|
86 |
+
for i, j in zip(dups[0], dups[1]):
|
87 |
+
if i != j:
|
88 |
+
edges.append((i, j))
|
89 |
+
G = nx.Graph()
|
90 |
+
for i, answer in enumerate(answers):
|
91 |
+
G.add_node(i, content=answer)
|
92 |
+
G.add_edges_from(edges)
|
93 |
+
partition = algorithms.louvain(G)
|
94 |
+
max_len_comm = np.max([len(x) for x in partition.communities])
|
95 |
+
best_comms = []
|
96 |
+
for comm in partition.communities:
|
97 |
+
# print([answers[i] for i in comm])
|
98 |
+
if len(comm) == max_len_comm:
|
99 |
+
best_comms.append([answers[i] for i in comm])
|
100 |
+
# if len(best_comms) > 1:
|
101 |
+
# return best_answer
|
102 |
+
for comm in best_comms:
|
103 |
+
if best_answer in comm:
|
104 |
+
return best_answer
|
105 |
+
mid = len(best_comms[0]) // 2
|
106 |
+
# print(mid, sorted(best_comms[0], key = len))
|
107 |
+
return sorted(best_comms[0], key=len)[mid]
|
108 |
+
except Exception as e:
|
109 |
+
print(e, "Disconnected graph")
|
110 |
+
return best_answer
|
features/text_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from glob import glob
|
3 |
+
import re
|
4 |
+
import nltk
|
5 |
+
nltk.download('punkt')
|
6 |
+
from nltk import word_tokenize as lib_tokenizer
|
7 |
+
import string
|
8 |
+
|
9 |
+
|
10 |
+
def preprocess(x, max_length=-1, remove_puncts=False):
|
11 |
+
x = nltk_tokenize(x)
|
12 |
+
x = x.replace("\n", " ")
|
13 |
+
if remove_puncts:
|
14 |
+
x = "".join([i for i in x if i not in string.punctuation])
|
15 |
+
if max_length > 0:
|
16 |
+
x = " ".join(x.split()[:max_length])
|
17 |
+
return x
|
18 |
+
|
19 |
+
|
20 |
+
def nltk_tokenize(x):
|
21 |
+
return " ".join(word_tokenize(strip_context(x))).strip()
|
22 |
+
|
23 |
+
|
24 |
+
def post_process_answer(x, entity_dict):
|
25 |
+
if type(x) is not str:
|
26 |
+
return x
|
27 |
+
try:
|
28 |
+
x = strip_answer_string(x)
|
29 |
+
except:
|
30 |
+
return "NaN"
|
31 |
+
x = "".join([c for c in x if c not in string.punctuation])
|
32 |
+
x = " ".join(x.split())
|
33 |
+
y = x.lower()
|
34 |
+
if len(y) > 1 and y.split()[0].isnumeric() and ("tháng" not in x):
|
35 |
+
return y.split()[0]
|
36 |
+
if not (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x):
|
37 |
+
if len(x.split()) <= 2:
|
38 |
+
return entity_dict.get(x.lower(), x)
|
39 |
+
else:
|
40 |
+
return x
|
41 |
+
else:
|
42 |
+
return y
|
43 |
+
|
44 |
+
|
45 |
+
dict_map = dict({})
|
46 |
+
|
47 |
+
|
48 |
+
def word_tokenize(text):
|
49 |
+
global dict_map
|
50 |
+
words = text.split()
|
51 |
+
words_norm = []
|
52 |
+
for w in words:
|
53 |
+
if dict_map.get(w, None) is None:
|
54 |
+
dict_map[w] = ' '.join(lib_tokenizer(w)).replace('``', '"').replace("''", '"')
|
55 |
+
words_norm.append(dict_map[w])
|
56 |
+
return words_norm
|
57 |
+
|
58 |
+
|
59 |
+
def strip_answer_string(text):
|
60 |
+
text = text.strip()
|
61 |
+
while text[-1] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
|
62 |
+
if text[0] != '(' and text[-1] == ')' and '(' in text:
|
63 |
+
break
|
64 |
+
if text[-1] == '"' and text[0] != '"' and text.count('"') > 1:
|
65 |
+
break
|
66 |
+
text = text[:-1].strip()
|
67 |
+
while text[0] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
|
68 |
+
if text[0] == '"' and text[-1] != '"' and text.count('"') > 1:
|
69 |
+
break
|
70 |
+
text = text[1:].strip()
|
71 |
+
text = text.strip()
|
72 |
+
return text
|
73 |
+
|
74 |
+
|
75 |
+
def strip_context(text):
|
76 |
+
text = text.replace('\n', ' ')
|
77 |
+
text = re.sub(r'\s+', ' ', text)
|
78 |
+
text = text.strip()
|
79 |
+
return text
|
80 |
+
|
81 |
+
|
82 |
+
def check_number(x):
|
83 |
+
x = str(x).lower()
|
84 |
+
return (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x)
|
models/bm25_utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tqdm.auto import tqdm
|
3 |
+
|
4 |
+
tqdm.pandas()
|
5 |
+
from gensim.corpora import Dictionary
|
6 |
+
from gensim.models import TfidfModel
|
7 |
+
from gensim.similarities import SparseMatrixSimilarity
|
8 |
+
from src.features.text_utils import preprocess
|
9 |
+
|
10 |
+
|
11 |
+
class BM25Gensim:
|
12 |
+
def __init__(self, checkpoint_path, entity_dict, title2idx):
|
13 |
+
self.dictionary = Dictionary.load(checkpoint_path + "/dict")
|
14 |
+
self.tfidf_model = SparseMatrixSimilarity.load(checkpoint_path + "/tfidf")
|
15 |
+
self.bm25_index = TfidfModel.load(checkpoint_path + "/bm25_index")
|
16 |
+
self.title2idx = title2idx
|
17 |
+
self.entity_dict = entity_dict
|
18 |
+
|
19 |
+
def get_topk_stage1(self, query, topk=100):
|
20 |
+
tokenized_query = query.split()
|
21 |
+
tfidf_query = self.tfidf_model[self.dictionary.doc2bow(tokenized_query)]
|
22 |
+
scores = self.bm25_index[tfidf_query]
|
23 |
+
top_n = np.argsort(scores)[::-1][:topk]
|
24 |
+
return top_n, scores[top_n]
|
25 |
+
|
26 |
+
def get_topk_stage2(self, x, raw_answer=None, topk=50):
|
27 |
+
x = str(x)
|
28 |
+
query = preprocess(x, max_length=128).lower().split()
|
29 |
+
tfidf_query = self.tfidf_model[self.dictionary.doc2bow(query)]
|
30 |
+
scores = self.bm25_index[tfidf_query]
|
31 |
+
top_n = list(np.argsort(scores)[::-1][:topk])
|
32 |
+
if raw_answer is not None:
|
33 |
+
raw_answer = raw_answer.strip()
|
34 |
+
if raw_answer in self.entity_dict:
|
35 |
+
title = self.entity_dict[raw_answer].replace("wiki/", "").replace("_", " ")
|
36 |
+
extra_id = self.title2idx.get(title, -1)
|
37 |
+
if extra_id != -1 and extra_id not in top_n:
|
38 |
+
top_n.append(extra_id)
|
39 |
+
scores = scores[top_n]
|
40 |
+
return np.array(top_n), np.array(scores)
|
models/pairwise_model.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from transformers import AutoModel, AutoConfig
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
AUTH_TOKEN = "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"
|
10 |
+
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
12 |
+
use_auth_token=AUTH_TOKEN)
|
13 |
+
pad_token_id = tokenizer.pad_token_id
|
14 |
+
|
15 |
+
|
16 |
+
class PairwiseModel(nn.Module):
|
17 |
+
def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"):
|
18 |
+
super(PairwiseModel, self).__init__()
|
19 |
+
self.max_length = max_length
|
20 |
+
self.batch_size = batch_size
|
21 |
+
self.device = device
|
22 |
+
self.model = AutoModel.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
|
23 |
+
self.model.to(self.device)
|
24 |
+
self.model.eval()
|
25 |
+
self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
|
26 |
+
self.fc = nn.Linear(768, 1).to(self.device)
|
27 |
+
|
28 |
+
def forward(self, ids, masks):
|
29 |
+
out = self.model(input_ids=ids,
|
30 |
+
attention_mask=masks,
|
31 |
+
output_hidden_states=False).last_hidden_state
|
32 |
+
out = out[:, 0]
|
33 |
+
outputs = self.fc(out)
|
34 |
+
return outputs
|
35 |
+
|
36 |
+
def stage1_ranking(self, question, texts):
|
37 |
+
tmp = pd.DataFrame()
|
38 |
+
tmp["text"] = [" ".join(x.split()) for x in texts]
|
39 |
+
tmp["question"] = question
|
40 |
+
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
|
41 |
+
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
42 |
+
num_workers=0, shuffle=False, pin_memory=True)
|
43 |
+
preds = []
|
44 |
+
with torch.no_grad():
|
45 |
+
bar = enumerate(valid_loader)
|
46 |
+
for step, data in bar:
|
47 |
+
ids = data["ids"].to(self.device)
|
48 |
+
masks = data["masks"].to(self.device)
|
49 |
+
preds.append(torch.sigmoid(self(ids, masks)).view(-1))
|
50 |
+
preds = torch.concat(preds)
|
51 |
+
return preds.cpu().numpy()
|
52 |
+
|
53 |
+
def stage2_ranking(self, question, answer, titles, texts):
|
54 |
+
tmp = pd.DataFrame()
|
55 |
+
tmp["candidate"] = texts
|
56 |
+
tmp["question"] = question
|
57 |
+
tmp["answer"] = answer
|
58 |
+
tmp["title"] = titles
|
59 |
+
valid_dataset = SiameseDatasetStage2(tmp, tokenizer, self.max_length, is_test=True)
|
60 |
+
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
61 |
+
num_workers=0, shuffle=False, pin_memory=True)
|
62 |
+
preds = []
|
63 |
+
with torch.no_grad():
|
64 |
+
bar = enumerate(valid_loader)
|
65 |
+
for step, data in bar:
|
66 |
+
ids = data["ids"].to(self.device)
|
67 |
+
masks = data["masks"].to(self.device)
|
68 |
+
preds.append(torch.sigmoid(self(ids, masks)).view(-1))
|
69 |
+
preds = torch.concat(preds)
|
70 |
+
return preds.cpu().numpy()
|
71 |
+
|
72 |
+
|
73 |
+
class SiameseDatasetStage1(Dataset):
|
74 |
+
|
75 |
+
def __init__(self, df, tokenizer, max_length, is_test=False):
|
76 |
+
self.df = df
|
77 |
+
self.max_length = max_length
|
78 |
+
self.tokenizer = tokenizer
|
79 |
+
self.is_test = is_test
|
80 |
+
self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[
|
81 |
+
"input_ids"]
|
82 |
+
self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[
|
83 |
+
"input_ids"]
|
84 |
+
if not self.is_test:
|
85 |
+
self.targets = self.df.label
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.df)
|
89 |
+
|
90 |
+
def __getitem__(self, index):
|
91 |
+
return {
|
92 |
+
'ids1': torch.tensor(self.content1[index], dtype=torch.long),
|
93 |
+
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
|
94 |
+
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
class SiameseDatasetStage2(Dataset):
|
99 |
+
|
100 |
+
def __init__(self, df, tokenizer, max_length, is_test=False):
|
101 |
+
self.df = df
|
102 |
+
self.max_length = max_length
|
103 |
+
self.tokenizer = tokenizer
|
104 |
+
self.is_test = is_test
|
105 |
+
self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1)
|
106 |
+
self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1)
|
107 |
+
self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[
|
108 |
+
"input_ids"]
|
109 |
+
self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[
|
110 |
+
"input_ids"]
|
111 |
+
if not self.is_test:
|
112 |
+
self.targets = self.df.label
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.df)
|
116 |
+
|
117 |
+
def __getitem__(self, index):
|
118 |
+
return {
|
119 |
+
'ids1': torch.tensor(self.content1[index], dtype=torch.long),
|
120 |
+
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
|
121 |
+
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
|
122 |
+
}
|
123 |
+
|
124 |
+
|
125 |
+
def collate_fn(batch):
|
126 |
+
ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch]
|
127 |
+
targets = [x["target"] for x in batch]
|
128 |
+
max_len = np.max([len(x) for x in ids])
|
129 |
+
masks = []
|
130 |
+
for i in range(len(ids)):
|
131 |
+
if len(ids[i]) < max_len:
|
132 |
+
ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long)))
|
133 |
+
masks.append(ids[i] != pad_token_id)
|
134 |
+
# print(tokenizer.decode(ids[0]))
|
135 |
+
outputs = {
|
136 |
+
"ids": torch.vstack(ids),
|
137 |
+
"masks": torch.vstack(masks),
|
138 |
+
"target": torch.vstack(targets).view(-1)
|
139 |
+
}
|
140 |
+
return outputs
|
models/predict_model.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.models.pairwise_model import *
|
2 |
+
from src.features.text_utils import *
|
3 |
+
import regex as re
|
4 |
+
from src.models.bm25_utils import BM25Gensim
|
5 |
+
from src.models.qa_model import *
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
tqdm.pandas()
|
8 |
+
from datasets import load_dataset
|
9 |
+
|
10 |
+
df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
|
11 |
+
df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
|
12 |
+
df_wiki.title = df_wiki.title.apply(str)
|
13 |
+
|
14 |
+
entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
|
15 |
+
new_dict = dict()
|
16 |
+
for key, val in entity_dict.items():
|
17 |
+
val = val[0].replace("wiki/", "").replace("_", " ")
|
18 |
+
entity_dict[key] = val
|
19 |
+
key = preprocess(key)
|
20 |
+
new_dict[key.lower()] = val
|
21 |
+
entity_dict.update(new_dict)
|
22 |
+
title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
|
23 |
+
|
24 |
+
qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["models/qa_model_robust.bin"], entity_dict)
|
25 |
+
pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
|
26 |
+
pairwise_model_stage1.load_state_dict(torch.load("models/pairwise_v2.bin", map_location=torch.device('cpu')))
|
27 |
+
pairwise_model_stage1.eval()
|
28 |
+
|
29 |
+
pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
|
30 |
+
pairwise_model_stage2.load_state_dict(torch.load("models/pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
|
31 |
+
|
32 |
+
bm25_model_stage1 = BM25Gensim("models/bm25_stage1/", entity_dict, title2idx)
|
33 |
+
bm25_model_stage2_full = BM25Gensim("models/bm25_stage2/full_text/", entity_dict, title2idx)
|
34 |
+
bm25_model_stage2_title = BM25Gensim("models/bm25_stage2/title/", entity_dict, title2idx)
|
35 |
+
|
36 |
+
def get_answer_e2e(question):
|
37 |
+
#Bm25 retrieval for top200 candidates
|
38 |
+
query = preprocess(question).lower()
|
39 |
+
top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
|
40 |
+
titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
|
41 |
+
texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
|
42 |
+
|
43 |
+
#Reranking with pairwise model for top10
|
44 |
+
question = preprocess(question)
|
45 |
+
ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
|
46 |
+
ranking_scores = ranking_preds * bm25_scores
|
47 |
+
|
48 |
+
#Question answering
|
49 |
+
best_idxs = np.argsort(ranking_scores)[-10:]
|
50 |
+
ranking_scores = np.array(ranking_scores)[best_idxs]
|
51 |
+
texts = np.array(texts)[best_idxs]
|
52 |
+
best_answer = qa_model(question, texts, ranking_scores)
|
53 |
+
if best_answer is None:
|
54 |
+
return "Chịu"
|
55 |
+
bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
|
56 |
+
|
57 |
+
#Entity mapping
|
58 |
+
if not check_number(bm25_answer):
|
59 |
+
bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
|
60 |
+
bm25_question_answer = bm25_question + " " + bm25_answer
|
61 |
+
candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
|
62 |
+
titles = [df_wiki.title.values[i] for i in candidates]
|
63 |
+
texts = [df_wiki.text.values[i] for i in candidates]
|
64 |
+
ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
|
65 |
+
if ranking_preds.max() >= 0.1:
|
66 |
+
final_answer = titles[ranking_preds.argmax()]
|
67 |
+
else:
|
68 |
+
candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
|
69 |
+
titles = [df_wiki.title.values[i] for i in candidates] + titles
|
70 |
+
texts = [df_wiki.text.values[i] for i in candidates] + texts
|
71 |
+
ranking_preds = np.concatenate(
|
72 |
+
[pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
|
73 |
+
final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
|
74 |
+
else:
|
75 |
+
final_answer = bm25_answer.lower()
|
76 |
+
return final_answer
|
models/qa_model.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import AutoModelForQuestionAnswering, pipeline
|
5 |
+
from src.features.text_utils import post_process_answer
|
6 |
+
from src.features.graph_utils import find_best_cluster
|
7 |
+
|
8 |
+
|
9 |
+
class QAEnsembleModel(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, model_name, model_checkpoints, entity_dict,
|
12 |
+
thr=0.1, device="cpu"):
|
13 |
+
super(QAEnsembleModel, self).__init__()
|
14 |
+
self.nlps = []
|
15 |
+
for model_checkpoint in model_checkpoints:
|
16 |
+
model = AutoModelForQuestionAnswering.from_pretrained(model_name)#.half()
|
17 |
+
model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
|
18 |
+
nlp = pipeline('question-answering', model=model,
|
19 |
+
tokenizer=model_name, device=device)
|
20 |
+
self.nlps.append(nlp)
|
21 |
+
self.entity_dict = entity_dict
|
22 |
+
self.thr = thr
|
23 |
+
|
24 |
+
def forward(self, question, texts, ranking_scores=None):
|
25 |
+
if ranking_scores is None:
|
26 |
+
ranking_scores = np.ones((len(texts),))
|
27 |
+
|
28 |
+
curr_answers = []
|
29 |
+
curr_scores = []
|
30 |
+
best_score = 0
|
31 |
+
for i, nlp in enumerate(self.nlps):
|
32 |
+
for text, score in zip(texts, ranking_scores):
|
33 |
+
QA_input = {
|
34 |
+
'question': question,
|
35 |
+
'context': text
|
36 |
+
}
|
37 |
+
res = nlp(QA_input)
|
38 |
+
# print(res)
|
39 |
+
if res["score"] > self.thr:
|
40 |
+
curr_answers.append(res["answer"])
|
41 |
+
curr_scores.append(res["score"])
|
42 |
+
res["score"] = res["score"] * score
|
43 |
+
if i == 0:
|
44 |
+
if res["score"] > best_score:
|
45 |
+
answer = res["answer"]
|
46 |
+
best_score = res["score"]
|
47 |
+
if len(curr_answers) == 0:
|
48 |
+
return None
|
49 |
+
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
|
50 |
+
answer = post_process_answer(answer, self.entity_dict)
|
51 |
+
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
|
52 |
+
return new_best_answer
|