fffiloni commited on
Commit
1bd1938
β€’
1 Parent(s): 90431e0

added mode options

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -18,12 +18,12 @@ img_to_text = gr.Blocks.load(name="spaces/fffiloni/CLIP-Interrogator-2")
18
 
19
  from share_btn import community_icon_html, loading_icon_html, share_js
20
 
21
- def get_prompts(uploaded_image, track_duration, gen_intensity):
22
  print("calling clip interrogator")
23
  #prompt = img_to_text(uploaded_image, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
24
  prompt = img_to_text(uploaded_image, 'fast', 4, fn_index=1)[0]
25
  print(prompt)
26
- music_result = generate_track_by_prompt(prompt, track_duration, gen_intensity)
27
  print(music_result)
28
  return music_result[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
29
 
@@ -33,11 +33,8 @@ minilm = SentenceTransformer('all-MiniLM-L6-v2')
33
  mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
34
 
35
 
36
- def get_track_by_tags(tags, pat, duration, gen_intensity, maxit=20, loop=False):
37
- if loop:
38
- mode = "loop"
39
- else:
40
- mode = "track"
41
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
42
  json={
43
  "method": "RecordTrackTTM",
@@ -47,7 +44,7 @@ def get_track_by_tags(tags, pat, duration, gen_intensity, maxit=20, loop=False):
47
  "format": "wav",
48
  "intensity":gen_intensity,
49
  "tags": tags,
50
- "mode": mode
51
  }
52
  })
53
 
@@ -63,11 +60,11 @@ def get_track_by_tags(tags, pat, duration, gen_intensity, maxit=20, loop=False):
63
  time.sleep(1)
64
 
65
 
66
- def generate_track_by_prompt(prompt, duration, gen_intensity):
67
  try:
68
  pat = get_pat("prodia@prodia.com")
69
  _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
70
- result = get_track_by_tags(tags, pat, int(duration), gen_intensity, loop=False)
71
  print(result)
72
  return result, ",".join(tags), "Success"
73
  except Exception as e:
@@ -179,7 +176,8 @@ with gr.Blocks(css=css) as demo:
179
  input_img = gr.Image(type="filepath", elem_id="input-img")
180
  with gr.Row():
181
  track_duration = gr.Slider(minimum=20, maximum=120, value=30, step=5, label="πŸŽ… Track duration", elem_id="duration-inp")
182
- gen_intensity = gr.Radio(choices=["low", "medium", "high"], value="high", label="Complexity")
 
183
  generate = gr.Button("Generate Music from Image")
184
 
185
  music_output = gr.Audio(label="Result", type="filepath", elem_id="music-output")
@@ -190,7 +188,7 @@ with gr.Blocks(css=css) as demo:
190
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
191
 
192
  gr.HTML(article)
193
- generate.click(get_prompts, inputs=[input_img,track_duration,gen_intensity], outputs=[music_output, share_button, community_icon, loading_icon], api_name="i2m")
194
  share_button.click(None, [], [], _js=share_js)
195
 
196
  demo.queue(max_size=32, concurrency_count=20).launch()
 
18
 
19
  from share_btn import community_icon_html, loading_icon_html, share_js
20
 
21
+ def get_prompts(uploaded_image, track_duration, gen_intensity, gen_mode):
22
  print("calling clip interrogator")
23
  #prompt = img_to_text(uploaded_image, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
24
  prompt = img_to_text(uploaded_image, 'fast', 4, fn_index=1)[0]
25
  print(prompt)
26
+ music_result = generate_track_by_prompt(prompt, track_duration, gen_intensity, gen_mode)
27
  print(music_result)
28
  return music_result[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
29
 
 
33
  mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
34
 
35
 
36
+ def get_track_by_tags(tags, pat, duration, gen_intensity, gen_mode, maxit=20):
37
+
 
 
 
38
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
39
  json={
40
  "method": "RecordTrackTTM",
 
44
  "format": "wav",
45
  "intensity":gen_intensity,
46
  "tags": tags,
47
+ "mode": gen_mode
48
  }
49
  })
50
 
 
60
  time.sleep(1)
61
 
62
 
63
+ def generate_track_by_prompt(prompt, duration, gen_intensity, gen_mode):
64
  try:
65
  pat = get_pat("prodia@prodia.com")
66
  _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
67
+ result = get_track_by_tags(tags, pat, int(duration), gen_intensity, gen_mode)
68
  print(result)
69
  return result, ",".join(tags), "Success"
70
  except Exception as e:
 
176
  input_img = gr.Image(type="filepath", elem_id="input-img")
177
  with gr.Row():
178
  track_duration = gr.Slider(minimum=20, maximum=120, value=30, step=5, label="πŸŽ… Track duration", elem_id="duration-inp")
179
+ gen_intensity = gr.Dropdown(choices=["low", "medium", "high"], value="high", label="Complexity")
180
+ gen_mode = gr.Radio(label="mode", choices=["track", "loop"], value="track")
181
  generate = gr.Button("Generate Music from Image")
182
 
183
  music_output = gr.Audio(label="Result", type="filepath", elem_id="music-output")
 
188
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
189
 
190
  gr.HTML(article)
191
+ generate.click(get_prompts, inputs=[input_img,track_duration,gen_intensity,gen_mode], outputs=[music_output, share_button, community_icon, loading_icon], api_name="i2m")
192
  share_button.click(None, [], [], _js=share_js)
193
 
194
  demo.queue(max_size=32, concurrency_count=20).launch()