letrunglinh commited on
Commit
fa01b79
1 Parent(s): 8d62285

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ outputs/bm25_stage1/bm25_index filter=lfs diff=lfs merge=lfs -text
36
+ outputs/bm25_stage1/tfidf filter=lfs diff=lfs merge=lfs -text
37
+ processed/wikipedia_chungta_cleaned.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from text_utils import *
3
+ import pandas as pd
4
+ from qa_model import *
5
+ from bm25_utils import *
6
+ from pairwise_model import *
7
+
8
+ df_wiki_windows = pd.read_csv("./processed/wikipedia_chungta_cleaned.csv")
9
+ df_wiki = pd.read_csv("./processed/wikipedia_chungta_short.csv")
10
+ df_wiki.title = df_wiki.title.apply(str)
11
+
12
+
13
+
14
+ entity_dict = json.load(open("./processed/entities.json"))
15
+ new_dict = dict()
16
+ for key, val in entity_dict.items():
17
+ val = val.replace("wiki/", "").replace("_", " ")
18
+ entity_dict[key] = val
19
+ key = preprocess(key)
20
+ new_dict[key.lower()] = val
21
+ entity_dict.update(new_dict)
22
+ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
+
24
+
25
+ qa_model = QAEnsembleModel_modify("letrunglinh/qa_pnc", entity_dict)
26
+ pairwise_model_stage1 = PairwiseModel_modify("nguyenvulebinh/vi-mrc-base")
27
+
28
+ bm25_model_stage1 = BM25Gensim("./outputs/bm25_stage1/", entity_dict, title2idx)
29
+
30
+
31
+ def get_answer_e2e(question):
32
+ #Bm25 retrieval for top200 candidates
33
+ query = preprocess(question).lower()
34
+ top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
35
+ titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
36
+ pre_texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
37
+
38
+ #Reranking with pairwise model for top10
39
+ question = preprocess(question)
40
+ ranking_preds = pairwise_model_stage1.stage1_ranking(question, pre_texts)
41
+
42
+ ranking_scores = ranking_preds * bm25_scores
43
+
44
+ #Question answering
45
+ best_idxs = np.argsort(ranking_scores)[-10:]
46
+ ranking_scores = np.array(ranking_scores)[best_idxs]
47
+ texts = np.array(pre_texts)[best_idxs]
48
+
49
+ best_answer = qa_model(question, texts, ranking_scores)
50
+
51
+ if best_answer is None:
52
+ return pre_texts[0]
53
+
54
+ return best_answer
55
+
56
+ if __name__ == "__main__":
57
+ # result = get_answer_e2e("OKR là gì?")
58
+ # print(result)
59
+ gr.Interface(fn=get_answer_e2e, inputs=["text"], outputs=["textbox"]).launch()
bm25_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm.auto import tqdm
3
+
4
+ tqdm.pandas()
5
+ from gensim.corpora import Dictionary
6
+ from gensim.models import TfidfModel
7
+ from gensim.similarities import SparseMatrixSimilarity
8
+ from text_utils import preprocess
9
+
10
+
11
+ class BM25Gensim:
12
+ def __init__(self, checkpoint_path, entity_dict, title2idx):
13
+ self.dictionary = Dictionary.load(checkpoint_path + "/dict")
14
+ self.tfidf_model = SparseMatrixSimilarity.load(checkpoint_path + "/tfidf")
15
+ self.bm25_index = TfidfModel.load(checkpoint_path + "/bm25_index")
16
+ self.title2idx = title2idx
17
+ self.entity_dict = entity_dict
18
+
19
+ def get_topk_stage1(self, query, topk=100):
20
+ tokenized_query = query.split()
21
+ tfidf_query = self.tfidf_model[self.dictionary.doc2bow(tokenized_query)]
22
+ scores = self.bm25_index[tfidf_query]
23
+ top_n = np.argsort(scores)[::-1][:topk]
24
+ return top_n, scores[top_n]
25
+
26
+ def get_topk_stage2(self, x, raw_answer=None, topk=50):
27
+ x = str(x)
28
+ query = preprocess(x, max_length=128).lower().split()
29
+ tfidf_query = self.tfidf_model[self.dictionary.doc2bow(query)]
30
+ scores = self.bm25_index[tfidf_query]
31
+ top_n = list(np.argsort(scores)[::-1][:topk])
32
+ if raw_answer is not None:
33
+ raw_answer = raw_answer.strip()
34
+ if raw_answer in self.entity_dict:
35
+ title = self.entity_dict[raw_answer].replace("wiki/", "").replace("_", " ")
36
+ extra_id = self.title2idx.get(title, -1)
37
+ if extra_id != -1 and extra_id not in top_n:
38
+ top_n.append(extra_id)
39
+ scores = scores[top_n]
40
+ return np.array(top_n), np.array(scores)
graph_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import numpy as np
3
+ from cdlib import algorithms
4
+
5
+
6
+ # these functions are heavily influenced by the HF squad_metrics.py script
7
+ def normalize_text(s):
8
+ """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
9
+ import string, re
10
+
11
+ def remove_articles(text):
12
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
13
+ return re.sub(regex, " ", text)
14
+
15
+ def white_space_fix(text):
16
+ return " ".join(text.split())
17
+
18
+ def remove_punc(text):
19
+ exclude = set(string.punctuation)
20
+ return "".join(ch for ch in text if ch not in exclude)
21
+
22
+ def lower(text):
23
+ return text.lower()
24
+
25
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
26
+
27
+
28
+ def compute_exact_match(prediction, truth):
29
+ return int(normalize_text(prediction) == normalize_text(truth))
30
+
31
+
32
+ def compute_f1(prediction, truth):
33
+ pred_tokens = normalize_text(prediction).split()
34
+ truth_tokens = normalize_text(truth).split()
35
+
36
+ # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
37
+ if len(pred_tokens) == 0 or len(truth_tokens) == 0:
38
+ return int(pred_tokens == truth_tokens)
39
+
40
+ common_tokens = set(pred_tokens) & set(truth_tokens)
41
+
42
+ # if there are no common tokens then f1 = 0
43
+ if len(common_tokens) == 0:
44
+ return 0
45
+
46
+ prec = len(common_tokens) / len(pred_tokens)
47
+ rec = len(common_tokens) / len(truth_tokens)
48
+
49
+ return 2 * (prec * rec) / (prec + rec)
50
+
51
+
52
+ def is_date_or_num(answer):
53
+ answer = answer.lower().split()
54
+ for w in answer:
55
+ w = w.strip()
56
+ if w.isnumeric() or w in ["ngày", "tháng", "năm"]:
57
+ return True
58
+ return False
59
+
60
+
61
+ def find_best_cluster(answers, best_answer, thr=0.79):
62
+ if len(answers) == 0: # or best_answer not in answers:
63
+ return best_answer
64
+ elif len(answers) == 1:
65
+ return answers[0]
66
+ dists = np.zeros((len(answers), len(answers)))
67
+ for i in range(len(answers) - 1):
68
+ for j in range(i + 1, len(answers)):
69
+ a1 = answers[i].lower().strip()
70
+ a2 = answers[j].lower().strip()
71
+ if is_date_or_num(a1) or is_date_or_num(a2):
72
+ # print(a1, a2)
73
+ if a1 == a2 or ("tháng" in a1 and a1 in a2) or ("tháng" in a2 and a2 in a1):
74
+ dists[i, j] = 1
75
+ dists[j, i] = 1
76
+ # continue
77
+ elif a1 == a2 or (a1 in a2) or (a2 in a1) or compute_f1(a1.lower(), a2.lower()) >= thr:
78
+ dists[i, j] = 1
79
+ dists[j, i] = 1
80
+ # print(dists)
81
+ try:
82
+ thr = 1
83
+ dups = np.where(dists >= thr)
84
+ dup_strs = []
85
+ edges = []
86
+ for i, j in zip(dups[0], dups[1]):
87
+ if i != j:
88
+ edges.append((i, j))
89
+ G = nx.Graph()
90
+ for i, answer in enumerate(answers):
91
+ G.add_node(i, content=answer)
92
+ G.add_edges_from(edges)
93
+ partition = algorithms.louvain(G)
94
+ max_len_comm = np.max([len(x) for x in partition.communities])
95
+ best_comms = []
96
+ for comm in partition.communities:
97
+ # print([answers[i] for i in comm])
98
+ if len(comm) == max_len_comm:
99
+ best_comms.append([answers[i] for i in comm])
100
+ # if len(best_comms) > 1:
101
+ # return best_answer
102
+ for comm in best_comms:
103
+ if best_answer in comm:
104
+ return best_answer
105
+ mid = len(best_comms[0]) // 2
106
+ # print(mid, sorted(best_comms[0], key = len))
107
+ return sorted(best_comms[0], key=len)[mid]
108
+ except Exception as e:
109
+ print(e, "Disconnected graph")
110
+ return best_answer
openvino_stage1/stage1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac91eb96b7cb06441c544d46f1a64b9f44c7bec4f629e91c1e28cee61570262a
3
+ size 1109819632
openvino_stage1/stage1.mapping ADDED
The diff for this file is too large to render. See raw diff
 
openvino_stage1/stage1.xml ADDED
The diff for this file is too large to render. See raw diff
 
outputs/bm25_stage1/bm25_index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35c10a1b81b1a458ca41cd3fcc73ab1dcedf0606024aa57b94a99e21e1820fbd
3
+ size 28408446
outputs/bm25_stage1/dict ADDED
Binary file (802 kB). View file
 
outputs/bm25_stage1/tfidf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed2df8ff211b311b6f99b1f9f7d5f34ad24139fbedb608aa04465a37aae56b58
3
+ size 1525507
pairwise_model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
6
+ import pandas as pd
7
+ from optimum.intel import OVModelForQuestionAnswering
8
+ import openvino.inference_engine as ie
9
+ import os
10
+ import gradio as gr
11
+
12
+ AUTH_TOKEN = "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"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
15
+ use_auth_token=AUTH_TOKEN)
16
+ pad_token_id = tokenizer.pad_token_id
17
+
18
+ # Load the model
19
+ model_xml = "openvino_stage1/stage1.xml"
20
+ model_bin = "openvino_stage1/stage1.bin"
21
+ # Create an Inference Engine object
22
+ ie_core = ie.IECore()
23
+ # Read the IR files"
24
+ net = ie_core.read_network(model=model_xml, weights=model_bin)
25
+
26
+ class PairwiseModel_modify(nn.Module):
27
+ def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"):
28
+ super(PairwiseModel_modify, self).__init__()
29
+ self.max_length = max_length
30
+ self.batch_size = batch_size
31
+ self.device = device
32
+ # self.model = AutoModel.from_pretrained(model_name , use_auth_token=AUTH_TOKEN)
33
+ self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
34
+ self.fc = nn.Linear(768, 1).to(self.device)
35
+
36
+ def forward(self, ids, masks):
37
+ # Export the model to ONNX format
38
+ input_feed = {"input_ids": ids.cpu().numpy().astype(np.int64), "attention_mask": masks.cpu().numpy().astype(np.int64)}
39
+ # Specify the input shapes (batch_size, max_sequence_length)
40
+ input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
41
+
42
+ # Set the input shapes in the network
43
+ net.reshape(input_shapes)
44
+
45
+ # Load the network with the specified input shapes
46
+ exec_net = ie_core.load_network(network=net, device_name="CPU")
47
+ outputs = exec_net.infer(input_feed)
48
+
49
+ # Get the output tensor and apply the linear layer
50
+ out = torch.from_numpy(outputs["output"]).to(self.device)
51
+ out = out[:, 0]
52
+ return out
53
+
54
+ def stage1_ranking(self, question, texts):
55
+ tmp = pd.DataFrame()
56
+ tmp["text"] = [" ".join(x.split()) for x in texts]
57
+ tmp["question"] = question
58
+ valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
59
+ valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
60
+ num_workers=0, shuffle=False, pin_memory=True)
61
+ preds = []
62
+ with torch.no_grad():
63
+ bar = enumerate(valid_loader)
64
+ for step, data in bar:
65
+ ids = data["ids"].to(self.device)
66
+ masks = data["masks"].to(self.device)
67
+ preds.append(torch.sigmoid(self(ids, masks)).view(-1))
68
+ preds = torch.concat(preds)
69
+ return preds.cpu().numpy()
70
+
71
+
72
+ class SiameseDatasetStage1(Dataset):
73
+
74
+ def __init__(self, df, tokenizer, max_length, is_test=False):
75
+ self.df = df
76
+ self.max_length = max_length
77
+ self.tokenizer = tokenizer
78
+ self.is_test = is_test
79
+ self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[
80
+ "input_ids"]
81
+ self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[
82
+ "input_ids"]
83
+ if not self.is_test:
84
+ self.targets = self.df.label
85
+
86
+ def __len__(self):
87
+ return len(self.df)
88
+
89
+ def __getitem__(self, index):
90
+ return {
91
+ 'ids1': torch.tensor(self.content1[index], dtype=torch.long),
92
+ 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
93
+ 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
94
+ }
95
+
96
+
97
+ class SiameseDatasetStage2(Dataset):
98
+
99
+ def __init__(self, df, tokenizer, max_length, is_test=False):
100
+ self.df = df
101
+ self.max_length = max_length
102
+ self.tokenizer = tokenizer
103
+ self.is_test = is_test
104
+ self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1)
105
+ self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1)
106
+ self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[
107
+ "input_ids"]
108
+ self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[
109
+ "input_ids"]
110
+ if not self.is_test:
111
+ self.targets = self.df.label
112
+
113
+ def __len__(self):
114
+ return len(self.df)
115
+
116
+ def __getitem__(self, index):
117
+ return {
118
+ 'ids1': torch.tensor(self.content1[index], dtype=torch.long),
119
+ 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
120
+ 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
121
+ }
122
+
123
+
124
+ def collate_fn(batch):
125
+ ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch]
126
+ targets = [x["target"] for x in batch]
127
+ max_len = np.max([len(x) for x in ids])
128
+ masks = []
129
+ for i in range(len(ids)):
130
+ if len(ids[i]) < max_len:
131
+ ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long)))
132
+ masks.append(ids[i] != pad_token_id)
133
+ # print(tokenizer.decode(ids[0]))
134
+ outputs = {
135
+ "ids": torch.vstack(ids),
136
+ "masks": torch.vstack(masks),
137
+ "target": torch.vstack(targets).view(-1)
138
+ }
139
+ return outputs
processed/entities.json ADDED
The diff for this file is too large to render. See raw diff
 
processed/wikipedia_chungta_cleaned.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8968d3a8c99c9844b1aaf6b2c174830305473e41a85e08058ceceb95dc8ae921
3
+ size 66457234
processed/wikipedia_chungta_short.csv ADDED
The diff for this file is too large to render. See raw diff
 
qa_model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from transformers import AutoTokenizer, pipeline
3
+ from optimum.onnxruntime import ORTModelForQuestionAnswering
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoModelForQuestionAnswering, pipeline
8
+ from text_utils import post_process_answer
9
+ from graph_utils import find_best_cluster
10
+ from optimum.intel import OVModelForQuestionAnswering
11
+ import os
12
+ import json
13
+ from text_utils import *
14
+
15
+
16
+ # os.environ['HTTP_PROXY'] = 'http://proxy.hcm.fpt.vn:80'
17
+ class QAEnsembleModel_modify(nn.Module):
18
+
19
+ # def __init__(self, model_name, model_checkpoints, entity_dict,
20
+ # thr=0.1, device="cuda:0"):
21
+ def __init__(self, model_name, entity_dict,
22
+ thr=0.1, device="cpu"):
23
+ super(QAEnsembleModel_modify, self).__init__()
24
+ self.nlps = []
25
+ # model_checkpoint = "./data/qa_model_robust.bin"
26
+ AUTH_TOKEN = "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"
27
+ # model_checkpoint = "letrunglinh/qa_pnc"
28
+ model_convert = OVModelForQuestionAnswering.from_pretrained(model_name, export= True, use_auth_token= AUTH_TOKEN)
29
+ # model_convert.half()
30
+ # model_convert.compile()
31
+ nlp = pipeline('question-answering', model=model_convert,
32
+ tokenizer=model_name)
33
+ self.nlps.append(nlp)
34
+ self.entity_dict = entity_dict
35
+ self.thr = thr
36
+
37
+ def forward(self, question, texts, ranking_scores=None):
38
+ if ranking_scores is None:
39
+ ranking_scores = np.ones((len(texts),))
40
+
41
+ curr_answers = []
42
+ curr_scores = []
43
+ best_score = 0
44
+ for i, nlp in enumerate(self.nlps):
45
+ for text, score in zip(texts, ranking_scores):
46
+ QA_input = {
47
+ 'question': question,
48
+ 'context': text
49
+ }
50
+ res = nlp(QA_input)
51
+ print(res)
52
+ if res["score"] > self.thr:
53
+ curr_answers.append(res["answer"])
54
+ curr_scores.append(res["score"])
55
+ res["score"] = res["score"] * score
56
+ if i == 0:
57
+ if res["score"] > best_score:
58
+ answer = res["answer"]
59
+ best_score = res["score"]
60
+ if len(curr_answers) == 0:
61
+ return None
62
+ curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
63
+ answer = post_process_answer(answer, self.entity_dict)
64
+ new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
65
+ return new_best_answer
text_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from glob import glob
3
+ import re
4
+ from nltk import word_tokenize as lib_tokenizer
5
+ import string
6
+
7
+
8
+ def preprocess(x, max_length=-1, remove_puncts=False):
9
+ x = nltk_tokenize(x)
10
+ x = x.replace("\n", " ")
11
+ if remove_puncts:
12
+ x = "".join([i for i in x if i not in string.punctuation])
13
+ if max_length > 0:
14
+ x = " ".join(x.split()[:max_length])
15
+ return x
16
+
17
+
18
+ def nltk_tokenize(x):
19
+ return " ".join(word_tokenize(strip_context(x))).strip()
20
+
21
+
22
+ def post_process_answer(x, entity_dict):
23
+ if type(x) is not str:
24
+ return x
25
+ try:
26
+ x = strip_answer_string(x)
27
+ except:
28
+ return "NaN"
29
+ x = "".join([c for c in x if c not in string.punctuation])
30
+ x = " ".join(x.split())
31
+ y = x.lower()
32
+ if len(y) > 1 and y.split()[0].isnumeric() and ("tháng" not in x):
33
+ return y.split()[0]
34
+ if not (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x):
35
+ if len(x.split()) <= 2:
36
+ return entity_dict.get(x.lower(), x)
37
+ else:
38
+ return x
39
+ else:
40
+ return y
41
+
42
+
43
+ dict_map = dict({})
44
+
45
+
46
+ def word_tokenize(text):
47
+ global dict_map
48
+ words = text.split()
49
+ words_norm = []
50
+ for w in words:
51
+ if dict_map.get(w, None) is None:
52
+ dict_map[w] = ' '.join(lib_tokenizer(w)).replace('``', '"').replace("''", '"')
53
+ words_norm.append(dict_map[w])
54
+ return words_norm
55
+
56
+
57
+ def strip_answer_string(text):
58
+ text = text.strip()
59
+ while text[-1] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
60
+ if text[0] != '(' and text[-1] == ')' and '(' in text:
61
+ break
62
+ if text[-1] == '"' and text[0] != '"' and text.count('"') > 1:
63
+ break
64
+ text = text[:-1].strip()
65
+ while text[0] in '.,/><;:\'"[]{}+=-_)(*&^!~`':
66
+ if text[0] == '"' and text[-1] != '"' and text.count('"') > 1:
67
+ break
68
+ text = text[1:].strip()
69
+ text = text.strip()
70
+ return text
71
+
72
+
73
+ def strip_context(text):
74
+ text = text.replace('\n', ' ')
75
+ text = re.sub(r'\s+', ' ', text)
76
+ text = text.strip()
77
+ return text
78
+
79
+
80
+ def check_number(x):
81
+ x = str(x).lower()
82
+ return (x.isnumeric() or "ngày" in x or "tháng" in x or "năm" in x)