Spaces:
Runtime error
Runtime error
import json | |
import logging | |
import gradio as gr | |
import requests | |
import base64 | |
from typing import Union | |
from enum import Enum | |
restapi_server = "https://23ztiveq.fn.bytedance.net" | |
class Task(str, Enum): | |
tag = "tag" | |
caption = "caption" | |
desc_basic = "desc_basic" | |
desc_plus = "desc_plus" | |
desc_max = "desc_max" | |
args_placeholder = { | |
Task.tag: {"image_base64": ""}, | |
Task.caption: {"image_base64": ""}, | |
Task.desc_basic: {"image_base64": "", "task": "basic"}, | |
Task.desc_plus: {"image_base64": "", "task": "plus"}, | |
Task.desc_max: {"image_base64": "", "task": "max"}, | |
} | |
internal_apis = { | |
Task.tag: "/img_tag", | |
Task.caption: "/img_caption", | |
Task.desc_basic: "/img_desc", | |
Task.desc_plus: "/img_desc", | |
Task.desc_max: "/img_desc", | |
} | |
def execute_task(img_path, task, params): | |
d = { | |
"task": task, | |
"image": img_path, | |
} | |
resp = make_interal_api_call(task, params) | |
text = post_process(task, resp) | |
d["resp"] = text | |
logging.info(f"executing {task}, with img:{img_path} with res: {d['resp']}, ") | |
return d | |
def post_process(task, resp): | |
if resp is None: | |
return None | |
if task == Task.desc_max: | |
text = resp["text"] | |
elif task in [Task.desc_plus, Task.desc_basic]: | |
# todo, after server unifing 3 output, should keep same with desc max | |
text = resp["text"][0] | |
elif task == Task.tag: | |
text = format_img_tag_resp(resp) | |
else: | |
text = resp | |
return text | |
def format_img_tag_resp(resp_json_str): | |
try: | |
json_d = json.loads(resp_json_str) | |
d = json_d[0] | |
lines = [] | |
for cat in d.keys(): | |
tags = ";".join( | |
[f"#{tag_d['chn_name']}#:{tag_d['score']:.3f}" for tag_d in d[cat]] | |
) | |
lines.append(f"|{cat}|:{tags}") | |
return "\n".join(lines) | |
except Exception as e: | |
logging.error(e) | |
return None | |
def make_interal_api_call(task, req_args_d): | |
headers = {"Content-Type": "application/json"} | |
url = restapi_server + internal_apis[task] | |
try: | |
resp = requests.post(url, json=req_args_d, headers=headers) | |
except Exception as e: | |
logging.error(e) | |
return None | |
if resp.status_code == 200: | |
return resp.json() | |
else: | |
return None | |
def greet(name): | |
return "Hello " + name + "!" | |
def img_2_base64(img_path): | |
if img_path is None: | |
return None | |
with open(img_path, "rb") as image_file: | |
img_bytes = image_file.read() | |
img_base64 = base64.b64encode(img_bytes).decode("utf-8") | |
return img_base64 | |
def get_desc_basic(img_path: Union[str, None]): | |
img_base64 = img_2_base64(img_path) | |
args = args_placeholder.get(Task.desc_basic).copy() | |
args.update({"image_base64": img_base64}) | |
resp = make_interal_api_call(Task.desc_basic, args) | |
resp = post_process(Task.desc_basic, resp) | |
return resp | |
def get_desc_plus(img_path: Union[str, None]): | |
img_base64 = img_2_base64(img_path) | |
args = args_placeholder.get(Task.desc_plus).copy() | |
args.update({"image_base64": img_base64}) | |
resp = make_interal_api_call(Task.desc_plus, args) | |
resp = post_process(Task.desc_plus, resp) | |
return resp | |
def get_desc_max(img_path: Union[str, None], top_k, top_p, temp, max_new_tokens): | |
task = Task.desc_max | |
img_base64 = img_2_base64(img_path) | |
args = args_placeholder.get(task).copy() | |
args.update( | |
{ | |
"image_base64": img_base64, | |
"max_new_tokens": max_new_tokens, | |
"temperature": temp, | |
"top_k": top_k, | |
"top_p": top_p, | |
} | |
) | |
resp = make_interal_api_call(task, args) | |
resp = post_process(task, resp) | |
return resp | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
## Image-Understanding | |
""" | |
) | |
with gr.Row(): | |
input_image = gr.Image(type="filepath", label="Image", sources=["upload"]) | |
with gr.Column(): | |
output_basic = gr.Textbox(label="Desc Basic") | |
output_plus = gr.Textbox(label="Desc Plus") | |
output_max = gr.Textbox(label="Desc Max") | |
with gr.Row(): | |
top_k = gr.Slider(1, 200, label="Top K", value=40) | |
top_p = gr.Slider(0, 1, label="Top P", value=1) | |
with gr.Row(): | |
temp = gr.Slider(0.001, 5, label="Temperature", value=1) | |
max_new_tokens = gr.Slider(1, 2000, label="Max New Tokens", value=300) | |
with gr.Row(): | |
submit_btn = gr.Button("π Submit Max") | |
empty_bin = gr.Button("π§Ή Clear") | |
input_image.change(get_desc_basic, inputs=input_image, outputs=output_basic) | |
input_image.change(get_desc_plus, inputs=input_image, outputs=output_plus) | |
input_image.clear( | |
lambda: [None] * 4, | |
None, | |
[input_image, output_basic, output_plus, output_max], | |
) | |
submit_btn.click( | |
get_desc_max, | |
inputs=[input_image, top_k, top_p, temp, max_new_tokens], | |
outputs=[output_max], | |
) | |
''' | |
gr.Markdown( | |
""" | |
## Mixtral Moe 8 x 7B in 4bit | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
instruct = gr.Textbox(label="Instruction") | |
input = gr.Textbox(label="Input") | |
with gr.Row(): | |
top_k = gr.Slider(1, 200, label="Top K", value=40) | |
top_p = gr.Slider(0, 1, label="Top P", value=1) | |
with gr.Row(): | |
temp = gr.Slider(0.001, 5, label="Temperature", value=1) | |
max_new_tokens = gr.Slider(1, 2000, label="Max New Tokens", value=300) | |
with gr.Row(): | |
submit_btn = gr.Button("π Submit") | |
empty_bin = gr.Button("π§Ή Clear") | |
with gr.Column(): | |
output = gr.Textbox(label="Output", lines=16) | |
submit_btn.click( | |
complete, [instruct, input, top_k, top_p, temp, max_new_tokens], [output] | |
) | |
empty_bin.click( | |
lambda: [None] * 3, None, [instruct, input, output], queue=False | |
) | |
''' | |
if __name__ == "__main__": | |
demo.launch(share=True) | |
""" | |
top_k = gr.Slider(1, 200, label="Top K", value=40) | |
top_p = gr.Slider(0, 1, label="Top P", value=1) | |
temp = gr.Slider(0.001, 5, label="Temperature", value=1) | |
output = gr.Textbox(label="Output", lines=10) | |
max_new_tokens = gr.Slider(1, 2000, label="Max New Tokens", value=300) | |
iface = gr.Interface( | |
fn=complete, | |
inputs=["text", "text", top_k, top_p, temp, max_new_tokens], | |
outputs=output, | |
) | |
iface.launch() | |
""" | |