VDT / retrieval /retrieval.py
CreatorPhan's picture
Update retrieval/retrieval.py
c775592
raw
history blame
6.45 kB
import torch, math
from pyvi.ViTokenizer import tokenize
import re, os, string
import pandas as pd
import math
import numpy as np
class BM25:
def __init__(self, k1=1.5, b=0.75):
self.b = b
self.k1 = k1
def fit(self, corpus):
"""
Fit the various statistics that are required to calculate BM25 ranking
score using the corpus given.
Parameters
----------
corpus : list[list[str]]
Each element in the list represents a document, and each document
is a list of the terms.
Returns
-------
self
"""
tf = []
df = {}
idf = {}
doc_len = []
corpus_size = 0
for document in corpus:
corpus_size += 1
doc_len.append(len(document))
# compute tf (term frequency) per document
frequencies = {}
for term in document:
term_count = frequencies.get(term, 0) + 1
frequencies[term] = term_count
tf.append(frequencies)
# compute df (document frequency) per term
for term, _ in frequencies.items():
df_count = df.get(term, 0) + 1
df[term] = df_count
for term, freq in df.items():
idf[term] = math.log(1 + (corpus_size - freq + 0.5) / (freq + 0.5))
self.tf_ = tf
self.df_ = df
self.idf_ = idf
self.doc_len_ = doc_len
self.corpus_ = corpus
self.corpus_size_ = corpus_size
self.avg_doc_len_ = sum(doc_len) / corpus_size
return self
def search(self, query):
scores = [self._score(query, index) for index in range(self.corpus_size_)]
return scores
def _score(self, query, index):
score = 0.0
doc_len = self.doc_len_[index]
frequencies = self.tf_[index]
for term in query:
if term not in frequencies:
continue
freq = frequencies[term]
numerator = self.idf_[term] * freq * (self.k1 + 1)
denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len_)
score += (numerator / denominator)
return score
class Retrieval:
def __init__(
self, k=8,
model='retrieval/bm25.pt',
contexts='retrieval/context.pt',
stop_words='retrieval/stopwords.csv',
max_len = 400,
docs = None
) -> None:
self.k = k
self.max_len = max_len
data = pd.read_csv(stop_words, sep="\t", encoding='utf-8')
self.list_stopwords = data['stopwords']
if docs:
self.tuning(docs)
else:
self.bm25 = torch.load(model)
self.contexts = torch.load(contexts)
def get_context(self, query='Chảy máu chân răng là bệnh gì?'):
def clean_text(text):
text = re.sub('<.*?>', '', text).strip()
text = re.sub('(\s)+', r'\1', text)
return text
def normalize_text(text):
listpunctuation = string.punctuation.replace('_', '')
for i in listpunctuation:
text = text.replace(i, ' ')
return text.lower()
def remove_stopword(text):
pre_text = []
words = text.split()
for word in words:
if word not in self.list_stopwords:
pre_text.append(word)
text2 = ' '.join(pre_text)
return text2
def word_segment(sent):
sent = tokenize(sent.encode('utf-8').decode('utf-8'))
return sent
query = clean_text(query)
query = word_segment(query)
query = remove_stopword(normalize_text(query))
query = query.split()
scores = self.bm25.search(query)
scores_index = np.argsort(scores)
results = []
ss = []
for k in range(1, self.k+1):
index = scores_index[-k]
result = {'score_bm':scores[index], 'index':index, 'context':self.contexts[index]}
results.append(result)
ss.append(scores[index])
print("BM25:", ss)
return results
def split(self, document):
document = document.replace('\n', ' ')
document = re.sub(' +', ' ', document)
sentences = document.split('. ')
context_list = []
context = ""
length = 0
pre = ""
len__ = 0
for sentence in sentences:
sentence += '. '
len_ = len(sentence.split())
if length + len_ > self.max_len:
context_list.append(context)
context = pre
length = len__
length += len_
context += sentence
pre = sentence
len__ = len_
context_list.append(context)
self.contexts = context_list
if len(context_list) < self.k:
self.k = len(context_list)
def tuning(self, document):
def clean_text(text):
text = re.sub('<.*?>', '', text).strip()
text = re.sub('(\s)+', r'\1', text)
return text
def normalize_text(text):
listpunctuation = string.punctuation.replace('_', '')
for i in listpunctuation:
text = text.replace(i, ' ')
return text.lower()
def remove_stopword(text):
pre_text = []
words = text.split()
for word in words:
if word not in self.list_stopwords:
pre_text.append(word)
text2 = ' '.join(pre_text)
return text2
def word_segment(sent):
sent = tokenize(sent.encode('utf-8').decode('utf-8'))
return sent
self.split(document)
docs = []
for content in self.contexts:
content = clean_text(content)
content = word_segment(content)
content = remove_stopword(normalize_text(content))
docs.append(content)
print('There is', len(docs), 'contexts')
texts = [
[word for word in document.lower().split() if word not in self.list_stopwords]
for document in docs
]
self.bm25 = BM25()
self.bm25.fit(texts)