fffiloni commited on
Commit
c90f947
Ā·
1 Parent(s): 9da3dd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -7
app.py CHANGED
@@ -52,6 +52,9 @@ def get_pat_token():
52
  return pat
53
 
54
  def get_music(pat, prompt, track_duration, gen_intensity, gen_mode):
 
 
 
55
 
56
  r = httpx.post('https://api-b2b.mubert.com/v2/TTMRecordTrack',
57
  json={
@@ -100,19 +103,81 @@ def get_results(text_prompt,track_duration,gen_intensity,gen_mode):
100
  music = get_music(pat_token, text_prompt, track_duration, gen_intensity, gen_mode)
101
  return pat_token, music
102
 
103
- def get_prompts(uploaded_image, track_duration, gen_intensity, gen_mode):
104
  print("calling clip interrogator")
105
  #prompt = img_to_text(uploaded_image, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
 
106
  prompt = img_to_text(uploaded_image, 'best', 4, fn_index=1)[0]
107
  print(prompt)
108
- music_result = get_results(prompt, track_duration, gen_intensity, gen_mode)
 
 
 
 
 
 
 
109
  wave_file = convert_mp3_to_wav(music_result[1])
110
- #music_result = generate_track_by_prompt(pat, prompt, track_duration, gen_intensity, gen_mode)
111
- #print(pat
112
  time.sleep(1)
113
  return wave_file, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
114
 
115
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def get_track_by_tags(tags, pat, duration, gen_intensity, gen_mode, maxit=20):
117
 
118
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
@@ -223,16 +288,17 @@ with gr.Blocks(css="style.css") as demo:
223
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
224
 
225
  with gr.Accordion(label="Music Generation Options", open=False):
 
226
  track_duration = gr.Slider(minimum=20, maximum=120, value=30, step=5, label="Track duration", elem_id="duration-inp")
227
  with gr.Row():
228
  gen_intensity = gr.Dropdown(choices=["low", "medium", "high"], value="medium", label="Intensity")
229
- gen_mode = gr.Radio(label="mode", choices=["track", "loop"], value="track")
230
 
231
  generate = gr.Button("Generate Music from Image")
232
 
233
  gr.HTML(article)
234
 
235
- generate.click(get_prompts, inputs=[input_img,track_duration,gen_intensity,gen_mode], outputs=[music_output, share_button, community_icon, loading_icon], api_name="i2m")
236
  share_button.click(None, [], [], _js=share_js)
237
 
238
  demo.queue(max_size=32, concurrency_count=20).launch()
 
52
  return pat
53
 
54
  def get_music(pat, prompt, track_duration, gen_intensity, gen_mode):
55
+
56
+ if len(prompt) > 200:
57
+ prompt = prompt[:200]
58
 
59
  r = httpx.post('https://api-b2b.mubert.com/v2/TTMRecordTrack',
60
  json={
 
103
  music = get_music(pat_token, text_prompt, track_duration, gen_intensity, gen_mode)
104
  return pat_token, music
105
 
106
+ def get_prompts(uploaded_image, track_duration, gen_intensity, gen_mode, openai_api_key):
107
  print("calling clip interrogator")
108
  #prompt = img_to_text(uploaded_image, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
109
+
110
  prompt = img_to_text(uploaded_image, 'best', 4, fn_index=1)[0]
111
  print(prompt)
112
+ if openai_api_key != None:
113
+ gpt_adaptation = try_api(prompt, openai_api_key)
114
+ if gpt_adaptation[0] != "oups":
115
+ musical_prompt = gpt_adaptation[0]
116
+ else:
117
+ musical_prompt = prompt
118
+ music_result = get_results(musical_prompt, track_duration, gen_intensity, gen_mode)
119
+
120
  wave_file = convert_mp3_to_wav(music_result[1])
121
+
 
122
  time.sleep(1)
123
  return wave_file, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
124
 
125
+ def try_api(message, openai_api_key):
126
+
127
+ try:
128
+ response = call_api(message, openai_api_key)
129
+ return response, "<span class='openai_clear'>no error</span>"
130
+ except openai.error.Timeout as e:
131
+ #Handle timeout error, e.g. retry or log
132
+ print(f"OpenAI API request timed out: {e}")
133
+ return "oups", f"<span class='openai_error'>OpenAI API request timed out: <br />{e}</span>"
134
+ except openai.error.APIError as e:
135
+ #Handle API error, e.g. retry or log
136
+ print(f"OpenAI API returned an API Error: {e}")
137
+ return "oups", f"<span class='openai_error'>OpenAI API returned an API Error: <br />{e}</span>"
138
+ except openai.error.APIConnectionError as e:
139
+ #Handle connection error, e.g. check network or log
140
+ print(f"OpenAI API request failed to connect: {e}")
141
+ return "oups", f"<span class='openai_error'>OpenAI API request failed to connect: <br />{e}</span>"
142
+ except openai.error.InvalidRequestError as e:
143
+ #Handle invalid request error, e.g. validate parameters or log
144
+ print(f"OpenAI API request was invalid: {e}")
145
+ return "oups", f"<span class='openai_error'>OpenAI API request was invalid: <br />{e}</span>"
146
+ except openai.error.AuthenticationError as e:
147
+ #Handle authentication error, e.g. check credentials or log
148
+ print(f"OpenAI API request was not authorized: {e}")
149
+ return "oups", f"<span class='openai_error'>OpenAI API request was not authorized: <br />{e}</span>"
150
+ except openai.error.PermissionError as e:
151
+ #Handle permission error, e.g. check scope or log
152
+ print(f"OpenAI API request was not permitted: {e}")
153
+ return "oups", f"<span class='openai_error'>OpenAI API request was not permitted: <br />{e}</span>"
154
+ except openai.error.RateLimitError as e:
155
+ #Handle rate limit error, e.g. wait or log
156
+ print(f"OpenAI API request exceeded rate limit: {e}")
157
+ return "oups", f"<span class='openai_error'>OpenAI API request exceeded rate limit: <br />{e}</span>"
158
+
159
+ def call_api(message, openai_api_key):
160
+
161
+ print("starting open ai")
162
+ augmented_prompt = message + prevent_code_gen
163
+ openai.api_key = openai_api_key
164
+
165
+ response = openai.Completion.create(
166
+ model="text-davinci-003",
167
+ prompt=augmented_prompt,
168
+ temperature=0.5,
169
+ max_tokens=2048,
170
+ top_p=1,
171
+ frequency_penalty=0,
172
+ presence_penalty=0.6
173
+ )
174
+
175
+ print(response)
176
+
177
+ #return str(response.choices[0].text).split("\n",2)[2]
178
+ return str(response.choices[0].text)
179
+
180
+
181
  def get_track_by_tags(tags, pat, duration, gen_intensity, gen_mode, maxit=20):
182
 
183
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
 
288
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
289
 
290
  with gr.Accordion(label="Music Generation Options", open=False):
291
+ openai_api_key = gr.Textbox(label="OpenAI key", info="You can use you OpenAI key to adapt CLIP Interrogator caption to a musical translation.")
292
  track_duration = gr.Slider(minimum=20, maximum=120, value=30, step=5, label="Track duration", elem_id="duration-inp")
293
  with gr.Row():
294
  gen_intensity = gr.Dropdown(choices=["low", "medium", "high"], value="medium", label="Intensity")
295
+ gen_mode = gr.Radio(label="mode", choices=["track", "loop"], value="loop")
296
 
297
  generate = gr.Button("Generate Music from Image")
298
 
299
  gr.HTML(article)
300
 
301
+ generate.click(get_prompts, inputs=[input_img,track_duration,gen_intensity,gen_mode, openai_api_key], outputs=[music_output, share_button, community_icon, loading_icon], api_name="i2m")
302
  share_button.click(None, [], [], _js=share_js)
303
 
304
  demo.queue(max_size=32, concurrency_count=20).launch()