EdZ543 commited on
Commit
69d76de
1 Parent(s): 2c794be

Optimize recommendation calculation

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +30 -37
  3. data/{anime_indexes.csv → animes.csv} +2 -2
.gitignore CHANGED
@@ -158,3 +158,6 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
 
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
+
162
+ # Mac
163
+ *.DS_Store
app.py CHANGED
@@ -1,59 +1,52 @@
1
  """The main application file for the Gradio app."""
2
 
3
- import os
4
  import gradio as gr
5
  import pandas as pd
6
- import requests
7
  import torch
8
- import torch.nn as nn
9
 
10
- MAL_CLIENT_ID = os.getenv("MAL_CLIENT_ID")
 
11
 
12
- anime_indexes = pd.read_csv("./data/anime_indexes.csv")
13
- animes = anime_indexes["Anime"].values.tolist()
14
 
15
- anime_embeddings = pd.read_csv("./data/anime_embeddings.csv", header=None)
16
- anime_embeddings = torch.tensor(anime_embeddings.values)
17
 
 
 
18
 
19
- def fetch_anime_image_url(anime_id):
20
- url = f"https://api.myanimelist.net/v2/anime/{anime_id}?fields=main_picture"
21
- headers = {"X-MAL-CLIENT-ID": MAL_CLIENT_ID}
22
- response = requests.get(url, headers=headers)
23
- image_url = response["main_picture"]["large"]
24
- return image_url
25
 
 
 
 
 
 
 
26
 
27
- def recommend(anime):
28
- anime_index = anime_indexes[anime_indexes["Anime"] == anime].index[0]
29
- anime_embedding = anime_embeddings[anime_index][None]
30
 
31
- embedding_distances = nn.CosineSimilarity(dim=1)(anime_embeddings, anime_embedding)
32
- recommendation_indexes = embedding_distances.argsort(descending=True)[1:7].tolist()
33
- recommendations = [
34
- (
35
- "https://cdn.myanimelist.net/images/anime/1600/134703.jpg",
36
- anime_indexes.iloc[index]["Anime"],
37
- )
38
- for index in recommendation_indexes
39
- ]
40
 
41
- return recommendations
 
 
 
42
 
43
 
44
- with gr.Blocks() as space:
45
- gr.Markdown(
 
 
 
 
 
46
  """
47
- # Anime Collaborative Filtering System
48
- This is a Pytorch recommendation model that uses neural collaborative filtering.
49
- Enter an anime, and it will suggest similar shows!
50
- """
51
- )
52
 
53
- dropdown = gr.Dropdown(label="Enter an anime", choices=animes)
54
 
55
- gallery = gr.Gallery(label="Recommendations", rows=2, columns=3)
56
 
57
- dropdown.change(fn=recommend, inputs=dropdown, outputs=gallery)
58
 
59
  space.launch()
 
1
  """The main application file for the Gradio app."""
2
 
 
3
  import gradio as gr
4
  import pandas as pd
 
5
  import torch
 
6
 
7
+ animes_df = pd.read_csv("./data/animes.csv")
8
+ anime_embeddings_df = pd.read_csv("./data/anime_embeddings.csv", header=None)
9
 
10
+ title_list = animes_df["Title"].tolist()
11
+ embeddings = torch.tensor(anime_embeddings_df.values)
12
 
 
 
13
 
14
+ def recommend(index):
15
+ embedding = embeddings[index]
16
 
17
+ embedding_distances = torch.nn.CosineSimilarity(dim=1)(embeddings, embedding)
18
+ recommendation_indexes = embedding_distances.argsort(descending=True)[1:4]
 
 
 
 
19
 
20
+ recommendations = []
21
+ for rank, recommendation_index in enumerate(recommendation_indexes):
22
+ recommendation = animes_df.iloc[int(recommendation_index)]
23
+ value = recommendation["Image URL"]
24
+ label = f'{rank + 1}. {recommendation["Title"]}'
25
+ recommendations.append((value, label))
26
 
27
+ return recommendations
 
 
28
 
 
 
 
 
 
 
 
 
 
29
 
30
+ css = """
31
+ .gradio-container {align-items: center}
32
+ #container {max-width: 800px}
33
+ """
34
 
35
 
36
+ with gr.Blocks(css=css) as space:
37
+ with gr.Column(elem_id="container"):
38
+ gr.Markdown(
39
+ """
40
+ # Anime Collaborative Filtering System
41
+ This is a Pytorch recommendation model that uses neural collaborative filtering.
42
+ Enter an anime, and it will suggest similar shows!
43
  """
44
+ )
 
 
 
 
45
 
46
+ dropdown = gr.Dropdown(label="Enter an anime", choices=title_list, type="index")
47
 
48
+ gallery = gr.Gallery(label="Recommendations", rows=1, columns=3)
49
 
50
+ dropdown.change(fn=recommend, inputs=dropdown, outputs=gallery)
51
 
52
  space.launch()
data/{anime_indexes.csv → animes.csv} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:58f89061083a803c723125c3db497ee672f8236f7525aff943b637e0654bd463
3
- size 257312
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcc6118380dc53f88732a9811bb537357604ccacdf04aa409449fa9d5b9c45ee
3
+ size 646379