SummerTime / model /query_based /tf_idf_model.py
aliabd
full demo working with old graido
7e3e85d
from .base_query_based_model import QueryBasedSummModel
from model.base_model import SummModel
from model.single_doc import TextRankModel
from typing import List
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
class TFIDFSummModel(QueryBasedSummModel):
# static variables
model_name = "TF-IDF"
is_extractive = True
is_neural = False
is_query_based = True
def __init__(
self,
trained_domain: str = None,
max_input_length: int = None,
max_output_length: int = None,
model_backend: SummModel = TextRankModel,
retrieval_ratio: float = 0.5,
preprocess: bool = True,
**kwargs
):
super(TFIDFSummModel, self).__init__(
trained_domain=trained_domain,
max_input_length=max_input_length,
max_output_length=max_output_length,
model_backend=model_backend,
retrieval_ratio=retrieval_ratio,
preprocess=preprocess,
**kwargs
)
self.vectorizer = TfidfVectorizer()
def _retrieve(self, instance: List[str], query: List[str], n_best):
instance_vectors = self.vectorizer.fit_transform(instance)
query_vector = self.vectorizer.transform(query)
similarities = cosine_similarity(query_vector, instance_vectors).squeeze()
top_n_index = similarities.argsort()[::-1][0:n_best]
top_n_sent = [instance[ind] for ind in top_n_index] # List[str]
return top_n_sent