Spaces:
Runtime error
Runtime error
letrunglinh
commited on
Commit
•
fa01b79
1
Parent(s):
8d62285
Upload 15 files
Browse files- .gitattributes +3 -0
- app.py +59 -0
- bm25_utils.py +40 -0
- graph_utils.py +110 -0
- openvino_stage1/stage1.bin +3 -0
- openvino_stage1/stage1.mapping +0 -0
- openvino_stage1/stage1.xml +0 -0
- outputs/bm25_stage1/bm25_index +3 -0
- outputs/bm25_stage1/dict +0 -0
- outputs/bm25_stage1/tfidf +3 -0
- pairwise_model.py +139 -0
- processed/entities.json +0 -0
- processed/wikipedia_chungta_cleaned.csv +3 -0
- processed/wikipedia_chungta_short.csv +0 -0
- qa_model.py +65 -0
- text_utils.py +82 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
outputs/bm25_stage1/bm25_index filter=lfs diff=lfs merge=lfs -text
|
36 |
+
outputs/bm25_stage1/tfidf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
processed/wikipedia_chungta_cleaned.csv filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from text_utils import *
|
3 |
+
import pandas as pd
|
4 |
+
from qa_model import *
|
5 |
+
from bm25_utils import *
|
6 |
+
from pairwise_model import *
|
7 |
+
|
8 |
+
df_wiki_windows = pd.read_csv("./processed/wikipedia_chungta_cleaned.csv")
|
9 |
+
df_wiki = pd.read_csv("./processed/wikipedia_chungta_short.csv")
|
10 |
+
df_wiki.title = df_wiki.title.apply(str)
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
entity_dict = json.load(open("./processed/entities.json"))
|
15 |
+
new_dict = dict()
|
16 |
+
for key, val in entity_dict.items():
|
17 |
+
val = val.replace("wiki/", "").replace("_", " ")
|
18 |
+
entity_dict[key] = val
|
19 |
+
key = preprocess(key)
|
20 |
+
new_dict[key.lower()] = val
|
21 |
+
entity_dict.update(new_dict)
|
22 |
+
title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
|
23 |
+
|
24 |
+
|
25 |
+
qa_model = QAEnsembleModel_modify("letrunglinh/qa_pnc", entity_dict)
|
26 |
+
pairwise_model_stage1 = PairwiseModel_modify("nguyenvulebinh/vi-mrc-base")
|
27 |
+
|
28 |
+
bm25_model_stage1 = BM25Gensim("./outputs/bm25_stage1/", entity_dict, title2idx)
|
29 |
+
|
30 |
+
|
31 |
+
def get_answer_e2e(question):
|
32 |
+
#Bm25 retrieval for top200 candidates
|
33 |
+
query = preprocess(question).lower()
|
34 |
+
top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
|
35 |
+
titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
|
36 |
+
pre_texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
|
37 |
+
|
38 |
+
#Reranking with pairwise model for top10
|
39 |
+
question = preprocess(question)
|
40 |
+
ranking_preds = pairwise_model_stage1.stage1_ranking(question, pre_texts)
|
41 |
+
|
42 |
+
ranking_scores = ranking_preds * bm25_scores
|
43 |
+
|
44 |
+
#Question answering
|
45 |
+
best_idxs = np.argsort(ranking_scores)[-10:]
|
46 |
+
ranking_scores = np.array(ranking_scores)[best_idxs]
|
47 |
+
texts = np.array(pre_texts)[best_idxs]
|
48 |
+
|
49 |
+
best_answer = qa_model(question, texts, ranking_scores)
|
50 |
+
|
51 |
+
if best_answer is None:
|
52 |
+
return pre_texts[0]
|
53 |
+
|
54 |
+
return best_answer
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
# result = get_answer_e2e("OKR là gì?")
|
58 |
+
# print(result)
|
59 |
+
gr.Interface(fn=get_answer_e2e, inputs=["text"], outputs=["textbox"]).launch()
|
bm25_utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tqdm.auto import tqdm
|
3 |
+
|
4 |
+
tqdm.pandas()
|
5 |
+
from gensim.corpora import Dictionary
|
6 |
+
from gensim.models import TfidfModel
|
7 |
+
from gensim.similarities import SparseMatrixSimilarity
|
8 |
+
from text_utils import preprocess
|
9 |
+
|
10 |
+
|
11 |
+
class BM25Gensim:
|
12 |
+
def __init__(self, checkpoint_path, entity_dict, title2idx):
|
13 |
+
self.dictionary = Dictionary.load(checkpoint_path + "/dict")
|
14 |
+
self.tfidf_model = SparseMatrixSimilarity.load(checkpoint_path + "/tfidf")
|
15 |
+
self.bm25_index = TfidfModel.load(checkpoint_path + "/bm25_index")
|
16 |
+
self.title2idx = title2idx
|
17 |
+
self.entity_dict = entity_dict
|
18 |
+
|
19 |
+
def get_topk_stage1(self, query, topk=100):
|
20 |
+
tokenized_query = query.split()
|
21 |
+
tfidf_query = self.tfidf_model[self.dictionary.doc2bow(tokenized_query)]
|
22 |
+
scores = self.bm25_index[tfidf_query]
|
23 |
+
top_n = np.argsort(scores)[::-1][:topk]
|
24 |
+
return top_n, scores[top_n]
|
25 |
+
|
26 |
+
def get_topk_stage2(self, x, raw_answer=None, topk=50):
|
27 |
+
x = str(x)
|
28 |
+
query = preprocess(x, max_length=128).lower().split()
|
29 |
+
tfidf_query = self.tfidf_model[self.dictionary.doc2bow(query)]
|
30 |
+
scores = self.bm25_index[tfidf_query]
|
31 |
+
top_n = list(np.argsort(scores)[::-1][:topk])
|
32 |
+
if raw_answer is not None:
|
33 |
+
raw_answer = raw_answer.strip()
|
34 |
+
if raw_answer in self.entity_dict:
|
35 |
+
title = self.entity_dict[raw_answer].replace("wiki/", "").replace("_", " ")
|
36 |
+
extra_id = self.title2idx.get(title, -1)
|
37 |
+
if extra_id != -1 and extra_id not in top_n:
|
38 |
+
top_n.append(extra_id)
|
39 |
+
scores = scores[top_n]
|
40 |
+
return np.array(top_n), np.array(scores)
|
graph_utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import networkx as nx
|
2 |
+
import numpy as np
|
3 |
+
from cdlib import algorithms
|
4 |
+
|
5 |
+
|
6 |
+
# these functions are heavily influenced by the HF squad_metrics.py script
|
7 |
+
def normalize_text(s):
|
8 |
+
"""Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
|
9 |
+
import string, re
|
10 |
+
|
11 |
+
def remove_articles(text):
|
12 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
13 |
+
return re.sub(regex, " ", text)
|
14 |
+
|
15 |
+
def white_space_fix(text):
|
16 |
+
return " ".join(text.split())
|
17 |
+
|
18 |
+
def remove_punc(text):
|
19 |
+
exclude = set(string.punctuation)
|
20 |
+
return "".join(ch for ch in text if ch not in exclude)
|
21 |
+
|
22 |
+
def lower(text):
|
23 |
+
return text.lower()
|
24 |
+
|
25 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
26 |
+
|
27 |
+
|
28 |
+
def compute_exact_match(prediction, truth):
|
29 |
+
return int(normalize_text(prediction) == normalize_text(truth))
|
30 |
+
|
31 |
+
|
32 |
+
def compute_f1(prediction, truth):
|
33 |
+
pred_tokens = normalize_text(prediction).split()
|
34 |
+
truth_tokens = normalize_text(truth).split()
|
35 |
+
|
36 |
+
# if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
|
37 |
+
if len(pred_tokens) == 0 or len(truth_tokens) == 0:
|
38 |
+
return int(pred_tokens == truth_tokens)
|
39 |
+
|
40 |
+
common_tokens = set(pred_tokens) & set(truth_tokens)
|
41 |
+
|
42 |
+
# if there are no common tokens then f1 = 0
|
43 |
+
if len(common_tokens) == 0:
|
44 |
+
return 0
|
45 |
+
|
46 |
+
prec = len(common_tokens) / len(pred_tokens)
|
47 |
+
rec = len(common_tokens) / len(truth_tokens)
|
48 |
+
|
49 |
+
return 2 * (prec * rec) / (prec + rec)
|
50 |
+
|
51 |
+
|
52 |
+
def is_date_or_num(answer):
|
53 |
+
answer = answer.lower().split()
|
54 |
+
for w in answer:
|
55 |
+
w = w.strip()
|
56 |
+
if w.isnumeric() or w in ["ngày", "tháng", "năm"]:
|
57 |
+
return True
|
58 |
+
return False
|
59 |
+
|
60 |
+
|
61 |
+
def find_best_cluster(answers, best_answer, thr=0.79):
|
62 |
+
if len(answers) == 0: # or best_answer not in answers:
|
63 |
+
return best_answer
|
64 |
+
elif len(answers) == 1:
|
65 |
+
return answers[0]
|
66 |
+
dists = np.zeros((len(answers), len(answers)))
|
67 |
+
for i in range(len(answers) - 1):
|
68 |
+
for j in range(i + 1, len(answers)):
|
69 |
+
a1 = answers[i].lower().strip()
|
70 |
+
a2 = answers[j].lower().strip()
|
71 |
+
if is_date_or_num(a1) or is_date_or_num(a2):
|
72 |
+
# print(a1, a2)
|
73 |
+
if a1 == a2 or ("tháng" in a1 and a1 in a2) or ("tháng" in a2 and a2 in a1):
|
74 |
+
dists[i, j] = 1
|
75 |
+
dists[j, i] = 1
|
76 |
+
# continue
|
77 |
+
elif a1 == a2 or (a1 in a2) or (a2 in a1) or compute_f1(a1.lower(), a2.lower()) >= thr:
|
78 |
+
dists[i, j] = 1
|
79 |
+
dists[j, i] = 1
|
80 |
+
# print(dists)
|
81 |
+
try:
|
82 |
+
thr = 1
|
83 |
+
dups = np.where(dists >= thr)
|
84 |
+
dup_strs = []
|
85 |
+
edges = []
|
86 |
+
for i, j in zip(dups[0], dups[1]):
|
87 |
+
if i != j:
|
88 |
+
edges.append((i, j))
|
89 |
+
G = nx.Graph()
|
90 |
+
for i, answer in enumerate(answers):
|
91 |
+
G.add_node(i, content=answer)
|
92 |
+
G.add_edges_from(edges)
|
93 |
+
partition = algorithms.louvain(G)
|
94 |
+
max_len_comm = np.max([len(x) for x in partition.communities])
|
95 |
+
best_comms = []
|
96 |
+
for comm in partition.communities:
|
97 |
+
# print([answers[i] for i in comm])
|
98 |
+
if len(comm) == max_len_comm:
|
99 |
+
best_comms.append([answers[i] for i in comm])
|
100 |
+
# if len(best_comms) > 1:
|
101 |
+
# return best_answer
|
102 |
+
for comm in best_comms:
|
103 |
+
if best_answer in comm:
|
104 |
+
return best_answer
|
105 |
+
mid = len(best_comms[0]) // 2
|
106 |
+
# print(mid, sorted(best_comms[0], key = len))
|
107 |
+
return sorted(best_comms[0], key=len)[mid]
|
108 |
+
except Exception as e:
|
109 |
+
print(e, "Disconnected graph")
|
110 |
+
return best_answer
|
openvino_stage1/stage1.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac91eb96b7cb06441c544d46f1a64b9f44c7bec4f629e91c1e28cee61570262a
|
3 |
+
size 1109819632
|
openvino_stage1/stage1.mapping
ADDED
The diff for this file is too large to render.
See raw diff
|
|
openvino_stage1/stage1.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
outputs/bm25_stage1/bm25_index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35c10a1b81b1a458ca41cd3fcc73ab1dcedf0606024aa57b94a99e21e1820fbd
|
3 |
+
size 28408446
|
outputs/bm25_stage1/dict
ADDED
Binary file (802 kB). View file
|
|
outputs/bm25_stage1/tfidf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed2df8ff211b311b6f99b1f9f7d5f34ad24139fbedb608aa04465a37aae56b58
|
3 |
+
size 1525507
|
pairwise_model.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
6 |
+
import pandas as pd
|
7 |
+
from optimum.intel import OVModelForQuestionAnswering
|
8 |
+
import openvino.inference_engine as ie
|
9 |
+
import os
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
AUTH_TOKEN = "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"
|
13 |
+
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
15 |
+
use_auth_token=AUTH_TOKEN)
|
16 |
+
pad_token_id = tokenizer.pad_token_id
|
17 |
+
|
18 |
+
# Load the model
|
19 |
+
model_xml = "openvino_stage1/stage1.xml"
|
20 |
+
model_bin = "openvino_stage1/stage1.bin"
|
21 |
+
# Create an Inference Engine object
|
22 |
+
ie_core = ie.IECore()
|
23 |
+
# Read the IR files"
|
24 |
+
net = ie_core.read_network(model=model_xml, weights=model_bin)
|
25 |
+
|
26 |
+
class PairwiseModel_modify(nn.Module):
|
27 |
+
def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"):
|
28 |
+
super(PairwiseModel_modify, self).__init__()
|
29 |
+
self.max_length = max_length
|
30 |
+
self.batch_size = batch_size
|
31 |
+
self.device = device
|
32 |
+
# self.model = AutoModel.from_pretrained(model_name , use_auth_token=AUTH_TOKEN)
|
33 |
+
self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
|
34 |
+
self.fc = nn.Linear(768, 1).to(self.device)
|
35 |
+
|
36 |
+
def forward(self, ids, masks):
|
37 |
+
# Export the model to ONNX format
|
38 |
+
input_feed = {"input_ids": ids.cpu().numpy().astype(np.int64), "attention_mask": masks.cpu().numpy().astype(np.int64)}
|
39 |
+
# Specify the input shapes (batch_size, max_sequence_length)
|
40 |
+
input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
|
41 |
+
|
42 |
+
# Set the input shapes in the network
|
43 |
+
net.reshape(input_shapes)
|
44 |
+
|
45 |
+
# Load the network with the specified input shapes
|
46 |
+
exec_net = ie_core.load_network(network=net, device_name="CPU")
|
47 |
+
outputs = exec_net.infer(input_feed)
|
48 |
+
|
49 |
+
# Get the output tensor and apply the linear layer
|
50 |
+
out = torch.from_numpy(outputs["output"]).to(self.device)
|
51 |
+
out = out[:, 0]
|
52 |
+
return out
|
53 |
+
|
54 |
+
def stage1_ranking(self, question, texts):
|
55 |
+
tmp = pd.DataFrame()
|
56 |
+
tmp["text"] = [" ".join(x.split()) for x in texts]
|
57 |
+
tmp["question"] = question
|
58 |
+
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
|
59 |
+
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
60 |
+
num_workers=0, shuffle=False, pin_memory=True)
|
61 |
+
preds = []
|
62 |
+
with torch.no_grad():
|
63 |
+
bar = enumerate(valid_loader)
|
64 |
+
for step, data in bar:
|
65 |
+
ids = data["ids"].to(self.device)
|
66 |
+
masks = data["masks"].to(self.device)
|
67 |
+
preds.append(torch.sigmoid(self(ids, masks)).view(-1))
|
68 |
+
preds = torch.concat(preds)
|
69 |
+
return preds.cpu().numpy()
|
70 |
+
|
71 |
+
|
72 |
+
class SiameseDatasetStage1(Dataset):
|
73 |
+
|
74 |
+
def __init__(self, df, tokenizer, max_length, is_test=False):
|
75 |
+
self.df = df
|
76 |
+
self.max_length = max_length
|
77 |
+
self.tokenizer = tokenizer
|
78 |
+
self.is_test = is_test
|
79 |
+
self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[
|
80 |
+
"input_ids"]
|
81 |
+
self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[
|
82 |
+
"input_ids"]
|
83 |
+
if not self.is_test:
|
84 |
+
self.targets = self.df.label
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.df)
|
88 |
+
|
89 |
+
def __getitem__(self, index):
|
90 |
+
return {
|
91 |
+
'ids1': torch.tensor(self.content1[index], dtype=torch.long),
|
92 |
+
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
|
93 |
+
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
class SiameseDatasetStage2(Dataset):
|
98 |
+
|
99 |
+
def __init__(self, df, tokenizer, max_length, is_test=False):
|
100 |
+
self.df = df
|
101 |
+
self.max_length = max_length
|
102 |
+
self.tokenizer = tokenizer
|
103 |
+
self.is_test = is_test
|
104 |
+
self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1)
|
105 |
+
self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1)
|
106 |
+
self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[
|
107 |
+
"input_ids"]
|
108 |
+
self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[
|
109 |
+
"input_ids"]
|
110 |
+
if not self.is_test:
|
111 |
+
self.targets = self.df.label
|
112 |
+
|
113 |
+
def __len__(self):
|
114 |
+
return len(self.df)
|
115 |
+
|
116 |
+
def __getitem__(self, index):
|
117 |
+
return {
|
118 |
+
'ids1': torch.tensor(self.content1[index], dtype=torch.long),
|
119 |
+
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
|
120 |
+
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
def collate_fn(batch):
|
125 |
+
ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch]
|
126 |
+
targets = [x["target"] for x in batch]
|
127 |
+
max_len = np.max([len(x) for x in ids])
|
128 |
+
masks = []
|
129 |
+
for i in range(len(ids)):
|
130 |
+
if len(ids[i]) < max_len:
|
131 |
+
ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long)))
|
132 |
+
masks.append(ids[i] != pad_token_id)
|
133 |
+
# print(tokenizer.decode(ids[0]))
|
134 |
+
outputs = {
|
135 |
+
"ids": torch.vstack(ids),
|
136 |
+
"masks": torch.vstack(masks),
|
137 |
+
"target": torch.vstack(targets).view(-1)
|
138 |
+
}
|
139 |
+
return outputs
|
processed/entities.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
processed/wikipedia_chungta_cleaned.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8968d3a8c99c9844b1aaf6b2c174830305473e41a85e08058ceceb95dc8ae921
|
3 |
+
size 66457234
|
processed/wikipedia_chungta_short.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
qa_model.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from transformers import AutoTokenizer, pipeline
|
3 |
+
from optimum.onnxruntime import ORTModelForQuestionAnswering
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers import AutoModelForQuestionAnswering, pipeline
|
8 |
+
from text_utils import post_process_answer
|
9 |
+
from graph_utils import find_best_cluster
|
10 |
+
from optimum.intel import OVModelForQuestionAnswering
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
from text_utils import *
|
14 |
+
|
15 |
+
|
16 |
+
# os.environ['HTTP_PROXY'] = 'http://proxy.hcm.fpt.vn:80'
|
17 |
+
class QAEnsembleModel_modify(nn.Module):
|
18 |
+
|
19 |
+
# def __init__(self, model_name, model_checkpoints, entity_dict,
|
20 |
+
# thr=0.1, device="cuda:0"):
|
21 |
+
def __init__(self, model_name, entity_dict,
|
22 |
+
thr=0.1, device="cpu"):
|
23 |
+
super(QAEnsembleModel_modify, self).__init__()
|
24 |
+
self.nlps = []
|
25 |
+
# model_checkpoint = "./data/qa_model_robust.bin"
|
26 |
+
AUTH_TOKEN = "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"
|
27 |
+
# model_checkpoint = "letrunglinh/qa_pnc"
|
28 |
+
model_convert = OVModelForQuestionAnswering.from_pretrained(model_name, export= True, use_auth_token= AUTH_TOKEN)
|
29 |
+
# model_convert.half()
|
30 |
+
# model_convert.compile()
|
31 |
+
nlp = pipeline('question-answering', model=model_convert,
|
32 |
+
tokenizer=model_name)
|
33 |
+
self.nlps.append(nlp)
|
34 |
+
self.entity_dict = entity_dict
|
35 |
+
self.thr = thr
|
36 |
+
|
37 |
+
def forward(self, question, texts, ranking_scores=None):
|
38 |
+
if ranking_scores is None:
|
39 |
+
ranking_scores = np.ones((len(texts),))
|
40 |
+
|
41 |
+
curr_answers = []
|
42 |
+
curr_scores = []
|
43 |
+
best_score = 0
|
44 |
+
for i, nlp in enumerate(self.nlps):
|
45 |
+
for text, score in zip(texts, ranking_scores):
|
46 |
+
QA_input = {
|
47 |
+
'question': question,
|
48 |
+
'context': text
|
49 |
+
}
|
50 |
+
res = nlp(QA_input)
|
51 |
+
print(res)
|
52 |
+
if res["score"] > self.thr:
|
53 |
+
curr_answers.append(res["answer"])
|
54 |
+
curr_scores.append(res["score"])
|
55 |
+
res["score"] = res["score"] * score
|
56 |
+
if i == 0:
|
57 |
+
if res["score"] > best_score:
|
58 |
+
answer = res["answer"]
|
59 |
+
best_score = res["score"]
|
60 |
+
if len(curr_answers) == 0:
|
61 |
+
return None
|
62 |
+
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
|
63 |
+
answer = post_process_answer(answer, self.entity_dict)
|
64 |
+
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
|
65 |
+
return new_best_answer
|
text_utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from glob import glob
|
3 |
+
import re
|
4 |
+
from nltk import word_tokenize as lib_tokenizer
|
5 |
+
import string
|
6 |
+
|
7 |
+
|
8 |
+
def preprocess(x, max_length=-1, remove_puncts=False):
|
9 |
+
x = nltk_tokenize(x)
|
10 |
+
x = x.replace("\n", " ")
|
11 |
+
if remove_puncts:
|
12 |
+
x = "".join([i for i in x if i not in string.punctuation])
|
13 |
+
if max_length > 0:
|
14 |
+
x = " ".join(x.split()[:max_length])
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def nltk_tokenize(x):
|
19 |
+
return " ".join(word_tokenize(strip_context(x))).strip()
|
20 |
+
|
21 |
+
|
22 |
+
def post_process_answer(x, entity_dict):
|
23 |
+
if type(x) is not str:
|
24 |
+
return x
|
25 |
+
try:
|
26 |
+
x = strip_answer_string(x)
|
27 |
+
except:
|
28 |
+
return "NaN"
|
29 |
+
x = "".join([c for c in x if c not in string.punctuation])
|
30 |
+
x = " ".join(x.split())
|
31 |
+
y = x.lower()
|
32 |
+
if len(y) > 1 and y.split()[0].isnumeric() and ("tháng" not in x):
|
33 |
+
return y.split()[0]
|
34 |
+
if not (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x):
|
35 |
+
if len(x.split()) <= 2:
|
36 |
+
return entity_dict.get(x.lower(), x)
|
37 |
+
else:
|
38 |
+
return x
|
39 |
+
else:
|
40 |
+
return y
|
41 |
+
|
42 |
+
|
43 |
+
dict_map = dict({})
|
44 |
+
|
45 |
+
|
46 |
+
def word_tokenize(text):
|
47 |
+
global dict_map
|
48 |
+
words = text.split()
|
49 |
+
words_norm = []
|
50 |
+
for w in words:
|
51 |
+
if dict_map.get(w, None) is None:
|
52 |
+
dict_map[w] = ' '.join(lib_tokenizer(w)).replace('``', '"').replace("''", '"')
|
53 |
+
words_norm.append(dict_map[w])
|
54 |
+
return words_norm
|
55 |
+
|
56 |
+
|
57 |
+
def strip_answer_string(text):
|
58 |
+
text = text.strip()
|
59 |
+
while text[-1] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
|
60 |
+
if text[0] != '(' and text[-1] == ')' and '(' in text:
|
61 |
+
break
|
62 |
+
if text[-1] == '"' and text[0] != '"' and text.count('"') > 1:
|
63 |
+
break
|
64 |
+
text = text[:-1].strip()
|
65 |
+
while text[0] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
|
66 |
+
if text[0] == '"' and text[-1] != '"' and text.count('"') > 1:
|
67 |
+
break
|
68 |
+
text = text[1:].strip()
|
69 |
+
text = text.strip()
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
def strip_context(text):
|
74 |
+
text = text.replace('\n', ' ')
|
75 |
+
text = re.sub(r'\s+', ' ', text)
|
76 |
+
text = text.strip()
|
77 |
+
return text
|
78 |
+
|
79 |
+
|
80 |
+
def check_number(x):
|
81 |
+
x = str(x).lower()
|
82 |
+
return (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x)
|