Script-Similarity / ScriptMatcher.py
Stanford-TH's picture
Upload folder using huggingface_hub
d862c41 verified
raw
history blame
4.07 kB
import pandas as pd
import numpy as np
from ast import literal_eval
import yake
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import os
class ScriptMatcher:
def __init__(self, data_path = None, model_name='paraphrase-mpnet-base-v2',dataframe = None):
"""
Initialize the SeriesMatcher object.
Parameters:
data_path (str): Path to the dataset file.
model_name (str): Name of the sentence transformer model. Default is 'paraphrase-mpnet-base-v2'.
"""
if data_path is not None:
self.dataset = pd.read_csv(data_path)
if dataframe is not None:
self.dataset = dataframe
self.model = SentenceTransformer(model_name)
self.kw_extractor = yake.KeywordExtractor("en", n=1, dedupLim=0.9)
self.k_dataset = pd.read_csv('models/Similarity_K_Dataset/K_Dataset.csv')
self._ent_type = ["PERSON","NORP","FAC","ORG","GPE","LOC","PRODUCT","EVENT","WORK","ART","LAW",
"LANGUAGE","DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]
self.embeddings_synopsis_list = np.load("models/Similarity_K_Dataset/plot_embeddings.npy")
self.plot_embedding_list = np.load("models/Similarity_K_Dataset/synopsis_embeddings.npy")
try:
self.nlp = spacy.load("en_core_web_sm")
except:
print("Downloading spaCy NLP model...")
os.system(
"pip install https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl")
self.nlp = spacy.load("en_core_web_sm")
def extract_keywords(self, text):
"""
Extract keywords from a given text using the YAKE keyword extraction algorithm.
Parameters:
text (str): Text from which to extract keywords.
Returns:
str: A string of extracted keywords joined by spaces.
"""
extracted_keywords = self.kw_extractor.extract_keywords(text)
return " ".join([keywords[0] for keywords in extracted_keywords if keywords[0] not in self._ent_type])
def preprocess_text(self, text):
"""
Process a given text to replace named entities and extract keywords.
Parameters:
text (str): The text to process.
Returns:
str: Processed text with named entities replaced and keywords extracted.
"""
doc = self.nlp(text)
replaced_text = text
for token in doc:
if token.ent_type_ != "MISC" and token.ent_type_ != "":
replaced_text = replaced_text.replace(token.text, f"<{token.ent_type_}>")
return self.extract_keywords(replaced_text)
def find_similar_series(self, new_synopsis, genres_keywords,k=5):
"""
Find series similar to a new synopsis.
Parameters:
new_synopsis (str): The synopsis to compare.
k (int): The number of similar series to return.
Returns:
pd.DataFrame: A dataframe of the closest series.
"""
processed_synopsis = self.preprocess_text(new_synopsis)
genre_keywords = " ".join(genres_keywords)
print(genre_keywords)
synopsis_sentence = genre_keywords + self.extract_keywords(processed_synopsis)
synopsis_embedding = self.model.encode([synopsis_sentence])
cosine_similarity_matrix = 0.75 * cosine_similarity(synopsis_embedding, self.embeddings_synopsis_list) + 0.25 * cosine_similarity(synopsis_embedding,self.plot_embedding_list)
top_k_indices = cosine_similarity_matrix.argsort()[0, -k:][::-1]
closest_series = self.k_dataset.iloc[top_k_indices]
# Add scores column
closest_series["Score"] = cosine_similarity_matrix[0, top_k_indices]
return closest_series[["Series", "Genre","Score"]].to_dict(orient='records')