craftgamesnetwork commited on
Commit
c9b0479
1 Parent(s): ab6aec7

Create utils/gradio_helpers.py

Browse files
Files changed (1) hide show
  1. utils/gradio_helpers.py +469 -0
utils/gradio_helpers.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from urllib.parse import urlparse
3
+ import requests
4
+ import time
5
+ from PIL import Image
6
+ import base64
7
+ import io
8
+ import uuid
9
+ import os
10
+
11
+
12
+ def extract_property_info(prop):
13
+ combined_prop = {}
14
+ merge_keywords = ["allOf", "anyOf", "oneOf"]
15
+
16
+ for keyword in merge_keywords:
17
+ if keyword in prop:
18
+ for subprop in prop[keyword]:
19
+ combined_prop.update(subprop)
20
+ del prop[keyword]
21
+
22
+ if not combined_prop:
23
+ combined_prop = prop.copy()
24
+
25
+ for key in ["description", "default"]:
26
+ if key in prop:
27
+ combined_prop[key] = prop[key]
28
+
29
+ return combined_prop
30
+
31
+
32
+ def detect_file_type(filename):
33
+ audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"]
34
+ image_extensions = [
35
+ ".jpg",
36
+ ".jpeg",
37
+ ".png",
38
+ ".gif",
39
+ ".bmp",
40
+ ".tiff",
41
+ ".svg",
42
+ ".webp",
43
+ ]
44
+ video_extensions = [
45
+ ".mp4",
46
+ ".mov",
47
+ ".wmv",
48
+ ".flv",
49
+ ".avi",
50
+ ".avchd",
51
+ ".mkv",
52
+ ".webm",
53
+ ]
54
+
55
+ # Extract the file extension
56
+ if isinstance(filename, str):
57
+ extension = filename[filename.rfind(".") :].lower()
58
+
59
+ # Check the extension against each list
60
+ if extension in audio_extensions:
61
+ return "audio"
62
+ elif extension in image_extensions:
63
+ return "image"
64
+ elif extension in video_extensions:
65
+ return "video"
66
+ else:
67
+ return "string"
68
+ elif isinstance(filename, list):
69
+ return "list"
70
+
71
+
72
+ def build_gradio_inputs(ordered_input_schema, example_inputs=None):
73
+ inputs = []
74
+ input_field_strings = """inputs = []\n"""
75
+ names = []
76
+ for index, (name, prop) in enumerate(ordered_input_schema):
77
+ names.append(name)
78
+ prop = extract_property_info(prop)
79
+ if "enum" in prop:
80
+ input_field = gr.Dropdown(
81
+ choices=prop["enum"],
82
+ label=prop.get("title"),
83
+ info=prop.get("description"),
84
+ value=prop.get("default"),
85
+ )
86
+ input_field_string = f"""inputs.append(gr.Dropdown(
87
+ choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}"
88
+ ))\n"""
89
+ elif prop["type"] == "integer":
90
+ if prop.get("minimum") and prop.get("maximum"):
91
+ input_field = gr.Slider(
92
+ label=prop.get("title"),
93
+ info=prop.get("description"),
94
+ value=prop.get("default"),
95
+ minimum=prop.get("minimum"),
96
+ maximum=prop.get("maximum"),
97
+ step=1,
98
+ )
99
+ input_field_string = f"""inputs.append(gr.Slider(
100
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")},
101
+ minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1,
102
+ ))\n"""
103
+ else:
104
+ input_field = gr.Number(
105
+ label=prop.get("title"),
106
+ info=prop.get("description"),
107
+ value=prop.get("default"),
108
+ )
109
+ input_field_string = f"""inputs.append(gr.Number(
110
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
111
+ ))\n"""
112
+ elif prop["type"] == "number":
113
+ if prop.get("minimum") and prop.get("maximum"):
114
+ input_field = gr.Slider(
115
+ label=prop.get("title"),
116
+ info=prop.get("description"),
117
+ value=prop.get("default"),
118
+ minimum=prop.get("minimum"),
119
+ maximum=prop.get("maximum"),
120
+ )
121
+ input_field_string = f"""inputs.append(gr.Slider(
122
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")},
123
+ minimum={prop.get("minimum")}, maximum={prop.get("maximum")}
124
+ ))\n"""
125
+ else:
126
+ input_field = gr.Number(
127
+ label=prop.get("title"),
128
+ info=prop.get("description"),
129
+ value=prop.get("default"),
130
+ )
131
+ input_field_string = f"""inputs.append(gr.Number(
132
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
133
+ ))\n"""
134
+ elif prop["type"] == "boolean":
135
+ input_field = gr.Checkbox(
136
+ label=prop.get("title"),
137
+ info=prop.get("description"),
138
+ value=prop.get("default"),
139
+ )
140
+ input_field_string = f"""inputs.append(gr.Checkbox(
141
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
142
+ ))\n"""
143
+ elif (
144
+ prop["type"] == "string" and prop.get("format") == "uri" and example_inputs
145
+ ):
146
+ input_type_example = example_inputs.get(name, None)
147
+ if input_type_example:
148
+ input_type = detect_file_type(input_type_example)
149
+ else:
150
+ input_type = None
151
+ if input_type == "image":
152
+ input_field = gr.Image(label=prop.get("title"), type="filepath")
153
+ input_field_string = f"""inputs.append(gr.Image(
154
+ label="{prop.get("title")}", type="filepath"
155
+ ))\n"""
156
+ elif input_type == "audio":
157
+ input_field = gr.Audio(label=prop.get("title"), type="filepath")
158
+ input_field_string = f"""inputs.append(gr.Audio(
159
+ label="{prop.get("title")}", type="filepath"
160
+ ))\n"""
161
+ elif input_type == "video":
162
+ input_field = gr.Video(label=prop.get("title"))
163
+ input_field_string = f"""inputs.append(gr.Video(
164
+ label="{prop.get("title")}"
165
+ ))\n"""
166
+ else:
167
+ input_field = gr.File(label=prop.get("title"))
168
+ input_field_string = f"""inputs.append(gr.File(
169
+ label="{prop.get("title")}"
170
+ ))\n"""
171
+ else:
172
+ input_field = gr.Textbox(
173
+ label=prop.get("title"),
174
+ info=prop.get("description"),
175
+ )
176
+ input_field_string = f"""inputs.append(gr.Textbox(
177
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}
178
+ ))\n"""
179
+ inputs.append(input_field)
180
+ input_field_strings += f"{input_field_string}\n"
181
+
182
+ input_field_strings += f"names = {names}\n"
183
+
184
+ return inputs, input_field_strings, names
185
+
186
+
187
+ def build_gradio_outputs_replicate(output_types):
188
+ outputs = []
189
+ output_field_strings = """outputs = []\n"""
190
+ if output_types:
191
+ for output in output_types:
192
+ if output == "image":
193
+ output_field = gr.Image()
194
+ output_field_string = "outputs.append(gr.Image())"
195
+ elif output == "audio":
196
+ output_field = gr.Audio(type="filepath")
197
+ output_field_string = "outputs.append(gr.Audio(type='filepath'))"
198
+ elif output == "video":
199
+ output_field = gr.Video()
200
+ output_field_string = "outputs.append(gr.Video())"
201
+ elif output == "string":
202
+ output_field = gr.Textbox()
203
+ output_field_string = "outputs.append(gr.Textbox())"
204
+ elif output == "json":
205
+ output_field = gr.JSON()
206
+ output_field_string = "outputs.append(gr.JSON())"
207
+ elif output == "list":
208
+ output_field = gr.JSON()
209
+ output_field_string = "outputs.append(gr.JSON())"
210
+ outputs.append(output_field)
211
+ output_field_strings += f"{output_field_string}\n"
212
+ else:
213
+ output_field = gr.JSON()
214
+ output_field_string = "outputs.append(gr.JSON())"
215
+ outputs.append(output_field)
216
+
217
+ return outputs, output_field_strings
218
+
219
+
220
+ def build_gradio_outputs_cog():
221
+ pass
222
+
223
+
224
+ def process_outputs(outputs):
225
+ output_values = []
226
+ for output in outputs:
227
+ if not output:
228
+ continue
229
+ if isinstance(output, str):
230
+ if output.startswith("data:image"):
231
+ base64_data = output.split(",", 1)[1]
232
+ image_data = base64.b64decode(base64_data)
233
+ image_stream = io.BytesIO(image_data)
234
+ image = Image.open(image_stream)
235
+ output_values.append(image)
236
+ elif output.startswith("data:audio"):
237
+ base64_data = output.split(",", 1)[1]
238
+ audio_data = base64.b64decode(base64_data)
239
+ audio_stream = io.BytesIO(audio_data)
240
+ filename = f"{uuid.uuid4()}.wav" # Change format as needed
241
+ with open(filename, "wb") as audio_file:
242
+ audio_file.write(audio_stream.getbuffer())
243
+ output_values.append(filename)
244
+ elif output.startswith("data:video"):
245
+ base64_data = output.split(",", 1)[1]
246
+ video_data = base64.b64decode(base64_data)
247
+ video_stream = io.BytesIO(video_data)
248
+ # Here you can save the audio or return the stream for further processing
249
+ filename = f"{uuid.uuid4()}.mp4" # Change format as needed
250
+ with open(filename, "wb") as video_file:
251
+ video_file.write(video_stream.getbuffer())
252
+ output_values.append(filename)
253
+ else:
254
+ output_values.append(output)
255
+ else:
256
+ output_values.append(output)
257
+ return output_values
258
+
259
+
260
+ def parse_outputs(data):
261
+ if isinstance(data, dict):
262
+ # Handle case where data is an object
263
+ dict_values = []
264
+ for value in data.values():
265
+ extracted_values = parse_outputs(value)
266
+ # For dict, we append instead of extend to maintain list structure within objects
267
+ if isinstance(value, list):
268
+ dict_values += [extracted_values]
269
+ else:
270
+ dict_values += extracted_values
271
+ return dict_values
272
+ elif isinstance(data, list):
273
+ # Handle case where data is an array
274
+ list_values = []
275
+ for item in data:
276
+ # Here we extend to flatten the list since we're already in an array context
277
+ list_values += parse_outputs(item)
278
+ return list_values
279
+ else:
280
+ # Handle primitive data types directly
281
+ return [data]
282
+
283
+
284
+ def create_dynamic_gradio_app(
285
+ inputs,
286
+ outputs,
287
+ api_url,
288
+ api_id=None,
289
+ replicate_token=None,
290
+ title="",
291
+ model_description="",
292
+ names=[],
293
+ local_base=False,
294
+ hostname="0.0.0.0",
295
+ ):
296
+ expected_outputs = len(outputs)
297
+
298
+ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
299
+ payload = {"input": {}}
300
+ if api_id:
301
+ payload["version"] = api_id
302
+ parsed_url = urlparse(str(request.url))
303
+ if local_base:
304
+ base_url = f"http://{hostname}:7860"
305
+ else:
306
+ base_url = parsed_url.scheme + "://" + parsed_url.netloc
307
+ for i, key in enumerate(names):
308
+ value = args[i]
309
+ if value and (os.path.exists(str(value))):
310
+ value = f"{base_url}/file=" + value
311
+ if value is not None and value != "":
312
+ payload["input"][key] = value
313
+ print(payload)
314
+ headers = {"Content-Type": "application/json"}
315
+ if replicate_token:
316
+ headers["Authorization"] = f"Token {replicate_token}"
317
+ print(headers)
318
+ response = requests.post(api_url, headers=headers, json=payload)
319
+ if response.status_code == 201:
320
+ follow_up_url = response.json()["urls"]["get"]
321
+ response = requests.get(follow_up_url, headers=headers)
322
+ while response.json()["status"] != "succeeded":
323
+ if response.json()["status"] == "failed":
324
+ raise gr.Error("The submission failed!")
325
+ response = requests.get(follow_up_url, headers=headers)
326
+ time.sleep(1)
327
+ # TODO: Add a failing mechanism if the API gets stuck
328
+ if response.status_code == 200:
329
+ json_response = response.json()
330
+ # If the output component is JSON return the entire output response
331
+ if outputs[0].get_config()["name"] == "json":
332
+ return json_response["output"]
333
+ predict_outputs = parse_outputs(json_response["output"])
334
+ processed_outputs = process_outputs(predict_outputs)
335
+ difference_outputs = expected_outputs - len(processed_outputs)
336
+ # If less outputs than expected, hide the extra ones
337
+ if difference_outputs > 0:
338
+ extra_outputs = [gr.update(visible=False)] * difference_outputs
339
+ processed_outputs.extend(extra_outputs)
340
+ # If more outputs than expected, cap the outputs to the expected number if
341
+ elif difference_outputs < 0:
342
+ processed_outputs = processed_outputs[:difference_outputs]
343
+
344
+ return (
345
+ tuple(processed_outputs)
346
+ if len(processed_outputs) > 1
347
+ else processed_outputs[0]
348
+ )
349
+
350
+ else:
351
+ if response.status_code == 409:
352
+ raise gr.Error(
353
+ f"Sorry, the Cog image is still processing. Try again in a bit."
354
+ )
355
+ raise gr.Error(f"The submission failed! Error: {response.status_code}")
356
+
357
+ app = gr.Interface(
358
+ fn=predict,
359
+ inputs=inputs,
360
+ outputs=outputs,
361
+ title=title,
362
+ description=model_description,
363
+ allow_flagging="never",
364
+ )
365
+ return app
366
+
367
+
368
+ def create_gradio_app_script(
369
+ inputs_string,
370
+ outputs_string,
371
+ api_url,
372
+ api_id=None,
373
+ replicate_token=None,
374
+ title="",
375
+ model_description="",
376
+ local_base=False,
377
+ hostname="0.0.0.0"
378
+ ):
379
+ headers = {"Content-Type": "application/json"}
380
+ if replicate_token:
381
+ headers["Authorization"] = f"Token {replicate_token}"
382
+
383
+ if local_base:
384
+ base_url = f'base_url = "http://{hostname}:7860"'
385
+ else:
386
+ base_url = """parsed_url = urlparse(str(request.url))
387
+ base_url = parsed_url.scheme + "://" + parsed_url.netloc"""
388
+ headers_string = f"""headers = {headers}\n"""
389
+ api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else ""
390
+ definition_string = """expected_outputs = len(outputs)
391
+ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):"""
392
+ payload_string = f"""payload = {{"input": {{}}}}
393
+ {api_id_value}
394
+
395
+ {base_url}
396
+ for i, key in enumerate(names):
397
+ value = args[i]
398
+ if value and (os.path.exists(str(value))):
399
+ value = f"{{base_url}}/file=" + value
400
+ if value is not None and value != "":
401
+ payload["input"][key] = value\n"""
402
+
403
+ request_string = (
404
+ f"""response = requests.post("{api_url}", headers=headers, json=payload)\n"""
405
+ )
406
+
407
+ result_string = f"""
408
+ if response.status_code == 201:
409
+ follow_up_url = response.json()["urls"]["get"]
410
+ response = requests.get(follow_up_url, headers=headers)
411
+ while response.json()["status"] != "succeeded":
412
+ if response.json()["status"] == "failed":
413
+ raise gr.Error("The submission failed!")
414
+ response = requests.get(follow_up_url, headers=headers)
415
+ time.sleep(1)
416
+ if response.status_code == 200:
417
+ json_response = response.json()
418
+ #If the output component is JSON return the entire output response
419
+ if(outputs[0].get_config()["name"] == "json"):
420
+ return json_response["output"]
421
+ predict_outputs = parse_outputs(json_response["output"])
422
+ processed_outputs = process_outputs(predict_outputs)
423
+ difference_outputs = expected_outputs - len(processed_outputs)
424
+ # If less outputs than expected, hide the extra ones
425
+ if difference_outputs > 0:
426
+ extra_outputs = [gr.update(visible=False)] * difference_outputs
427
+ processed_outputs.extend(extra_outputs)
428
+ # If more outputs than expected, cap the outputs to the expected number
429
+ elif difference_outputs < 0:
430
+ processed_outputs = processed_outputs[:difference_outputs]
431
+
432
+ return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
433
+ else:
434
+ if(response.status_code == 409):
435
+ raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
436
+ raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n"""
437
+
438
+ interface_string = f"""title = "{title}"
439
+ model_description = "{model_description}"
440
+
441
+ app = gr.Interface(
442
+ fn=predict,
443
+ inputs=inputs,
444
+ outputs=outputs,
445
+ title=title,
446
+ description=model_description,
447
+ allow_flagging="never",
448
+ )
449
+ app.launch(share=True)
450
+ """
451
+
452
+ app_string = f"""import gradio as gr
453
+ from urllib.parse import urlparse
454
+ import requests
455
+ import time
456
+ import os
457
+
458
+ from utils.gradio_helpers import parse_outputs, process_outputs
459
+
460
+ {inputs_string}
461
+ {outputs_string}
462
+ {definition_string}
463
+ {headers_string}
464
+ {payload_string}
465
+ {request_string}
466
+ {result_string}
467
+ {interface_string}
468
+ """
469
+ return app_string