Spaces:
Running
on
A10G
Running
on
A10G
import gradio as gr | |
from urllib.parse import urlparse | |
import requests | |
import time | |
from PIL import Image | |
import base64 | |
import io | |
import uuid | |
import os | |
def extract_property_info(prop): | |
combined_prop = {} | |
merge_keywords = ["allOf", "anyOf", "oneOf"] | |
for keyword in merge_keywords: | |
if keyword in prop: | |
for subprop in prop[keyword]: | |
combined_prop.update(subprop) | |
del prop[keyword] | |
if not combined_prop: | |
combined_prop = prop.copy() | |
for key in ["description", "default"]: | |
if key in prop: | |
combined_prop[key] = prop[key] | |
return combined_prop | |
def detect_file_type(filename): | |
audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"] | |
image_extensions = [ | |
".jpg", | |
".jpeg", | |
".png", | |
".gif", | |
".bmp", | |
".tiff", | |
".svg", | |
".webp", | |
] | |
video_extensions = [ | |
".mp4", | |
".mov", | |
".wmv", | |
".flv", | |
".avi", | |
".avchd", | |
".mkv", | |
".webm", | |
] | |
# Extract the file extension | |
if isinstance(filename, str): | |
extension = filename[filename.rfind(".") :].lower() | |
# Check the extension against each list | |
if extension in audio_extensions: | |
return "audio" | |
elif extension in image_extensions: | |
return "image" | |
elif extension in video_extensions: | |
return "video" | |
else: | |
return "string" | |
elif isinstance(filename, list): | |
return "list" | |
def build_gradio_inputs(ordered_input_schema, example_inputs=None): | |
inputs = [] | |
input_field_strings = """inputs = []\n""" | |
names = [] | |
for index, (name, prop) in enumerate(ordered_input_schema): | |
names.append(name) | |
prop = extract_property_info(prop) | |
if "enum" in prop: | |
input_field = gr.Dropdown( | |
choices=prop["enum"], | |
label=prop.get("title"), | |
info=prop.get("description"), | |
value=prop.get("default"), | |
) | |
input_field_string = f"""inputs.append(gr.Dropdown( | |
choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}" | |
))\n""" | |
elif prop["type"] == "integer": | |
if prop.get("minimum") and prop.get("maximum"): | |
input_field = gr.Slider( | |
label=prop.get("title"), | |
info=prop.get("description"), | |
value=prop.get("default"), | |
minimum=prop.get("minimum"), | |
maximum=prop.get("maximum"), | |
step=1, | |
) | |
input_field_string = f"""inputs.append(gr.Slider( | |
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, | |
minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1, | |
))\n""" | |
else: | |
input_field = gr.Number( | |
label=prop.get("title"), | |
info=prop.get("description"), | |
value=prop.get("default"), | |
) | |
input_field_string = f"""inputs.append(gr.Number( | |
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} | |
))\n""" | |
elif prop["type"] == "number": | |
if prop.get("minimum") and prop.get("maximum"): | |
input_field = gr.Slider( | |
label=prop.get("title"), | |
info=prop.get("description"), | |
value=prop.get("default"), | |
minimum=prop.get("minimum"), | |
maximum=prop.get("maximum"), | |
) | |
input_field_string = f"""inputs.append(gr.Slider( | |
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, | |
minimum={prop.get("minimum")}, maximum={prop.get("maximum")} | |
))\n""" | |
else: | |
input_field = gr.Number( | |
label=prop.get("title"), | |
info=prop.get("description"), | |
value=prop.get("default"), | |
) | |
input_field_string = f"""inputs.append(gr.Number( | |
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} | |
))\n""" | |
elif prop["type"] == "boolean": | |
input_field = gr.Checkbox( | |
label=prop.get("title"), | |
info=prop.get("description"), | |
value=prop.get("default"), | |
) | |
input_field_string = f"""inputs.append(gr.Checkbox( | |
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} | |
))\n""" | |
elif ( | |
prop["type"] == "string" and prop.get("format") == "uri" and example_inputs | |
): | |
input_type_example = example_inputs.get(name, None) | |
if input_type_example: | |
input_type = detect_file_type(input_type_example) | |
else: | |
input_type = None | |
if input_type == "image": | |
input_field = gr.Image(label=prop.get("title"), type="filepath") | |
input_field_string = f"""inputs.append(gr.Image( | |
label="{prop.get("title")}", type="filepath" | |
))\n""" | |
elif input_type == "audio": | |
input_field = gr.Audio(label=prop.get("title"), type="filepath") | |
input_field_string = f"""inputs.append(gr.Audio( | |
label="{prop.get("title")}", type="filepath" | |
))\n""" | |
elif input_type == "video": | |
input_field = gr.Video(label=prop.get("title")) | |
input_field_string = f"""inputs.append(gr.Video( | |
label="{prop.get("title")}" | |
))\n""" | |
else: | |
input_field = gr.File(label=prop.get("title")) | |
input_field_string = f"""inputs.append(gr.File( | |
label="{prop.get("title")}" | |
))\n""" | |
else: | |
input_field = gr.Textbox( | |
label=prop.get("title"), | |
info=prop.get("description"), | |
) | |
input_field_string = f"""inputs.append(gr.Textbox( | |
label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'} | |
))\n""" | |
inputs.append(input_field) | |
input_field_strings += f"{input_field_string}\n" | |
input_field_strings += f"names = {names}\n" | |
return inputs, input_field_strings, names | |
def build_gradio_outputs_replicate(output_types): | |
outputs = [] | |
output_field_strings = """outputs = []\n""" | |
if output_types: | |
for output in output_types: | |
if output == "image": | |
output_field = gr.Image() | |
output_field_string = "outputs.append(gr.Image())" | |
elif output == "audio": | |
output_field = gr.Audio(type="filepath") | |
output_field_string = "outputs.append(gr.Audio(type='filepath'))" | |
elif output == "video": | |
output_field = gr.Video() | |
output_field_string = "outputs.append(gr.Video())" | |
elif output == "string": | |
output_field = gr.Textbox() | |
output_field_string = "outputs.append(gr.Textbox())" | |
elif output == "json": | |
output_field = gr.JSON() | |
output_field_string = "outputs.append(gr.JSON())" | |
elif output == "list": | |
output_field = gr.JSON() | |
output_field_string = "outputs.append(gr.JSON())" | |
outputs.append(output_field) | |
output_field_strings += f"{output_field_string}\n" | |
else: | |
output_field = gr.JSON() | |
output_field_string = "outputs.append(gr.JSON())" | |
outputs.append(output_field) | |
return outputs, output_field_strings | |
def build_gradio_outputs_cog(): | |
pass | |
def process_outputs(outputs): | |
output_values = [] | |
for output in outputs: | |
if not output: | |
continue | |
if isinstance(output, str): | |
if output.startswith("data:image"): | |
base64_data = output.split(",", 1)[1] | |
image_data = base64.b64decode(base64_data) | |
image_stream = io.BytesIO(image_data) | |
image = Image.open(image_stream) | |
output_values.append(image) | |
elif output.startswith("data:audio"): | |
base64_data = output.split(",", 1)[1] | |
audio_data = base64.b64decode(base64_data) | |
audio_stream = io.BytesIO(audio_data) | |
filename = f"{uuid.uuid4()}.wav" # Change format as needed | |
with open(filename, "wb") as audio_file: | |
audio_file.write(audio_stream.getbuffer()) | |
output_values.append(filename) | |
elif output.startswith("data:video"): | |
base64_data = output.split(",", 1)[1] | |
video_data = base64.b64decode(base64_data) | |
video_stream = io.BytesIO(video_data) | |
# Here you can save the audio or return the stream for further processing | |
filename = f"{uuid.uuid4()}.mp4" # Change format as needed | |
with open(filename, "wb") as video_file: | |
video_file.write(video_stream.getbuffer()) | |
output_values.append(filename) | |
else: | |
output_values.append(output) | |
else: | |
output_values.append(output) | |
return output_values | |
def parse_outputs(data): | |
if isinstance(data, dict): | |
# Handle case where data is an object | |
dict_values = [] | |
for value in data.values(): | |
extracted_values = parse_outputs(value) | |
# For dict, we append instead of extend to maintain list structure within objects | |
if isinstance(value, list): | |
dict_values += [extracted_values] | |
else: | |
dict_values += extracted_values | |
return dict_values | |
elif isinstance(data, list): | |
# Handle case where data is an array | |
list_values = [] | |
for item in data: | |
# Here we extend to flatten the list since we're already in an array context | |
list_values += parse_outputs(item) | |
return list_values | |
else: | |
# Handle primitive data types directly | |
return [data] | |
def create_dynamic_gradio_app( | |
inputs, | |
outputs, | |
api_url, | |
api_id=None, | |
replicate_token=None, | |
title="", | |
model_description="", | |
names=[], | |
local_base=False, | |
hostname="0.0.0.0", | |
): | |
expected_outputs = len(outputs) | |
def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)): | |
payload = {"input": {}} | |
if api_id: | |
payload["version"] = api_id | |
parsed_url = urlparse(str(request.url)) | |
if local_base: | |
base_url = f"http://{hostname}:7860" | |
else: | |
base_url = parsed_url.scheme + "://" + parsed_url.netloc | |
for i, key in enumerate(names): | |
value = args[i] | |
if value and (os.path.exists(str(value))): | |
value = f"{base_url}/file=" + value | |
if value is not None and value != "": | |
payload["input"][key] = value | |
print(payload) | |
headers = {"Content-Type": "application/json"} | |
if replicate_token: | |
headers["Authorization"] = f"Token {replicate_token}" | |
print(headers) | |
response = requests.post(api_url, headers=headers, json=payload) | |
if response.status_code == 201: | |
follow_up_url = response.json()["urls"]["get"] | |
response = requests.get(follow_up_url, headers=headers) | |
while response.json()["status"] != "succeeded": | |
if response.json()["status"] == "failed": | |
raise gr.Error("The submission failed!") | |
response = requests.get(follow_up_url, headers=headers) | |
time.sleep(1) | |
# TODO: Add a failing mechanism if the API gets stuck | |
if response.status_code == 200: | |
json_response = response.json() | |
# If the output component is JSON return the entire output response | |
if outputs[0].get_config()["name"] == "json": | |
return json_response["output"] | |
predict_outputs = parse_outputs(json_response["output"]) | |
processed_outputs = process_outputs(predict_outputs) | |
difference_outputs = expected_outputs - len(processed_outputs) | |
# If less outputs than expected, hide the extra ones | |
if difference_outputs > 0: | |
extra_outputs = [gr.update(visible=False)] * difference_outputs | |
processed_outputs.extend(extra_outputs) | |
# If more outputs than expected, cap the outputs to the expected number if | |
elif difference_outputs < 0: | |
processed_outputs = processed_outputs[:difference_outputs] | |
return ( | |
tuple(processed_outputs) | |
if len(processed_outputs) > 1 | |
else processed_outputs[0] | |
) | |
else: | |
if response.status_code == 409: | |
raise gr.Error( | |
f"Sorry, the Cog image is still processing. Try again in a bit." | |
) | |
raise gr.Error(f"The submission failed! Error: {response.status_code}") | |
app = gr.Interface( | |
fn=predict, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
description=model_description, | |
allow_flagging="never", | |
) | |
return app | |
def create_gradio_app_script( | |
inputs_string, | |
outputs_string, | |
api_url, | |
api_id=None, | |
replicate_token=None, | |
title="", | |
model_description="", | |
local_base=False, | |
hostname="0.0.0.0" | |
): | |
headers = {"Content-Type": "application/json"} | |
if replicate_token: | |
headers["Authorization"] = f"Token {replicate_token}" | |
if local_base: | |
base_url = f'base_url = "http://{hostname}:7860"' | |
else: | |
base_url = """parsed_url = urlparse(str(request.url)) | |
base_url = parsed_url.scheme + "://" + parsed_url.netloc""" | |
headers_string = f"""headers = {headers}\n""" | |
api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else "" | |
definition_string = """expected_outputs = len(outputs) | |
def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):""" | |
payload_string = f"""payload = {{"input": {{}}}} | |
{api_id_value} | |
{base_url} | |
for i, key in enumerate(names): | |
value = args[i] | |
if value and (os.path.exists(str(value))): | |
value = f"{{base_url}}/file=" + value | |
if value is not None and value != "": | |
payload["input"][key] = value\n""" | |
request_string = ( | |
f"""response = requests.post("{api_url}", headers=headers, json=payload)\n""" | |
) | |
result_string = f""" | |
if response.status_code == 201: | |
follow_up_url = response.json()["urls"]["get"] | |
response = requests.get(follow_up_url, headers=headers) | |
while response.json()["status"] != "succeeded": | |
if response.json()["status"] == "failed": | |
raise gr.Error("The submission failed!") | |
response = requests.get(follow_up_url, headers=headers) | |
time.sleep(1) | |
if response.status_code == 200: | |
json_response = response.json() | |
#If the output component is JSON return the entire output response | |
if(outputs[0].get_config()["name"] == "json"): | |
return json_response["output"] | |
predict_outputs = parse_outputs(json_response["output"]) | |
processed_outputs = process_outputs(predict_outputs) | |
difference_outputs = expected_outputs - len(processed_outputs) | |
# If less outputs than expected, hide the extra ones | |
if difference_outputs > 0: | |
extra_outputs = [gr.update(visible=False)] * difference_outputs | |
processed_outputs.extend(extra_outputs) | |
# If more outputs than expected, cap the outputs to the expected number | |
elif difference_outputs < 0: | |
processed_outputs = processed_outputs[:difference_outputs] | |
return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0] | |
else: | |
if(response.status_code == 409): | |
raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.") | |
raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n""" | |
interface_string = f"""title = "{title}" | |
model_description = "{model_description}" | |
app = gr.Interface( | |
fn=predict, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
description=model_description, | |
allow_flagging="never", | |
) | |
app.launch(share=True) | |
""" | |
app_string = f"""import gradio as gr | |
from urllib.parse import urlparse | |
import requests | |
import time | |
import os | |
from utils.gradio_helpers import parse_outputs, process_outputs | |
{inputs_string} | |
{outputs_string} | |
{definition_string} | |
{headers_string} | |
{payload_string} | |
{request_string} | |
{result_string} | |
{interface_string} | |
""" | |
return app_string | |