Spaces:
Build error
Build error
import argparse | |
import queue | |
import sys | |
import uuid | |
from functools import partial | |
import numpy as np | |
import tritonclient.grpc as grpcclient | |
from tritonclient.utils import InferenceServerException | |
import gradio as gr | |
from functools import wraps | |
#### | |
from PIL import Image | |
import base64 | |
import io | |
##### | |
from http.server import HTTPServer, SimpleHTTPRequestHandler | |
import socket | |
#### | |
import os | |
import uuid | |
#### | |
class UserData: | |
def __init__(self): | |
self._completed_requests = queue.Queue() | |
def callback(user_data, result, error): | |
if error: | |
user_data._completed_requests.put(error) | |
else: | |
user_data._completed_requests.put(result) | |
def make_a_try(img_url,text): | |
model_name = 'ensemble_mllm' | |
user_data = UserData() | |
sequence_id = 100 | |
int_sequence_id0 = sequence_id | |
result_list=[] | |
try: | |
triton_client = grpcclient.InferenceServerClient( | |
url="10.95.163.43:8001", | |
# verbose=FLAGS.verbose, | |
verbose = True, #False | |
ssl=False, | |
root_certificates=None, | |
private_key=None, | |
certificate_chain=None, | |
) | |
except Exception as e: | |
print("channel creation failed: " + str(e)) | |
return "" | |
# Infer | |
inputs = [] | |
img_url_bytes = img_url.encode("utf-8") | |
img_url_bytes = np.array(img_url_bytes, dtype=bytes) | |
img_url_bytes = img_url_bytes.reshape([1, -1]) | |
inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES")) | |
inputs[0].set_data_from_numpy(img_url_bytes) | |
text_bytes = text.encode("utf-8") | |
text_bytes = np.array(text_bytes, dtype=bytes) | |
text_bytes = text_bytes.reshape([1, -1]) | |
# text_input = np.expand_dims(text_bytes, axis=0) | |
text_input = text_bytes | |
inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES")) | |
inputs[1].set_data_from_numpy(text_input) | |
outputs = [] | |
outputs.append(grpcclient.InferRequestedOutput("OUTPUT")) | |
# Test with outputs | |
results = triton_client.infer( | |
model_name=model_name, | |
inputs=inputs, | |
outputs=outputs, | |
client_timeout=None, #FLAGS.client_timeout, | |
# headers={"test": "1"}, | |
compression_algorithm=None, #FLAGS.grpc_compression_algorithm, | |
) | |
statistics = triton_client.get_inference_statistics(model_name=model_name) | |
print(statistics) | |
if len(statistics.model_stats) != 1: | |
print("FAILED: Inference Statistics") | |
return "" | |
# Get the output arrays from the results | |
output_data = results.as_numpy("OUTPUT") | |
result_str = output_data[0][0].decode('utf-8') | |
print("OUTPUT: "+ result_str) | |
return result_str | |
def greet(image, text): | |
###save img | |
static_path = f"/workdir/yanghandi/gradio_demo/static" | |
# 将图片转换为字节流 | |
img_byte_arr = io.BytesIO() | |
try: | |
image.save(img_byte_arr, format='JPEG') | |
except Exception: | |
return "" | |
img_byte_arr = img_byte_arr.getvalue() | |
# 为图片生成一个唯一的文件名 | |
# filename = "image_" + str(os.getpid()) + ".jpg" #uuid | |
unique_id = uuid.uuid4() | |
filename = f"image_{unique_id}.jpg" | |
filepath = os.path.join(static_path, filename) | |
# 将字节流写入文件 | |
with open(filepath, 'wb') as f: | |
f.write(img_byte_arr) | |
img_url = f"http://10.99.5.48:8080/file=static/" + filename | |
# img_url = PIL_to_URL(img_url) | |
# img_url = "http://10.99.5.48:8080/file=static/0000.jpeg" | |
result = make_a_try(img_url,text) | |
# print(result) | |
return result | |
def clear_output(): | |
return "" | |
def get_example(): | |
return [ | |
[f"/workdir/yanghandi/gradio_demo/static/0001.jpg", f"图中的人物是谁"] | |
] | |
if __name__ == "__main__": | |
param_info = {} | |
# param_info['appkey'] = "com.sankuai.automl.serving" | |
param_info['appkey'] = "10.199.14.151:8001" | |
# param_info['remote_appkey'] = "com.sankuai.automl.chat3" | |
param_info['remote_appkey'] = "10.199.14.151:8001" | |
param_info['model_name'] = 'ensemble_mllm' | |
param_info['model_version'] = "1" | |
param_info['time_out'] = 60000 | |
param_info['server_targets'] = [] | |
param_info['outputs'] = 'response' | |
gr.set_static_paths(paths=["static/"]) | |
with gr.Blocks(title='demo') as demo: | |
gr.Markdown("# 自研模型测试demo") | |
gr.Markdown("尝试使用该demo,上传图片并开始讨论它,或者尝试下面的例子") | |
with gr.Row(): | |
with gr.Column(): | |
# imagebox = gr.Image(value="static/0000.jpeg",type="pil") | |
imagebox = gr.Image(type="pil") | |
promptbox = gr.Textbox(label = "prompt") | |
with gr.Column(): | |
output = gr.Textbox(label = "output") | |
with gr.Row(): | |
submit = gr.Button("submit") | |
clear = gr.Button("clear") | |
submit.click(fn=greet,inputs=[imagebox, promptbox],outputs=[output]) | |
clear.click(fn=clear_output, inputs=[], outputs=[output]) | |
gr.Markdown("# example") | |
gr.Examples( | |
examples = get_example(), | |
fn = greet, | |
inputs=[imagebox, promptbox], | |
outputs = [output], | |
cache_examples = True | |
) | |
demo.launch(server_name="0.0.0.0", server_port=8080, debug=True, share=True) | |
# img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg" | |
# # img_url = f"http://10.99.5.48:8080/file=static/static/image_cff7077b-3506-4253-82b7-b6547f2f63c1.jpg" | |
# text = f"talk about this women" | |
# greet(img_url,text) | |