foxxy-hm commited on
Commit
aa8b9d8
1 Parent(s): 42f4f0b

Upload 6 files

Browse files
features/graph_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import numpy as np
3
+ from cdlib import algorithms
4
+
5
+
6
+ # these functions are heavily influenced by the HF squad_metrics.py script
7
+ def normalize_text(s):
8
+ """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
9
+ import string, re
10
+
11
+ def remove_articles(text):
12
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
13
+ return re.sub(regex, " ", text)
14
+
15
+ def white_space_fix(text):
16
+ return " ".join(text.split())
17
+
18
+ def remove_punc(text):
19
+ exclude = set(string.punctuation)
20
+ return "".join(ch for ch in text if ch not in exclude)
21
+
22
+ def lower(text):
23
+ return text.lower()
24
+
25
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
26
+
27
+
28
+ def compute_exact_match(prediction, truth):
29
+ return int(normalize_text(prediction) == normalize_text(truth))
30
+
31
+
32
+ def compute_f1(prediction, truth):
33
+ pred_tokens = normalize_text(prediction).split()
34
+ truth_tokens = normalize_text(truth).split()
35
+
36
+ # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
37
+ if len(pred_tokens) == 0 or len(truth_tokens) == 0:
38
+ return int(pred_tokens == truth_tokens)
39
+
40
+ common_tokens = set(pred_tokens) & set(truth_tokens)
41
+
42
+ # if there are no common tokens then f1 = 0
43
+ if len(common_tokens) == 0:
44
+ return 0
45
+
46
+ prec = len(common_tokens) / len(pred_tokens)
47
+ rec = len(common_tokens) / len(truth_tokens)
48
+
49
+ return 2 * (prec * rec) / (prec + rec)
50
+
51
+
52
+ def is_date_or_num(answer):
53
+ answer = answer.lower().split()
54
+ for w in answer:
55
+ w = w.strip()
56
+ if w.isnumeric() or w in ["ngày", "tháng", "năm"]:
57
+ return True
58
+ return False
59
+
60
+
61
+ def find_best_cluster(answers, best_answer, thr=0.79):
62
+ if len(answers) == 0: # or best_answer not in answers:
63
+ return best_answer
64
+ elif len(answers) == 1:
65
+ return answers[0]
66
+ dists = np.zeros((len(answers), len(answers)))
67
+ for i in range(len(answers) - 1):
68
+ for j in range(i + 1, len(answers)):
69
+ a1 = answers[i].lower().strip()
70
+ a2 = answers[j].lower().strip()
71
+ if is_date_or_num(a1) or is_date_or_num(a2):
72
+ # print(a1, a2)
73
+ if a1 == a2 or ("tháng" in a1 and a1 in a2) or ("tháng" in a2 and a2 in a1):
74
+ dists[i, j] = 1
75
+ dists[j, i] = 1
76
+ # continue
77
+ elif a1 == a2 or (a1 in a2) or (a2 in a1) or compute_f1(a1.lower(), a2.lower()) >= thr:
78
+ dists[i, j] = 1
79
+ dists[j, i] = 1
80
+ # print(dists)
81
+ try:
82
+ thr = 1
83
+ dups = np.where(dists >= thr)
84
+ dup_strs = []
85
+ edges = []
86
+ for i, j in zip(dups[0], dups[1]):
87
+ if i != j:
88
+ edges.append((i, j))
89
+ G = nx.Graph()
90
+ for i, answer in enumerate(answers):
91
+ G.add_node(i, content=answer)
92
+ G.add_edges_from(edges)
93
+ partition = algorithms.louvain(G)
94
+ max_len_comm = np.max([len(x) for x in partition.communities])
95
+ best_comms = []
96
+ for comm in partition.communities:
97
+ # print([answers[i] for i in comm])
98
+ if len(comm) == max_len_comm:
99
+ best_comms.append([answers[i] for i in comm])
100
+ # if len(best_comms) > 1:
101
+ # return best_answer
102
+ for comm in best_comms:
103
+ if best_answer in comm:
104
+ return best_answer
105
+ mid = len(best_comms[0]) // 2
106
+ # print(mid, sorted(best_comms[0], key = len))
107
+ return sorted(best_comms[0], key=len)[mid]
108
+ except Exception as e:
109
+ print(e, "Disconnected graph")
110
+ return best_answer
features/text_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from glob import glob
3
+ import re
4
+ import nltk
5
+ nltk.download('punkt')
6
+ from nltk import word_tokenize as lib_tokenizer
7
+ import string
8
+
9
+
10
+ def preprocess(x, max_length=-1, remove_puncts=False):
11
+ x = nltk_tokenize(x)
12
+ x = x.replace("\n", " ")
13
+ if remove_puncts:
14
+ x = "".join([i for i in x if i not in string.punctuation])
15
+ if max_length > 0:
16
+ x = " ".join(x.split()[:max_length])
17
+ return x
18
+
19
+
20
+ def nltk_tokenize(x):
21
+ return " ".join(word_tokenize(strip_context(x))).strip()
22
+
23
+
24
+ def post_process_answer(x, entity_dict):
25
+ if type(x) is not str:
26
+ return x
27
+ try:
28
+ x = strip_answer_string(x)
29
+ except:
30
+ return "NaN"
31
+ x = "".join([c for c in x if c not in string.punctuation])
32
+ x = " ".join(x.split())
33
+ y = x.lower()
34
+ if len(y) > 1 and y.split()[0].isnumeric() and ("tháng" not in x):
35
+ return y.split()[0]
36
+ if not (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x):
37
+ if len(x.split()) <= 2:
38
+ return entity_dict.get(x.lower(), x)
39
+ else:
40
+ return x
41
+ else:
42
+ return y
43
+
44
+
45
+ dict_map = dict({})
46
+
47
+
48
+ def word_tokenize(text):
49
+ global dict_map
50
+ words = text.split()
51
+ words_norm = []
52
+ for w in words:
53
+ if dict_map.get(w, None) is None:
54
+ dict_map[w] = ' '.join(lib_tokenizer(w)).replace('``', '"').replace("''", '"')
55
+ words_norm.append(dict_map[w])
56
+ return words_norm
57
+
58
+
59
+ def strip_answer_string(text):
60
+ text = text.strip()
61
+ while text[-1] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
62
+ if text[0] != '(' and text[-1] == ')' and '(' in text:
63
+ break
64
+ if text[-1] == '"' and text[0] != '"' and text.count('"') > 1:
65
+ break
66
+ text = text[:-1].strip()
67
+ while text[0] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
68
+ if text[0] == '"' and text[-1] != '"' and text.count('"') > 1:
69
+ break
70
+ text = text[1:].strip()
71
+ text = text.strip()
72
+ return text
73
+
74
+
75
+ def strip_context(text):
76
+ text = text.replace('\n', ' ')
77
+ text = re.sub(r'\s+', ' ', text)
78
+ text = text.strip()
79
+ return text
80
+
81
+
82
+ def check_number(x):
83
+ x = str(x).lower()
84
+ return (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x)
models/bm25_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm.auto import tqdm
3
+
4
+ tqdm.pandas()
5
+ from gensim.corpora import Dictionary
6
+ from gensim.models import TfidfModel
7
+ from gensim.similarities import SparseMatrixSimilarity
8
+ from src.features.text_utils import preprocess
9
+
10
+
11
+ class BM25Gensim:
12
+ def __init__(self, checkpoint_path, entity_dict, title2idx):
13
+ self.dictionary = Dictionary.load(checkpoint_path + "/dict")
14
+ self.tfidf_model = SparseMatrixSimilarity.load(checkpoint_path + "/tfidf")
15
+ self.bm25_index = TfidfModel.load(checkpoint_path + "/bm25_index")
16
+ self.title2idx = title2idx
17
+ self.entity_dict = entity_dict
18
+
19
+ def get_topk_stage1(self, query, topk=100):
20
+ tokenized_query = query.split()
21
+ tfidf_query = self.tfidf_model[self.dictionary.doc2bow(tokenized_query)]
22
+ scores = self.bm25_index[tfidf_query]
23
+ top_n = np.argsort(scores)[::-1][:topk]
24
+ return top_n, scores[top_n]
25
+
26
+ def get_topk_stage2(self, x, raw_answer=None, topk=50):
27
+ x = str(x)
28
+ query = preprocess(x, max_length=128).lower().split()
29
+ tfidf_query = self.tfidf_model[self.dictionary.doc2bow(query)]
30
+ scores = self.bm25_index[tfidf_query]
31
+ top_n = list(np.argsort(scores)[::-1][:topk])
32
+ if raw_answer is not None:
33
+ raw_answer = raw_answer.strip()
34
+ if raw_answer in self.entity_dict:
35
+ title = self.entity_dict[raw_answer].replace("wiki/", "").replace("_", " ")
36
+ extra_id = self.title2idx.get(title, -1)
37
+ if extra_id != -1 and extra_id not in top_n:
38
+ top_n.append(extra_id)
39
+ scores = scores[top_n]
40
+ return np.array(top_n), np.array(scores)
models/pairwise_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import AutoModel, AutoConfig
6
+ from transformers import AutoTokenizer
7
+ import pandas as pd
8
+
9
+ AUTH_TOKEN = "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"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
12
+ use_auth_token=AUTH_TOKEN)
13
+ pad_token_id = tokenizer.pad_token_id
14
+
15
+
16
+ class PairwiseModel(nn.Module):
17
+ def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"):
18
+ super(PairwiseModel, self).__init__()
19
+ self.max_length = max_length
20
+ self.batch_size = batch_size
21
+ self.device = device
22
+ self.model = AutoModel.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
23
+ self.model.to(self.device)
24
+ self.model.eval()
25
+ self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
26
+ self.fc = nn.Linear(768, 1).to(self.device)
27
+
28
+ def forward(self, ids, masks):
29
+ out = self.model(input_ids=ids,
30
+ attention_mask=masks,
31
+ output_hidden_states=False).last_hidden_state
32
+ out = out[:, 0]
33
+ outputs = self.fc(out)
34
+ return outputs
35
+
36
+ def stage1_ranking(self, question, texts):
37
+ tmp = pd.DataFrame()
38
+ tmp["text"] = [" ".join(x.split()) for x in texts]
39
+ tmp["question"] = question
40
+ valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
41
+ valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
42
+ num_workers=0, shuffle=False, pin_memory=True)
43
+ preds = []
44
+ with torch.no_grad():
45
+ bar = enumerate(valid_loader)
46
+ for step, data in bar:
47
+ ids = data["ids"].to(self.device)
48
+ masks = data["masks"].to(self.device)
49
+ preds.append(torch.sigmoid(self(ids, masks)).view(-1))
50
+ preds = torch.concat(preds)
51
+ return preds.cpu().numpy()
52
+
53
+ def stage2_ranking(self, question, answer, titles, texts):
54
+ tmp = pd.DataFrame()
55
+ tmp["candidate"] = texts
56
+ tmp["question"] = question
57
+ tmp["answer"] = answer
58
+ tmp["title"] = titles
59
+ valid_dataset = SiameseDatasetStage2(tmp, tokenizer, self.max_length, is_test=True)
60
+ valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
61
+ num_workers=0, shuffle=False, pin_memory=True)
62
+ preds = []
63
+ with torch.no_grad():
64
+ bar = enumerate(valid_loader)
65
+ for step, data in bar:
66
+ ids = data["ids"].to(self.device)
67
+ masks = data["masks"].to(self.device)
68
+ preds.append(torch.sigmoid(self(ids, masks)).view(-1))
69
+ preds = torch.concat(preds)
70
+ return preds.cpu().numpy()
71
+
72
+
73
+ class SiameseDatasetStage1(Dataset):
74
+
75
+ def __init__(self, df, tokenizer, max_length, is_test=False):
76
+ self.df = df
77
+ self.max_length = max_length
78
+ self.tokenizer = tokenizer
79
+ self.is_test = is_test
80
+ self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[
81
+ "input_ids"]
82
+ self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[
83
+ "input_ids"]
84
+ if not self.is_test:
85
+ self.targets = self.df.label
86
+
87
+ def __len__(self):
88
+ return len(self.df)
89
+
90
+ def __getitem__(self, index):
91
+ return {
92
+ 'ids1': torch.tensor(self.content1[index], dtype=torch.long),
93
+ 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
94
+ 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
95
+ }
96
+
97
+
98
+ class SiameseDatasetStage2(Dataset):
99
+
100
+ def __init__(self, df, tokenizer, max_length, is_test=False):
101
+ self.df = df
102
+ self.max_length = max_length
103
+ self.tokenizer = tokenizer
104
+ self.is_test = is_test
105
+ self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1)
106
+ self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1)
107
+ self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[
108
+ "input_ids"]
109
+ self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[
110
+ "input_ids"]
111
+ if not self.is_test:
112
+ self.targets = self.df.label
113
+
114
+ def __len__(self):
115
+ return len(self.df)
116
+
117
+ def __getitem__(self, index):
118
+ return {
119
+ 'ids1': torch.tensor(self.content1[index], dtype=torch.long),
120
+ 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
121
+ 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
122
+ }
123
+
124
+
125
+ def collate_fn(batch):
126
+ ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch]
127
+ targets = [x["target"] for x in batch]
128
+ max_len = np.max([len(x) for x in ids])
129
+ masks = []
130
+ for i in range(len(ids)):
131
+ if len(ids[i]) < max_len:
132
+ ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long)))
133
+ masks.append(ids[i] != pad_token_id)
134
+ # print(tokenizer.decode(ids[0]))
135
+ outputs = {
136
+ "ids": torch.vstack(ids),
137
+ "masks": torch.vstack(masks),
138
+ "target": torch.vstack(targets).view(-1)
139
+ }
140
+ return outputs
models/predict_model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models.pairwise_model import *
2
+ from src.features.text_utils import *
3
+ import regex as re
4
+ from src.models.bm25_utils import BM25Gensim
5
+ from src.models.qa_model import *
6
+ from tqdm.auto import tqdm
7
+ tqdm.pandas()
8
+ from datasets import load_dataset
9
+
10
+ df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
11
+ df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
12
+ df_wiki.title = df_wiki.title.apply(str)
13
+
14
+ entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
15
+ new_dict = dict()
16
+ for key, val in entity_dict.items():
17
+ val = val[0].replace("wiki/", "").replace("_", " ")
18
+ entity_dict[key] = val
19
+ key = preprocess(key)
20
+ new_dict[key.lower()] = val
21
+ entity_dict.update(new_dict)
22
+ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
+
24
+ qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["models/qa_model_robust.bin"], entity_dict)
25
+ pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
26
+ pairwise_model_stage1.load_state_dict(torch.load("models/pairwise_v2.bin", map_location=torch.device('cpu')))
27
+ pairwise_model_stage1.eval()
28
+
29
+ pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
30
+ pairwise_model_stage2.load_state_dict(torch.load("models/pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
31
+
32
+ bm25_model_stage1 = BM25Gensim("models/bm25_stage1/", entity_dict, title2idx)
33
+ bm25_model_stage2_full = BM25Gensim("models/bm25_stage2/full_text/", entity_dict, title2idx)
34
+ bm25_model_stage2_title = BM25Gensim("models/bm25_stage2/title/", entity_dict, title2idx)
35
+
36
+ def get_answer_e2e(question):
37
+ #Bm25 retrieval for top200 candidates
38
+ query = preprocess(question).lower()
39
+ top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
40
+ titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
41
+ texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
42
+
43
+ #Reranking with pairwise model for top10
44
+ question = preprocess(question)
45
+ ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
46
+ ranking_scores = ranking_preds * bm25_scores
47
+
48
+ #Question answering
49
+ best_idxs = np.argsort(ranking_scores)[-10:]
50
+ ranking_scores = np.array(ranking_scores)[best_idxs]
51
+ texts = np.array(texts)[best_idxs]
52
+ best_answer = qa_model(question, texts, ranking_scores)
53
+ if best_answer is None:
54
+ return "Chịu"
55
+ bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
56
+
57
+ #Entity mapping
58
+ if not check_number(bm25_answer):
59
+ bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
60
+ bm25_question_answer = bm25_question + " " + bm25_answer
61
+ candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
62
+ titles = [df_wiki.title.values[i] for i in candidates]
63
+ texts = [df_wiki.text.values[i] for i in candidates]
64
+ ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
65
+ if ranking_preds.max() >= 0.1:
66
+ final_answer = titles[ranking_preds.argmax()]
67
+ else:
68
+ candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
69
+ titles = [df_wiki.title.values[i] for i in candidates] + titles
70
+ texts = [df_wiki.text.values[i] for i in candidates] + texts
71
+ ranking_preds = np.concatenate(
72
+ [pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
73
+ final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
74
+ else:
75
+ final_answer = bm25_answer.lower()
76
+ return final_answer
models/qa_model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoModelForQuestionAnswering, pipeline
5
+ from src.features.text_utils import post_process_answer
6
+ from src.features.graph_utils import find_best_cluster
7
+
8
+
9
+ class QAEnsembleModel(nn.Module):
10
+
11
+ def __init__(self, model_name, model_checkpoints, entity_dict,
12
+ thr=0.1, device="cpu"):
13
+ super(QAEnsembleModel, self).__init__()
14
+ self.nlps = []
15
+ for model_checkpoint in model_checkpoints:
16
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)#.half()
17
+ model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
18
+ nlp = pipeline('question-answering', model=model,
19
+ tokenizer=model_name, device=device)
20
+ self.nlps.append(nlp)
21
+ self.entity_dict = entity_dict
22
+ self.thr = thr
23
+
24
+ def forward(self, question, texts, ranking_scores=None):
25
+ if ranking_scores is None:
26
+ ranking_scores = np.ones((len(texts),))
27
+
28
+ curr_answers = []
29
+ curr_scores = []
30
+ best_score = 0
31
+ for i, nlp in enumerate(self.nlps):
32
+ for text, score in zip(texts, ranking_scores):
33
+ QA_input = {
34
+ 'question': question,
35
+ 'context': text
36
+ }
37
+ res = nlp(QA_input)
38
+ # print(res)
39
+ if res["score"] > self.thr:
40
+ curr_answers.append(res["answer"])
41
+ curr_scores.append(res["score"])
42
+ res["score"] = res["score"] * score
43
+ if i == 0:
44
+ if res["score"] > best_score:
45
+ answer = res["answer"]
46
+ best_score = res["score"]
47
+ if len(curr_answers) == 0:
48
+ return None
49
+ curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
50
+ answer = post_process_answer(answer, self.entity_dict)
51
+ new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
52
+ return new_best_answer