|
import os |
|
import json |
|
|
|
import gradio as gr |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
def get_n_weighted_scores(embeddings, query, n, objective_weight, subjective_weight): |
|
query = [model.encode(query)] |
|
|
|
weighted_scores = [] |
|
|
|
for key, value in embeddings.items(): |
|
objective_embedding = value['objective_embedding'] |
|
subjective_embeddings = value['subjective_embeddings'] |
|
|
|
objective_score = cosine_similarity(query, objective_embedding).item() |
|
subjective_scores = cosine_similarity(query, subjective_embeddings) |
|
|
|
max_score = 0 |
|
max_review_index = 0 |
|
for idx, score in enumerate(subjective_scores[0].tolist()): |
|
weighted_score = ((objective_score * objective_weight)+(score * subjective_weight)) |
|
if weighted_score > max_score: |
|
max_score = weighted_score |
|
max_review_index = idx |
|
|
|
weighted_scores.append((key, max_score, max_review_index)) |
|
|
|
return sorted(weighted_scores, key=lambda x: x[1], reverse=True)[:n] |
|
|
|
def filter_anime(embeddings, genres, themes, rating): |
|
genres = set(genres) |
|
themes = set(themes) |
|
rating = set(rating) |
|
|
|
filtered_anime = embeddings.copy() |
|
for key, anime in embeddings.items(): |
|
|
|
anime_genres = set(anime['genres']) |
|
anime_themes = set(anime['themes']) |
|
anime_rating = set([anime['rating']]) |
|
|
|
if genres.intersection(anime_genres) or 'ALL' in genres: |
|
pass |
|
else: |
|
filtered_anime.pop(key) |
|
continue |
|
if themes.intersection(anime_themes) or 'ALL' in themes: |
|
pass |
|
else: |
|
filtered_anime.pop(key) |
|
continue |
|
if rating.intersection(anime_rating) or 'ALL' in rating: |
|
pass |
|
else: |
|
filtered_anime.pop(key) |
|
continue |
|
|
|
return filtered_anime |
|
|
|
def get_recommendation(query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight): |
|
filtered_anime = filter_anime(embeddings, genres, themes, rating) |
|
results = [] |
|
weighted_scores = get_n_weighted_scores(filtered_anime, query, number_of_recommendations, float(objective_weight), float(subjective_weight)) |
|
for idx, (key, score, review_index) in enumerate(weighted_scores, start=1): |
|
data = embeddings[key] |
|
if not data['english']: |
|
name = data['japanese'] |
|
else: |
|
name = data['english'] |
|
description = data['description'] |
|
review = data['reviews'][review_index]['text'] |
|
image = data['image'] |
|
|
|
results.append(gr.Image(label=f"Recommendation {idx}: {name}",value=image, height=435, width=500, visible=True)) |
|
results.append(gr.Textbox(label=f"Synopsis", value=description, max_lines=7, visible=True)) |
|
results.append(gr.Textbox(label=f"Most Relevant User Review",value=review, max_lines=7, visible=True)) |
|
|
|
for _ in range(10-number_of_recommendations): |
|
results.append(gr.Image(visible=False)) |
|
results.append(gr.Textbox(visible=False)) |
|
results.append(gr.Textbox(visible=False)) |
|
|
|
return results |
|
|
|
if __name__ == '__main__': |
|
|
|
with open('./embeddings/data.json') as f: |
|
data = json.load(f) |
|
embeddings = data['embeddings'] |
|
filters = data['filters'] |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue='red')) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
''' |
|
# Welcome to the Nuanced Recommendation System! |
|
### This system **combines** both objective (synopsis, episode count, themes) and subjective (user reviews) data, in order to recommend the most approprate anime. Feel free to refine using the **optional** filters below! |
|
''' |
|
) |
|
with gr.Column(): |
|
pass |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column() as input_col: |
|
query = gr.Textbox(label="What are you looking for?") |
|
number_of_recommendations = gr.Slider(label= "# of Recommendations", minimum=1, maximum=10, value=3, step=1) |
|
genres = gr.Dropdown(label='Genres',multiselect=True,choices=filters['genres'], value=['ALL']) |
|
themes = gr.Dropdown(label='Themes',multiselect=True,choices=filters['themes'], value=['ALL']) |
|
rating = gr.Dropdown(label='Rating',multiselect=True,choices=filters['rating'], value=['ALL']) |
|
objective_weight = gr.Slider(label= "Objective Weight", minimum=0, maximum=1, value=.5, step=.1) |
|
subjective_weight = gr.Slider(label= "Subjective Weight", minimum=0, maximum=1, value=.5, step=.1) |
|
submit_btn = gr.Button("Submit") |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
['A sci-fi anime set in a future where AI and robots have become self-aware', 3, ['Action', 'Sci-Fi', 'Fantasy'], ['ALL'], ['PG-13 - Teens 13 or older'], .8, .2], |
|
['An anime where a group of students form a band, and the story focuses on their personal growth and struggles with adulthood', 5, ['ALL'], ['Music'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .3, .7], |
|
['An anime where the main character starts as a villain but slowly redeems themselves', 3, ['Suspense', 'Action'], ['ALL'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .2, .8], |
|
], |
|
inputs=[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight], |
|
) |
|
|
|
outputs = [] |
|
with gr.Column(): |
|
for i in range(10): |
|
with gr.Row(): |
|
with gr.Column(): |
|
outputs.append(gr.Image(height=435, width=500, visible=False)) |
|
with gr.Column(): |
|
outputs.append(gr.Textbox(max_lines=7, visible=False)) |
|
outputs.append(gr.Textbox(max_lines=7, visible=False)) |
|
|
|
|
|
submit_btn.click( |
|
get_recommendation, |
|
[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight], |
|
outputs |
|
) |
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|