Spaces:
Runtime error
Runtime error
File size: 6,366 Bytes
e011405 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import torch, math
from pyvi.ViTokenizer import tokenize
import re, os, string
import pandas as pd
import math
import numpy as np
class BM25:
def __init__(self, k1=1.5, b=0.75):
self.b = b
self.k1 = k1
def fit(self, corpus):
"""
Fit the various statistics that are required to calculate BM25 ranking
score using the corpus given.
Parameters
----------
corpus : list[list[str]]
Each element in the list represents a document, and each document
is a list of the terms.
Returns
-------
self
"""
tf = []
df = {}
idf = {}
doc_len = []
corpus_size = 0
for document in corpus:
corpus_size += 1
doc_len.append(len(document))
# compute tf (term frequency) per document
frequencies = {}
for term in document:
term_count = frequencies.get(term, 0) + 1
frequencies[term] = term_count
tf.append(frequencies)
# compute df (document frequency) per term
for term, _ in frequencies.items():
df_count = df.get(term, 0) + 1
df[term] = df_count
for term, freq in df.items():
idf[term] = math.log(1 + (corpus_size - freq + 0.5) / (freq + 0.5))
self.tf_ = tf
self.df_ = df
self.idf_ = idf
self.doc_len_ = doc_len
self.corpus_ = corpus
self.corpus_size_ = corpus_size
self.avg_doc_len_ = sum(doc_len) / corpus_size
return self
def search(self, query):
scores = [self._score(query, index) for index in range(self.corpus_size_)]
return scores
def _score(self, query, index):
score = 0.0
doc_len = self.doc_len_[index]
frequencies = self.tf_[index]
for term in query:
if term not in frequencies:
continue
freq = frequencies[term]
numerator = self.idf_[term] * freq * (self.k1 + 1)
denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len_)
score += (numerator / denominator)
return score
class Retrieval:
def __init__(
self, k=8,
model='retrieval/bm25.pt',
contexts='retrieval/context.pt',
stop_words='retrieval/stopwords.csv',
max_len = 400,
docs = None
) -> None:
self.k = k
self.max_len = max_len
data = pd.read_csv(stop_words, sep="\t", encoding='utf-8')
self.list_stopwords = data['stopwords']
if docs:
self.tuning(docs)
else:
self.bm25 = torch.load(model)
self.contexts = torch.load(contexts)
def get_context(self, query='Chảy máu chân răng là bệnh gì?'):
def clean_text(text):
text = re.sub('<.*?>', '', text).strip()
text = re.sub('(\s)+', r'\1', text)
return text
def normalize_text(text):
listpunctuation = string.punctuation.replace('_', '')
for i in listpunctuation:
text = text.replace(i, ' ')
return text.lower()
def remove_stopword(text):
pre_text = []
words = text.split()
for word in words:
if word not in self.list_stopwords:
pre_text.append(word)
text2 = ' '.join(pre_text)
return text2
def word_segment(sent):
sent = tokenize(sent.encode('utf-8').decode('utf-8'))
return sent
query = clean_text(query)
query = word_segment(query)
query = remove_stopword(normalize_text(query))
query = query.split()
scores = self.bm25.search(query)
scores_index = np.argsort(scores)
results = []
for k in range(1, self.k+1):
index = scores_index[-k]
result = {'score':scores[index], 'index':index, 'context':self.contexts[index]}
results.append(result)
return results
def split(self, document):
document = document.replace('\n', ' ')
document = re.sub(' +', ' ', document)
sentences = document.split('. ')
context_list = []
context = ""
length = 0
pre = ""
len__ = 0
for sentence in sentences:
sentence += '. '
len_ = len(sentence.split())
if length + len_ > self.max_len:
context_list.append(context)
context = pre
length = len__
length += len_
context += sentence
pre = sentence
len__ = len_
context_list.append(context)
self.contexts = context_list
if len(context_list) < self.k:
self.k = len(context_list)
def tuning(self, document):
def clean_text(text):
text = re.sub('<.*?>', '', text).strip()
text = re.sub('(\s)+', r'\1', text)
return text
def normalize_text(text):
listpunctuation = string.punctuation.replace('_', '')
for i in listpunctuation:
text = text.replace(i, ' ')
return text.lower()
def remove_stopword(text):
pre_text = []
words = text.split()
for word in words:
if word not in self.list_stopwords:
pre_text.append(word)
text2 = ' '.join(pre_text)
return text2
def word_segment(sent):
sent = tokenize(sent.encode('utf-8').decode('utf-8'))
return sent
self.split(document)
docs = []
for content in self.contexts:
content = clean_text(content)
content = word_segment(content)
content = remove_stopword(normalize_text(content))
docs.append(content)
print('There is', len(docs), 'contexts')
texts = [
[word for word in document.lower().split() if word not in self.list_stopwords]
for document in docs
]
self.bm25 = BM25()
self.bm25.fit(texts)
|