EdZ543
commited on
Commit
•
6ef1d8c
1
Parent(s):
5a5774a
Add selections
Browse filesFormer-commit-id: cd3341c2e571b90ecfb34ffea9ab056c1e7f3e45
- app.py +25 -4
- flagged/log.csv +0 -3
app.py
CHANGED
@@ -1,29 +1,48 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
import requests
|
|
|
|
|
6 |
from urllib import parse
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
load_dotenv()
|
10 |
|
|
|
|
|
11 |
anime_indexes = pd.read_csv("./data/anime_indexes.csv")
|
12 |
animes = anime_indexes["Anime"].values.tolist()
|
13 |
|
14 |
-
|
|
|
15 |
|
16 |
|
17 |
def fetch_anime_image(anime):
|
18 |
query_url = f"https://api.myanimelist.net/v2/anime?q={parse.quote(anime)}&limit=1"
|
19 |
headers = {"X-MAL-CLIENT-ID": MAL_CLIENT_ID}
|
20 |
query_response = requests.get(query_url, headers=headers)
|
|
|
21 |
image_url = query_response.json()["data"][0]["node"]["main_picture"]["large"]
|
22 |
return image_url
|
23 |
|
24 |
|
25 |
def recommend(anime):
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
css = """
|
@@ -46,7 +65,7 @@ with gr.Blocks(css=css) as space:
|
|
46 |
dropdown = gr.Dropdown(container=False, choices=animes)
|
47 |
selection_image = gr.Image(show_label=False, width=225, visible=False)
|
48 |
|
49 |
-
gallery = gr.Gallery(label="Recommendations")
|
50 |
|
51 |
def submit(anime):
|
52 |
if anime is None:
|
@@ -56,9 +75,11 @@ with gr.Blocks(css=css) as space:
|
|
56 |
}
|
57 |
|
58 |
selection_image_url = fetch_anime_image(anime)
|
|
|
59 |
|
60 |
return {
|
61 |
selection_image: gr.update(visible=True, value=selection_image_url),
|
|
|
62 |
}
|
63 |
|
64 |
dropdown.change(fn=submit, inputs=dropdown, outputs=[selection_image, gallery])
|
|
|
1 |
+
"""The main application file for the Gradio app."""
|
2 |
+
|
3 |
import os
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
|
|
6 |
import requests
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
from urllib import parse
|
10 |
from dotenv import load_dotenv
|
11 |
|
12 |
load_dotenv()
|
13 |
|
14 |
+
MAL_CLIENT_ID = os.getenv("MAL_CLIENT_ID")
|
15 |
+
|
16 |
anime_indexes = pd.read_csv("./data/anime_indexes.csv")
|
17 |
animes = anime_indexes["Anime"].values.tolist()
|
18 |
|
19 |
+
anime_embeddings = pd.read_csv("./data/anime_embeddings.csv", header=None)
|
20 |
+
anime_embeddings = torch.tensor(anime_embeddings.values)
|
21 |
|
22 |
|
23 |
def fetch_anime_image(anime):
|
24 |
query_url = f"https://api.myanimelist.net/v2/anime?q={parse.quote(anime)}&limit=1"
|
25 |
headers = {"X-MAL-CLIENT-ID": MAL_CLIENT_ID}
|
26 |
query_response = requests.get(query_url, headers=headers)
|
27 |
+
|
28 |
image_url = query_response.json()["data"][0]["node"]["main_picture"]["large"]
|
29 |
return image_url
|
30 |
|
31 |
|
32 |
def recommend(anime):
|
33 |
+
anime_index = anime_indexes[anime_indexes["Anime"] == anime].index[0]
|
34 |
+
anime_embedding = anime_embeddings[anime_index][None]
|
35 |
+
|
36 |
+
embedding_distances = nn.CosineSimilarity(dim=1)(anime_embeddings, anime_embedding)
|
37 |
+
recommendation_indexes = embedding_distances.argsort(descending=True)[1:6].tolist()
|
38 |
+
|
39 |
+
recommendations = []
|
40 |
+
for recommendation_index in recommendation_indexes:
|
41 |
+
recommendation_anime = anime_indexes.iloc[recommendation_index]["Anime"]
|
42 |
+
recommendation_url = fetch_anime_image(recommendation_anime)
|
43 |
+
recommendations.append((recommendation_url, recommendation_anime))
|
44 |
+
|
45 |
+
return recommendations
|
46 |
|
47 |
|
48 |
css = """
|
|
|
65 |
dropdown = gr.Dropdown(container=False, choices=animes)
|
66 |
selection_image = gr.Image(show_label=False, width=225, visible=False)
|
67 |
|
68 |
+
gallery = gr.Gallery(label="Recommendations", object_fit="scale-down")
|
69 |
|
70 |
def submit(anime):
|
71 |
if anime is None:
|
|
|
75 |
}
|
76 |
|
77 |
selection_image_url = fetch_anime_image(anime)
|
78 |
+
recommendations = recommend(anime)
|
79 |
|
80 |
return {
|
81 |
selection_image: gr.update(visible=True, value=selection_image_url),
|
82 |
+
gallery: gr.update(value=recommendations),
|
83 |
}
|
84 |
|
85 |
dropdown.change(fn=submit, inputs=dropdown, outputs=[selection_image, gallery])
|
flagged/log.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c58b7e483af0b78bf9bfb21459a1f1eccfba65651d24ba7f236531c195b352f3
|
3 |
-
size 137
|
|
|
|
|
|
|
|