Spaces:
Runtime error
Runtime error
File size: 5,761 Bytes
fd6baf3 f416fa7 0c2802d afb269c fd6baf3 c146eca 4e929c8 fd6baf3 5979534 a602488 5979534 713fc75 efd2261 e11db98 713fc75 fd6baf3 a7f90e5 fd6baf3 a7f90e5 fd6baf3 a602488 fd6baf3 713fc75 fd6baf3 969df28 fd6baf3 9bf9a47 0a44459 fd6baf3 742e2ed c146eca fd6baf3 742e2ed c146eca fd6baf3 a0e6621 fd6baf3 a7f90e5 fd6baf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from datasets import load_dataset
import pandas as pd
import torch
import gradio as gr
import whisper
#Get the netflix dataset
netflix = load_dataset('hugginglearners/netflix-shows',use_auth_token=True)
#Filter for relevant columns and convert to pandas
netflix_df = netflix['train'].to_pandas()
netflix_df = netflix_df[['type','title','country','description','release_year','rating','duration','listed_in','cast']]
passages = netflix_df['description'].tolist()
#load mpnet model
model = SentenceTransformer('all-mpnet-base-v2')
#load embeddings
flix_ds = load_dataset("nickmuchi/netflix-shows-mpnet-embeddings", use_auth_token=True)
dataset_embeddings = torch.from_numpy(flix_ds["train"].to_pandas().to_numpy()).to(torch.float)
#load cross-encoder for reranking
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
def display_df_as_table(model,top_k,score='score'):
# Display the df with text and scores as a table
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
df['Score'] = round(df['Score'].astype(float),2)
df = df.merge(netflix_df,how='inner',left_on='Text',right_on='description')
df.drop('Text',inplace=True,axis=1)
return df
#load ASR model
def asr(audio):
asr_model = whisper.load_model("small")
results = asr_model.transcribe(audio)
query = results['text']
return query
#function for generating similarity of query and netflix shows
def semantic_search(query,top_k):
'''Encode query and check similarity with embeddings'''
question_embedding = model.encode(query, convert_to_tensor=True).cpu()
hits = util.semantic_search(question_embedding, dataset_embeddings, top_k=top_k)
hits = hits[0]
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp = [[query, netflix_df['description'].iloc[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
#Bi-encoder df
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
bi_df = display_df_as_table(hits,top_k)
#Cross encoder df
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
cross_df = display_df_as_table(hits,top_k,'cross-score')
cross_df['Score'] = round(cross_df['Score'].astype(float),2)
return bi_df, cross_df
title = """<h1 id="title">Voice Activated Netflix Shows Semantic Search</h1>"""
description = """
Semantic Search is a way to generate search results based on the actual meaning of the query instead of a standard keyword search. I believe this way of searching provides more meaning results when trying to find a good show to watch on Netflix. For example, one could search for "Success, rags to riches story" as provided in the example below to generate shows or movies with a description that is semantically similar to the query.
- The App generates embeddings using [All-Mpnet-Base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model from Sentence Transformers.
- The model encodes the query and the discerption field from the [Netflix-Shows](https://huggingface.co/datasets/hugginglearners/netflix-shows) dataset which contains 8800 shows and movies currently on Netflix scraped from the web using Selenium.
- Similarity scores are then generated, from highest to lowest. The user can select how many suggestions they need from the results.
- A Cross Encoder then re-ranks the top selections to further improve on the similarity scores.
- You will see 2 tables generated, one from the bi-encoder and the other from the cross encoder which further enhances the similarity score rankings
Enjoy and Search like you mean it!!
"""
example_queries = ["Success, rags to riches","murder, crime scene investigation thriller"]
twitter_link = """
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
"""
css = '''
h1#title {
text-align: center;
}
'''
demo = gr.Blocks(css=css)
with demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(twitter_link)
top_k = gr.Slider(minimum=3,maximum=10,value=5,step=1,label='Number of Suggestions to Generate')
with gr.Row():
audio = gr.Audio(source='microphone',type='filepath',label='Audio Input: Describe the Netflix show you would like to watch..')
with gr.Row():
query = gr.Textbox(label='Transcribed Text')
audio.change(asr,audio,query)
with gr.Row():
bi_output = gr.DataFrame(headers=['Similarity Score','Type','Title','Country','Description','Release Year','Rating','Duration','Category Listing','Cast'],
label=f'Top-{top_k} Bi-Encoder Retrieval hits', wrap=True)
with gr.Row():
cross_output = gr.DataFrame(headers=['Similarity Score','Type','Title','Country','Description','Release Year','Rating','Duration','Category Listing','Cast'],
label=f'Top-{top_k} Cross-Encoder Re-ranker hits', wrap=True)
with gr.Row():
examples = gr.Examples(examples=example_queries,inputs=[query])
sem_but = gr.Button('Search')
sem_but.click(semantic_search,inputs=[query,top_k],outputs=[bi_output,cross_output],queue=True)
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-netflix-shows-semantic-search)")
demo.launch(debug=True,enable_queue=True) |