picardle / app.py
m-butler's picture
Upload folder using huggingface_hub
f266d24 verified
import json
import numpy as np
from wordllama import WordLlama
import gradio as gr
from numpy.linalg import norm
import os
# Load episodes data
with open('processed_episodes.json', 'r') as f:
episodes = json.load(f)
# Load the WordLlama model with 256 dimensions
wl = WordLlama.load(trunc_dim=256)
# Check if embeddings are already cached
if os.path.exists('summary_embeddings.npy'):
# Load embeddings from cache
summary_embeddings = np.load('summary_embeddings.npy')
else:
# Compute embeddings for all summaries
summaries = [episode['summary'] for episode in episodes]
summary_embeddings = wl.embed(summaries)
summary_embeddings = np.array(summary_embeddings)
# Save embeddings to cache
np.save('summary_embeddings.npy', summary_embeddings)
# Define the function to find matching episodes
def find_matching_episodes(query, top_k=5):
# Compute the embedding for the query
query_embedding = wl.embed([query])[0] # The result is a 1D numpy array
# Normalize embeddings
query_norm = query_embedding / (norm(query_embedding) + 1e-10)
summaries_norm = summary_embeddings / (norm(summary_embeddings, axis=1, keepdims=True) + 1e-10)
# Compute cosine similarities
similarities = summaries_norm @ query_norm
# Get indices of the top_k most similar summaries
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
# Retrieve the matching episodes and their similarity scores
matching_episodes = []
for idx in top_k_indices:
episode = episodes[idx]
similarity_score = similarities[idx]
# Create a list of values instead of a dictionary
result = [
episode['episode_number'],
episode['title'],
f"{similarity_score:.4f}"
]
matching_episodes.append(result)
return matching_episodes
# Create the Gradio interface
# Create the Gradio interface
interface = gr.Interface(
fn=find_matching_episodes,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your query here...", label="Search Query"),
gr.Slider(minimum=1, maximum=10, value=5, label="Number of Results")
],
outputs=gr.Dataframe(
headers=["Episode Number", "Title", "Similarity Score"],
label="Matching Episodes"
),
title="Picardle",
description="Enter a query to find matching ST:TNG episodes based on their summaries."
)
# Launch the app
if __name__ == "__main__":
interface.launch(share=True)