Bokanovskii commited on
Commit
952a8f6
·
1 Parent(s): 01b4f11

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +364 -0
  2. style.css +42 -0
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spotipy
3
+ from spotipy import oauth2
4
+
5
+ from transformers import ViTForImageClassification, ViTImageProcessor
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from torchvision.io import read_image
9
+
10
+ import tensorflow as tf
11
+
12
+ from fastapi import FastAPI
13
+ from starlette.middleware.sessions import SessionMiddleware
14
+ from starlette.responses import HTMLResponse, RedirectResponse
15
+ from starlette.requests import Request
16
+ import gradio as gr
17
+ import uvicorn
18
+ from fastapi.responses import HTMLResponse
19
+ from fastapi.responses import RedirectResponse
20
+
21
+ import numpy as np
22
+ import base64
23
+ from io import BytesIO
24
+ from PIL import Image
25
+
26
+ import shred_model
27
+
28
+ # Xception fine tuned from pretrained imagenet weights for identifying Sraddha
29
+ SRADDHA_MODEL_PATH = "shred_model"
30
+ SHRED_MODEL = tf.keras.models.load_model(SRADDHA_MODEL_PATH)
31
+
32
+ SPOTIPY_TOKEN = None # Set in the homepage function
33
+
34
+ def main(img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request):
35
+ if img is None:
36
+ return None
37
+
38
+ mood_dict = get_image_mood_dict_from_transformer(img)
39
+ sraddha_found = get_sraddha(img)
40
+
41
+ playlist = get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request)
42
+ if playlist is None:
43
+ playlist = "Spotipy account token not set"
44
+
45
+ ret = playlist
46
+ if sraddha_found:
47
+ sraddha_msg = """Sraddha, you are the love of my life and seeing you always lifts my spirits. Hopefully these tunes can do the same for you.
48
+ <p>
49
+ </p>
50
+ - With Love, Scoob"""
51
+ return gr.update(value=ret, visible=True), gr.update(value=sraddha_msg, visible=True)
52
+ return gr.update(value=ret, visible=True), gr.update(visible=False)
53
+
54
+ def get_image_mood_dict_from_transformer(img):
55
+ img = read_image(img)
56
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+
58
+ model = ViTForImageClassification.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection")
59
+ model.eval()
60
+ model.to(device)
61
+
62
+ feature_extractor = ViTImageProcessor.from_pretrained("jayanta/google-vit-base-patch16-224-cartoon-emotion-detection")
63
+ encoding = feature_extractor(images=img, return_tensors="pt")
64
+ pixel_values = encoding['pixel_values'].to(device)
65
+ outputs = model(pixel_values)
66
+
67
+ logits = outputs.logits
68
+ probabilities = F.softmax(logits, dim = -1).detach().numpy()[0]
69
+ mood_dict = dict(zip(model.config.id2label.values(), probabilities))
70
+ return mood_dict
71
+
72
+ def get_sraddha(img):
73
+ fixed_img = shred_model.prepare_image(img)
74
+ prob = SHRED_MODEL.predict(fixed_img)[0]
75
+ if prob >= .5:
76
+ return True
77
+
78
+ def compute_mood(mood_dict):
79
+ print(mood_dict)
80
+ return mood_dict['happy'] + mood_dict['angry'] * .5 + mood_dict['sad'] * .1
81
+
82
+ def get_playlist(mood_dict, img, playlist_length, privacy, gen_mode, genre_choice, request: gr.Request):
83
+ token = request.request.session.get('token')
84
+ genre_map = {'Rock': ['alt-rock', 'alternative', 'indie', 'r-n-b', 'rock'], 'Hip-hop': ['hip-hop'], 'Party': ['club', 'dance', 'house', 'pop', 'party'], 'Mellow': ['blues', 'chill', 'classical', 'jazz', 'happy'], 'Indian': ['idm', 'indian'], 'Pop': ['pop', 'new-age'], 'Study': ['study', 'classical', 'jazz', 'happy', 'chill'], 'Romance': ['romance', 'happy', 'pop']}
85
+
86
+ if token:
87
+ playlist_name = "Mood " + str(round(compute_mood(mood_dict) * 100, 1))
88
+ sp = spotipy.Spotify(token)
89
+
90
+ if gen_mode == 'Recently Played':
91
+ top_tracks_uri = set(sp.current_user_recently_played(limit=50))
92
+ # I honestly don't know if this errors for people with not enough saved tracks
93
+ # Shouldn't be a problem for Sraddha
94
+ first_few = [x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50)['items']]
95
+ top_tracks_uri.update(first_few)
96
+ top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=50)['items']])
97
+ top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=100)['items']])
98
+ top_tracks_uri.update([x['track']['uri'] for x in sp.current_user_saved_tracks(limit=50, offset=150)['items']])
99
+ top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[:5], limit=50)['tracks']])
100
+ top_tracks_uri.update([x['uri'] for x in sp.recommendations(seed_tracks=first_few[5:10], limit=50)['tracks']])
101
+ top_tracks_uri = list(top_tracks_uri)
102
+ elif gen_mode == 'By a Chosen Genre':
103
+ genres = genre_map[genre_choice]
104
+ mood = compute_mood(mood_dict)
105
+ final_track_list = [x['uri'] for x in sp.recommendations(
106
+ seed_genres=genres, limit=playlist_length, max_valence=mood+.15,
107
+ min_valence=mood-.15, min_danceability=mood/1.75, max_danceability=mood*8,
108
+ min_energy=mood/2)['tracks']]
109
+ else:
110
+ top_artists_uri = aggregate_favorite_artists(sp)
111
+ top_tracks_uri = aggregate_top_tracks(sp, top_artists_uri)
112
+
113
+ if gen_mode != 'By a Chosen Genre':
114
+ final_track_list = filter_tracks(sp, top_tracks_uri, mood_dict, playlist_length)
115
+
116
+ # If no tracks fit the filter: generate some results anyways
117
+ if len(final_track_list) == 0:
118
+ if gen_mode == 'By a Chosen Genre':
119
+ final_track_list = [x['uri'] for x in sp.recommendations(
120
+ seed_genres=genres, limit=playlist_length,
121
+ min_valence=mood-.3, min_energy=mood/3)['tracks']]
122
+ else:
123
+ seed = sp.current_user_recently_played(limit=5)
124
+ final_track_list = [x['uri'] for x in sp.recommendations(
125
+ seed_tracks=seed, limit=playlist_length,
126
+ min_valence=mood-.3, min_energy=mood/3)['tracks']]
127
+
128
+ iframe_embedding = create_playlist(sp, img, final_track_list, playlist_name,
129
+ privacy)
130
+ return iframe_embedding
131
+ return None
132
+
133
+ def create_playlist(sp, img, tracks, playlist_name, privacy):
134
+ privacy = privacy == "Public"
135
+ user_id = sp.current_user()['id']
136
+ playlist_description = "This playlist was created using the img-to-music application built by the best boyfriend there ever was and ever will be"
137
+ playlist_data = sp.user_playlist_create(user_id, playlist_name, public=privacy,
138
+ description=playlist_description)
139
+ playlist_id = playlist_data['id']
140
+ if len(tracks) == 0:
141
+ return """No tracks could be generated from this image"""
142
+ sp.user_playlist_add_tracks(user_id, playlist_id, tracks)
143
+ try:
144
+ with Image.open(img) as im_file:
145
+ im_file.thumbnail((300, 300))
146
+ buffered = BytesIO()
147
+ im_file.save(buffered, format="JPEG")
148
+ img_str = base64.b64encode(buffered.getvalue())
149
+
150
+ sp.playlist_upload_cover_image(playlist_id, img_str)
151
+ except spotipy.exceptions.SpotifyException:
152
+ print("Image file too large, couldn't upload")
153
+
154
+ iframe_embedding = f"""<iframe style="border-radius:12px" src="https://open.spotify.com/embed/playlist/{playlist_id}" width="100%" height="352" frameBorder="0" allowfullscreen="" allow="autoplay; clipboard-write; encrypted-media; fullscreen; picture-in-picture" loading="lazy"></iframe>"""
155
+ return iframe_embedding
156
+
157
+ def aggregate_favorite_artists(sp):
158
+ top_artists_name = set()
159
+ top_artists_uri = []
160
+
161
+ ranges = ['short_term', 'medium_term', 'long_term']
162
+ for r in ranges:
163
+ top_artists_all_data = sp.current_user_top_artists(limit=50, time_range=r)
164
+ top_artists_data = top_artists_all_data['items']
165
+ for artist_data in top_artists_data:
166
+ if artist_data["name"] not in top_artists_name:
167
+ top_artists_name.add(artist_data['name'])
168
+ top_artists_uri.append(artist_data['uri'])
169
+
170
+ followed_artists_all_data = sp.current_user_followed_artists(limit=50)
171
+ followed_artsits_data = followed_artists_all_data['artists']
172
+ for artist_data in followed_artsits_data['items']:
173
+ if artist_data["name"] not in top_artists_name:
174
+ top_artists_name.add(artist_data['name'])
175
+ top_artists_uri.append(artist_data['uri'])
176
+
177
+ # attempt to garauntee 200 artists
178
+ i = 0
179
+ while len(top_artists_uri) < 200:
180
+ related_artists_all_data = sp.artist_related_artists(top_artists_uri[i])
181
+ i += 1
182
+ related_artists_data = related_artists_all_data['artists']
183
+ for artist_data in related_artists_data:
184
+ if artist_data["name"] not in top_artists_name:
185
+ top_artists_name.add(artist_data['name'])
186
+ top_artists_uri.append(artist_data['uri'])
187
+ if i == len(top_artists_uri):
188
+ # could build in a deeper artist recommendation finder here
189
+ # would do this if it was going to production but Sraddha follows lots of artists
190
+ break
191
+
192
+ return top_artists_uri
193
+
194
+ def aggregate_top_tracks(sp, top_artists_uri):
195
+ top_tracks_uri = []
196
+ for artist in top_artists_uri:
197
+ top_tracks_all_data = sp.artist_top_tracks(artist)
198
+ top_tracks_data = top_tracks_all_data['tracks']
199
+ for track_data in top_tracks_data:
200
+ top_tracks_uri.append(track_data['uri'])
201
+ return top_tracks_uri
202
+
203
+ def filter_tracks(sp, top_tracks_uri, mood_dict, playlist_length):
204
+ mood = compute_mood(mood_dict)
205
+ selected_tracks_uri = []
206
+
207
+ np.random.shuffle(top_tracks_uri)
208
+ # Batch network requests
209
+ BATCH_SIZE = 100
210
+ i = 0
211
+ all_track_data = []
212
+ while i + BATCH_SIZE < len(top_tracks_uri):
213
+ all_track_data += sp.audio_features(top_tracks_uri[i:i+BATCH_SIZE])
214
+ i += BATCH_SIZE
215
+ all_track_data += sp.audio_features(top_tracks_uri[i:])
216
+
217
+ for i, track in enumerate(top_tracks_uri):
218
+ track_data = all_track_data[i]
219
+ if track_data is None:
220
+ continue
221
+
222
+ valence = track_data['valence']
223
+ danceability = track_data['danceability']
224
+ energy = track_data['energy']
225
+ if mood < .1:
226
+ if valence <= mood + .15 and \
227
+ danceability <= mood * 8 and \
228
+ energy <= mood * 10:
229
+ selected_tracks_uri.append(track)
230
+ elif mood < .25:
231
+ if (mood - .1) <= valence <= (mood + .1) and \
232
+ danceability <= mood * 4 and \
233
+ energy <= mood * 5:
234
+ selected_tracks_uri.append(track)
235
+ elif mood < .5:
236
+ if mood - .05 <= valence <= mood + .05 and \
237
+ danceability <= mood * 1.75 and \
238
+ energy <= mood * 1.75:
239
+ selected_tracks_uri.append(track)
240
+ elif mood < .75:
241
+ if mood - .1 <= valence <= mood + .1 and \
242
+ danceability >= mood / 2.5 and \
243
+ energy >= mood / 2:
244
+ selected_tracks_uri.append(track)
245
+ elif mood < .9:
246
+ if mood - .1 <= valence <= mood + .1 and \
247
+ danceability >= mood / 2 and \
248
+ energy >= mood / 1.75:
249
+ selected_tracks_uri.append(track)
250
+ else:
251
+ if mood - .15 <= valence <= 1 and \
252
+ danceability >= mood / 1.75 and \
253
+ energy >= mood / 1.5:
254
+ selected_tracks_uri.append(track)
255
+
256
+ if len(selected_tracks_uri) >= playlist_length:
257
+ break
258
+ return selected_tracks_uri
259
+
260
+ # Define login and frontend
261
+ PORT_NUMBER = 8080
262
+ SPOTIPY_CLIENT_ID = '2320153024d042c8ba138a108066246c'
263
+ SPOTIPY_CLIENT_SECRET = 'da2746490f6542a3b0cfcff50893e8e8'
264
+ #SPOTIPY_REDIRECT_URI = 'http://localhost:7860'
265
+ SPOTIPY_REDIRECT_URI = "https://huggingface.co/spaces/Bokanovskii/Image-to-music"
266
+ SCOPE = 'ugc-image-upload playlist-read-private playlist-read-collaborative playlist-modify-private playlist-modify-public user-top-read user-read-recently-played user-library-modify user-library-read user-read-email user-read-private'
267
+
268
+ sp_oauth = oauth2.SpotifyOAuth(SPOTIPY_CLIENT_ID, SPOTIPY_CLIENT_SECRET, SPOTIPY_REDIRECT_URI, scope=SCOPE)
269
+
270
+ app = FastAPI()
271
+ app.add_middleware(SessionMiddleware, secret_key="w.o.w")
272
+
273
+ @app.get('/', response_class=HTMLResponse)
274
+ async def homepage(request: Request):
275
+ url = str(request.url)
276
+ code = sp_oauth.parse_response_code(url)
277
+ if code != url:
278
+ request.session['token'] = sp_oauth.get_access_token(code, as_dict=False)
279
+ return RedirectResponse("/gradio")
280
+
281
+ auth_url = sp_oauth.get_authorize_url()
282
+ return """<div style="text-align: center; max-width: 1000px; margin: 0 auto;">
283
+ <div
284
+ style="
285
+ align-items: center;
286
+ gap: 0.8rem;
287
+ font-size: 1.75rem;
288
+ "
289
+ >
290
+ <h3 style="font-weight: 900; margin-bottom: 30px; margin-top: 20px;">
291
+ Image to Music Generator
292
+ </h3>\n""" + \
293
+ f"<a href='" + auth_url + "'>Login to Spotify</a>\n" + \
294
+ """<p>
295
+ </p>
296
+ <div
297
+ style="
298
+ align-items: center;
299
+ gap: 0.8rem;
300
+ font-size: 1rem;
301
+ "
302
+ >
303
+ <small>
304
+ This applet requires a whitelisted Spotify account (contact Charlie Ward)
305
+ </small>"""
306
+
307
+ with gr.Blocks(css="style.css") as demo:
308
+ with gr.Column(elem_id="col-container"):
309
+ gr.HTML("""<div style="text-align: center; max-width: 700px; margin: 0 auto;">
310
+ <div
311
+ style="
312
+ display: inline-flex;
313
+ align-items: center;
314
+ gap: 0.8rem;
315
+ font-size: 1.75rem;
316
+ "
317
+ >
318
+ <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
319
+ Image to Music Generator
320
+ </h1>""")
321
+
322
+ input_img = gr.Image(type="filepath", elem_id="input-img")
323
+ sraddhas_box = gr.HTML(label="Sraddha's Box", elem_id="sraddhas-box", visible=False)
324
+ playlist_output = gr.HTML(label="Generated Playlist", elem_id="app-output", visible=True)
325
+
326
+ with gr.Accordion(label="Playlist Generation Options", open=False):
327
+ playlist_length = gr.Slider(minimum=5, maximum=100, value=30, step=5,
328
+ label="Playlist Length", elem_id="playlist-length")
329
+ with gr.Row():
330
+ privacy = gr.Radio(label="Playlist Privacy Level", choices=["Public", "Private"],
331
+ value="Private")
332
+ gen_mode = gr.Radio(label="Recommendation Base", choices=["Favorites", "Recently Played", "By a Chosen Genre"], value="By a Chosen Genre")
333
+ with gr.Row(visible=True) as genre_choice_row:
334
+ genre_choice = gr.Dropdown(label='Choose a Genre', choices=['Rock', 'Pop', 'Hip-hop', 'Party', 'Mellow', 'Indian', 'Study', 'Romance'], value='Pop')
335
+
336
+ def sraddha_box_hide():
337
+ return {sraddhas_box: gr.update(visible=False)}
338
+
339
+ def genre_dropdown_toggle(gen_mode):
340
+ if gen_mode == 'By a Chosen Genre':
341
+ return {genre_choice_row: gr.update(visible=True)}
342
+ else:
343
+ return {genre_choice_row: gr.update(visible=False)}
344
+
345
+ generate = gr.Button("Generate Playlist from Image")
346
+
347
+ article = """
348
+ <div class="footer">
349
+ <p>
350
+ Built for Sraddha: playlist generation from image inference
351
+ </p>
352
+ <p>
353
+ Sending Love 🤗
354
+ </p>
355
+ </div>
356
+ """
357
+ gr.HTML(article)
358
+ gen_mode.change(genre_dropdown_toggle, inputs=[gen_mode], outputs=[genre_choice_row])
359
+ generate.click(sraddha_box_hide, outputs=[sraddhas_box])
360
+ generate.click(main, inputs=[input_img, playlist_length, privacy, gen_mode, genre_choice],
361
+ outputs=[playlist_output, sraddhas_box], api_name="img-to-music")
362
+
363
+ gradio_app = gr.mount_gradio_app(app, demo, "/gradio")
364
+ uvicorn.run(app, host="localhost", port=7860)
style.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #col-container {
2
+ max-width: 510px;
3
+ margin-left: auto;
4
+ margin-right: auto;
5
+ }
6
+ a {
7
+ text-decoration-line: underline;
8
+ font-weight: 600;
9
+ }
10
+ div#app-output .h-full {
11
+ min-height: 5rem;
12
+ }
13
+ .footer {
14
+ margin-bottom: 45px;
15
+ margin-top: 10px;
16
+ text-align: center;
17
+ border-bottom: 1px solid #e5e5e5;
18
+ }
19
+ .footer > p {
20
+ font-size: 0.8rem;
21
+ display: inline-block;
22
+ padding: 0 10px;
23
+ transform: translateY(10px);
24
+ background: white;
25
+ }
26
+ .dark .footer {
27
+ border-color: #303030;
28
+ }
29
+ .dark .footer > p {
30
+ background: #0b0f19;
31
+ }
32
+ .animate-spin {
33
+ animation: spin 1s linear infinite;
34
+ }
35
+ @keyframes spin {
36
+ from {
37
+ transform: rotate(0deg);
38
+ }
39
+ to {
40
+ transform: rotate(360deg);
41
+ }
42
+ }