fffiloni commited on
Commit
42ca3ea
1 Parent(s): e9375f2

add Tango 2 model

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -136,6 +136,21 @@ def get_tango(prompt):
136
  print(result)
137
  return result
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def infer(image_in, chosen_model):
140
  caption = get_caption(image_in)
141
  if chosen_model == "MAGNet" :
@@ -150,6 +165,9 @@ def infer(image_in, chosen_model):
150
  elif chosen_model == "Tango" :
151
  tango_result = get_tango(caption)
152
  return tango_result
 
 
 
153
 
154
  css="""
155
  #col-container{
@@ -172,7 +190,7 @@ with gr.Blocks(css=css) as demo:
172
  with gr.Column():
173
  image_in = gr.Image(sources=["upload"], type="filepath", label="Image input", value="oiseau.png")
174
  with gr.Row():
175
- chosen_model = gr.Dropdown(label="Choose a model", choices=["MAGNet", "AudioLDM-2", "AudioGen", "Tango"], value="AudioLDM-2")
176
  submit_btn = gr.Button("Submit")
177
  with gr.Column():
178
  audio_o = gr.Audio(label="Audio output")
 
136
  print(result)
137
  return result
138
 
139
+ def get_tango2(prompt):
140
+ try:
141
+ client = Client("declare-lab/tango2")
142
+ except:
143
+ raise gr.Error("Tango2 space API is not ready, please try again in few minutes ")
144
+
145
+ result = client.predict(
146
+ prompt=prompt,
147
+ steps=100,
148
+ guidance=4,
149
+ api_name="/predict"
150
+ )
151
+ print(result)
152
+ return result
153
+
154
  def infer(image_in, chosen_model):
155
  caption = get_caption(image_in)
156
  if chosen_model == "MAGNet" :
 
165
  elif chosen_model == "Tango" :
166
  tango_result = get_tango(caption)
167
  return tango_result
168
+ elif chosen_model == "Tango 2" :
169
+ tango2_result = get_tango2(caption)
170
+ return tango2_result
171
 
172
  css="""
173
  #col-container{
 
190
  with gr.Column():
191
  image_in = gr.Image(sources=["upload"], type="filepath", label="Image input", value="oiseau.png")
192
  with gr.Row():
193
+ chosen_model = gr.Dropdown(label="Choose a model", choices=["MAGNet", "AudioLDM-2", "AudioGen", "Tango", "Tango 2"], value="AudioLDM-2")
194
  submit_btn = gr.Button("Submit")
195
  with gr.Column():
196
  audio_o = gr.Audio(label="Audio output")