EdZ543 commited on
Commit
6ef1d8c
1 Parent(s): 5a5774a
Files changed (2) hide show
  1. app.py +25 -4
  2. flagged/log.csv +0 -3
app.py CHANGED
@@ -1,29 +1,48 @@
 
 
1
  import os
2
  import gradio as gr
3
  import pandas as pd
4
- import numpy as np
5
  import requests
 
 
6
  from urllib import parse
7
  from dotenv import load_dotenv
8
 
9
  load_dotenv()
10
 
 
 
11
  anime_indexes = pd.read_csv("./data/anime_indexes.csv")
12
  animes = anime_indexes["Anime"].values.tolist()
13
 
14
- MAL_CLIENT_ID = os.getenv("MAL_CLIENT_ID")
 
15
 
16
 
17
  def fetch_anime_image(anime):
18
  query_url = f"https://api.myanimelist.net/v2/anime?q={parse.quote(anime)}&limit=1"
19
  headers = {"X-MAL-CLIENT-ID": MAL_CLIENT_ID}
20
  query_response = requests.get(query_url, headers=headers)
 
21
  image_url = query_response.json()["data"][0]["node"]["main_picture"]["large"]
22
  return image_url
23
 
24
 
25
  def recommend(anime):
26
- return None
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  css = """
@@ -46,7 +65,7 @@ with gr.Blocks(css=css) as space:
46
  dropdown = gr.Dropdown(container=False, choices=animes)
47
  selection_image = gr.Image(show_label=False, width=225, visible=False)
48
 
49
- gallery = gr.Gallery(label="Recommendations")
50
 
51
  def submit(anime):
52
  if anime is None:
@@ -56,9 +75,11 @@ with gr.Blocks(css=css) as space:
56
  }
57
 
58
  selection_image_url = fetch_anime_image(anime)
 
59
 
60
  return {
61
  selection_image: gr.update(visible=True, value=selection_image_url),
 
62
  }
63
 
64
  dropdown.change(fn=submit, inputs=dropdown, outputs=[selection_image, gallery])
 
1
+ """The main application file for the Gradio app."""
2
+
3
  import os
4
  import gradio as gr
5
  import pandas as pd
 
6
  import requests
7
+ import torch
8
+ import torch.nn as nn
9
  from urllib import parse
10
  from dotenv import load_dotenv
11
 
12
  load_dotenv()
13
 
14
+ MAL_CLIENT_ID = os.getenv("MAL_CLIENT_ID")
15
+
16
  anime_indexes = pd.read_csv("./data/anime_indexes.csv")
17
  animes = anime_indexes["Anime"].values.tolist()
18
 
19
+ anime_embeddings = pd.read_csv("./data/anime_embeddings.csv", header=None)
20
+ anime_embeddings = torch.tensor(anime_embeddings.values)
21
 
22
 
23
  def fetch_anime_image(anime):
24
  query_url = f"https://api.myanimelist.net/v2/anime?q={parse.quote(anime)}&limit=1"
25
  headers = {"X-MAL-CLIENT-ID": MAL_CLIENT_ID}
26
  query_response = requests.get(query_url, headers=headers)
27
+
28
  image_url = query_response.json()["data"][0]["node"]["main_picture"]["large"]
29
  return image_url
30
 
31
 
32
  def recommend(anime):
33
+ anime_index = anime_indexes[anime_indexes["Anime"] == anime].index[0]
34
+ anime_embedding = anime_embeddings[anime_index][None]
35
+
36
+ embedding_distances = nn.CosineSimilarity(dim=1)(anime_embeddings, anime_embedding)
37
+ recommendation_indexes = embedding_distances.argsort(descending=True)[1:6].tolist()
38
+
39
+ recommendations = []
40
+ for recommendation_index in recommendation_indexes:
41
+ recommendation_anime = anime_indexes.iloc[recommendation_index]["Anime"]
42
+ recommendation_url = fetch_anime_image(recommendation_anime)
43
+ recommendations.append((recommendation_url, recommendation_anime))
44
+
45
+ return recommendations
46
 
47
 
48
  css = """
 
65
  dropdown = gr.Dropdown(container=False, choices=animes)
66
  selection_image = gr.Image(show_label=False, width=225, visible=False)
67
 
68
+ gallery = gr.Gallery(label="Recommendations", object_fit="scale-down")
69
 
70
  def submit(anime):
71
  if anime is None:
 
75
  }
76
 
77
  selection_image_url = fetch_anime_image(anime)
78
+ recommendations = recommend(anime)
79
 
80
  return {
81
  selection_image: gr.update(visible=True, value=selection_image_url),
82
+ gallery: gr.update(value=recommendations),
83
  }
84
 
85
  dropdown.change(fn=submit, inputs=dropdown, outputs=[selection_image, gallery])
flagged/log.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c58b7e483af0b78bf9bfb21459a1f1eccfba65651d24ba7f236531c195b352f3
3
- size 137