|
import datasets |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import gradio as gr |
|
from gradio.components import Label |
|
|
|
|
|
|
|
|
|
dataset = datasets.load_dataset("SandipPalit/Movie_Dataset") |
|
title = dataset['train']['Title'] |
|
overview = dataset['train']['Overview'] |
|
|
|
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") |
|
|
|
vectors = model.encode(overview) |
|
|
|
vector_dimension = vectors.shape[1] |
|
index = faiss.IndexFlatL2(vector_dimension) |
|
faiss.normalize_L2(vectors) |
|
index.add(vectors) |
|
|
|
def get_model_generated_vector(text): |
|
search_vector = model.encode(text) |
|
vector = np.array([search_vector]) |
|
faiss.normalize_L2(vector) |
|
return vector |
|
|
|
def find_top_k_matched(vector): |
|
distances, ann = index.search(vector, k=5) |
|
return [title[ann[0][0]], title[ann[0][1]], title[ann[0][2]], title[ann[0][3]], title[ann[0][4]]] |
|
|
|
|
|
def movie_recommandation(text): |
|
vector = get_model_generated_vector(text) |
|
matches = find_top_k_matched(vector) |
|
|
|
return matches[0], matches[1], matches[2], matches[3], matches[4] |
|
|
|
demo = gr.Interface( |
|
fn=movie_recommandation, |
|
inputs=gr.Textbox(placeholder="Enter the Movie Name"), |
|
outputs=[Label() for i in range(5)], |
|
examples=[["Scarlet Macaw on Perch"], ["horror"]]) |
|
|
|
demo.launch(debug=True) |