File size: 1,724 Bytes
6ef1d8c
 
b2200f4
d6eaf98
6ef1d8c
b2200f4
69d76de
 
6ef1d8c
69d76de
 
b2200f4
 
69d76de
 
d6eaf98
69d76de
 
d6eaf98
69d76de
 
 
 
 
 
d6eaf98
69d76de
6ef1d8c
 
69d76de
 
e7139ac
69d76de
d6eaf98
 
69d76de
 
 
 
 
 
d17c670
 
d6eaf98
69d76de
d6eaf98
69d76de
d6eaf98
e7139ac
d6eaf98
69d76de
d6eaf98
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""The main application file for the Gradio app."""

import gradio as gr
import pandas as pd
import torch

animes_df = pd.read_csv("./data/animes.csv")
anime_embeddings_df = pd.read_csv("./data/anime_embeddings.csv", header=None)

title_list = animes_df["Title"].tolist()
embeddings = torch.tensor(anime_embeddings_df.values)


def recommend(index):
    embedding = embeddings[index]

    embedding_distances = torch.nn.CosineSimilarity(dim=1)(embeddings, embedding)
    recommendation_indexes = embedding_distances.argsort(descending=True)[1:4]

    recommendations = []
    for rank, recommendation_index in enumerate(recommendation_indexes):
        recommendation = animes_df.iloc[int(recommendation_index)]
        value = recommendation["Image URL"]
        label = f'{rank + 1}. {recommendation["Title"]}'
        recommendations.append((value, label))

    return recommendations


css = """
.gradio-container {align-items: center}
#container {max-width: 795px}
"""


with gr.Blocks(css=css) as space:
    with gr.Column(elem_id="container"):
        gr.Markdown(
            """
        # Anime Collaborative Filtering System
        This is a Pytorch recommendation model that uses neural collaborative filtering.
        Enter an anime, and it will suggest similar shows! \
        Source code: [https://github.com/EdZ543/anime-collaborative-filtering-system](https://github.com/EdZ543/anime-collaborative-filtering-system)
        """
        )

        dropdown = gr.Dropdown(label="Enter an anime", choices=title_list, type="index")

        gallery = gr.Gallery(label="Recommendations", rows=1, columns=3, height="265")

        dropdown.change(fn=recommend, inputs=dropdown, outputs=gallery)

space.launch()