hd0013 commited on
Commit
7f119fd
1 Parent(s): 713b79b

Upload folder using huggingface_hub

Browse files
Files changed (39) hide show
  1. .gitignore +2 -0
  2. .vscode/launch.json +17 -0
  3. 0000.jpeg +0 -0
  4. 0001.jpg +0 -0
  5. README.md +3 -9
  6. gradio_cached_examples/15/log.csv +2 -0
  7. main.py +83 -0
  8. multi_main.py +105 -0
  9. requirements.txt +78 -0
  10. run.sh +9 -0
  11. setup.sh +2 -0
  12. static/0000.jpeg +0 -0
  13. static/0001.jpg +0 -0
  14. static/image_1149d78e-43f6-4d5b-8e63-e24c2641012b.jpg +0 -0
  15. static/image_11bc11b6-57a9-42d3-81fc-97d957e62f28.jpg +0 -0
  16. static/image_33685dfc-b96f-401c-847d-3dd537186ec2.jpg +0 -0
  17. static/image_4a883844-b1b2-4154-bbfe-0b108d93dca6.jpg +0 -0
  18. static/image_4e43dfc4-9362-4e8d-b20f-d9816a9b684d.jpg +0 -0
  19. static/image_51e314f6-0935-4430-a53f-bfdc22ca7cc4.jpg +0 -0
  20. static/image_6c63b8a2-8bfc-485e-a238-94e76f22e7db.jpg +0 -0
  21. static/image_82c514bf-bdc1-4d60-b073-a97b5816ea56.jpg +0 -0
  22. static/image_845cd65a-30a7-43c3-be8c-819bc04de98d.jpg +0 -0
  23. static/image_86804064-1f0e-4130-bcc3-30dd839a3c0a.jpg +0 -0
  24. static/image_9afab29f-77d9-4451-8a52-092a5c37625d.jpg +0 -0
  25. static/image_9c5c694a-acef-4545-8759-75ccffa39f6d.jpg +0 -0
  26. static/image_b73c1efa-8112-404c-ab9e-eddb870f43af.jpg +0 -0
  27. static/image_bbacb8a4-89d9-45ac-9da2-28717009e750.jpg +0 -0
  28. static/image_bd087d82-9ed6-4b68-8ac1-c1417cfbb995.jpg +0 -0
  29. static/image_ca9e5436-bc9b-4242-9a02-261d2209feb3.jpg +0 -0
  30. static/image_dd018613-a2be-4eef-8183-a01929c835fd.jpg +0 -0
  31. static/image_e63e5e5c-41cc-4f42-b896-95644d6b28f3.jpg +0 -0
  32. static/image_e6c53092-7b49-4ba5-96dc-38e9479d3c6a.jpg +0 -0
  33. static/image_f9bef0de-bc72-415d-b527-6c068594af02.jpg +0 -0
  34. test_hd.py +187 -0
  35. try_demo.py +224 -0
  36. try_demo_demo.py +206 -0
  37. try_grpc.py +139 -0
  38. try_hd.py +137 -0
  39. try_hd_v2.py +218 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ flagged
2
+ *.log
.vscode/launch.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // 使用 IntelliSense 了解相关属性。
3
+ // 悬停以查看现有属性的描述。
4
+ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python: 当前文件",
9
+ "type": "python",
10
+ "request": "launch",
11
+ "program": "test_hd.py",
12
+ "console": "integratedTerminal",
13
+ "python" :"/home/hadoop-automl/tianrunhe/anaconda3/envs/gradio_env_3.9/bin/python"
14
+ // "python":"/home/hadoop-automl/tianrunhe/anaconda3/envs/gradio/bin/python"
15
+ }
16
+ ]
17
+ }
0000.jpeg ADDED
0001.jpg ADDED
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Gradio Demo
3
- emoji:
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.31.5
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: gradio_demo
3
+ app_file: try_grpc.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.31.4
 
 
6
  ---
 
 
gradio_cached_examples/15/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ output,flag,username,timestamp
2
+ ,,,2024-05-21 21:35:52.230145
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from TritonServerClient import TritonServerClient, InferInput, InferRequestedOutput
3
+ from TritonServerClient.utils import np_to_triton_dtype
4
+ from functools import wraps
5
+ import numpy as np
6
+
7
+ def prepare_params(query):
8
+ bs=1
9
+ # query = "#角色:直播间文案写作专家。#任务:根据用户输入的商品信息写出用于推广营销的直播文案。#目标:引导用户购买直播间商品,提高销量。#要求:- 如果没有输入商品信息,则邀请输入。- 文案写作 - 按照“引出话题”、“产品讲解”、“购买方式”三个方面写作文案。 - 在“产品讲解”部分详细讲解产品细节: - 【重要】充分扩充内容,吸引用户购买。 - 详细描述商品的优惠活动信息(如用户输入没有,不要询问,直接补充)。 - 在“购买方式”部分: - 仅包括引导用户点击链接购买的内容。 - 不要包括快递物流承诺。 - 【重要】每个部分多于5句话。 - 使用口语化的文字。 - 表达方式有亲和力和感染力。 - 对用户们使用“宝子们”、“宝宝们”、“家人们”等有亲和力的称呼。 - 多使用自问自答的语句。- 当你准备好了回复“请输入商品描述和优惠信息”我们就开始。#输出:- 包含在""里。- 写出每一部分内容的标题,包括【引出话题】、【产品讲解】、【购买方式】。- 示例:""【引出话题】【产品讲解】【购买方式】""- 输入商品信息:%s" % prompt
10
+ # query = '请你扮演直播间直播卖货的主播角色,根据用户提出的问题,检索知识库中的信息,如果查找到相关的信息,使用口语化的文字,用非常简短的一句话回答用户的问题;如果查找不到相关的信息,直接回答没有,并引导用户关注本场直播商品。你的目的是激发用户的购买欲、提高直播商品销量。回答对用户们使用“宝子”、“宝宝”、“亲”等有亲和力的称呼。\n\n# 用户提出的问题:\n这里面有可乐吗\n\n# 知识库:\n商品名称:外送欢乐桶(券)兑换券\n直播链接序号:64号链接号链接\n商品现价:66元\n商品原价:142元\n折扣:5\n品牌名称:麦当劳麦乐送\n券包张数:1\n套餐详情:["麦辣鸡腿汉堡1个","板烧鸡腿堡1个","中可乐3杯","四拼小食桶A"]\n补充:四拼小食桶A(麦麦脆汁鸡(鸡腿)1块+中薯条1份+麦辣鸡翅2块+麦乐鸡5块)+麦辣鸡腿堡1个+板烧鸡腿堡1个+中可乐3杯 原材料:面包\n\n\n# 输出格式:\n答:\n\n# 要求\n对于用户问到了知识库中未提及的信息不要编造,直接不回答'
11
+
12
+ title_text = np.array([query.encode('utf-8')], dtype=np.string_)
13
+ title_text = np.tile(title_text, (bs, 1))
14
+
15
+ data_batch = {}
16
+ data_batch['query'] = title_text
17
+
18
+ inputs = [
19
+ InferInput("query", data_batch['query'].shape,
20
+ np_to_triton_dtype(data_batch['query'].dtype)),
21
+ ]
22
+
23
+ inputs[0].set_data_from_numpy(data_batch['query'])
24
+
25
+ return inputs
26
+
27
+
28
+ def make_a_try(inputs, outputs='response', model_name='qwen', model_version='1'):
29
+ outputs_list = []
30
+ ori_outputs_list = outputs.strip().split(",")
31
+ for out_ele in ori_outputs_list:
32
+ outputs_list.append(out_ele.strip())
33
+ outputs = [InferRequestedOutput(x) for x in outputs_list]
34
+
35
+ response = my_client.predict(model_name=model_name, inputs=inputs, model_version=model_version, outputs=outputs)
36
+
37
+ rsp_info = {}
38
+ if outputs_list == []:
39
+ for out_name_ele in response._result.outputs:
40
+ outputs_list.append(out_name_ele.name)
41
+ for output_name in outputs_list:
42
+ res = response.as_numpy(output_name)
43
+ response = np.expand_dims(res, axis=0)
44
+ response = response[0].decode('utf-8')
45
+ rsp_info[output_name] = response
46
+
47
+ return rsp_info['response']
48
+
49
+
50
+
51
+ def greet(prompt):
52
+ """Greet someone."""
53
+ # print(prompt)
54
+ inputs = prepare_params(prompt)
55
+ result = make_a_try(inputs)
56
+
57
+ return result
58
+
59
+
60
+ if __name__ == "__main__":
61
+ param_info = {}
62
+ param_info['appkey'] = "com.sankuai.automl.serving"
63
+ # param_info['appkey'] = "com.sankuai.automl.streamvlm"
64
+
65
+ param_info['remote_appkey'] = "com.sankuai.automl.chat3"
66
+ param_info['model_name'] = "qwen"
67
+ param_info['model_version'] = "1"
68
+ param_info['time_out'] = 60000
69
+ param_info['server_targets'] = []
70
+ param_info['outputs'] = 'response'
71
+
72
+ appkey, remote_appkey, model_name, model_version, time_out, server_targets = param_info['appkey'], param_info['remote_appkey'], param_info['model_name'], param_info['model_version'], param_info['time_out'], param_info['server_targets']
73
+
74
+ # my_client = TritonServerClient(appkey=appkey, remote_appkey=remote_appkey, time_out=time_out, server_targets=server_targets)
75
+
76
+ # 以上部分,用户可以修改代码
77
+ demo = gr.Interface(
78
+ fn=greet,
79
+ inputs=["textbox"],
80
+ outputs=["textbox"],
81
+ )
82
+
83
+ demo.launch(server_name="0.0.0.0", server_port=8088, debug=True, share=True)
multi_main.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from TritonServerClient import TritonServerClient, InferInput, InferRequestedOutput
3
+ from TritonServerClient.utils import np_to_triton_dtype
4
+ from functools import wraps
5
+ import numpy as np
6
+
7
+
8
+ def prepare_params(query):
9
+ bs=1
10
+ # query = "#角色:直播间文案写作专家。#任务:根据用户输入的商品信息写出用于推广营销的直播文案。#目标:引导用户购买直播间商品,提高销量。#要求:- 如果没有输入商品信息,则邀请输入。- 文案写作 - 按照“引出话题”、“产品讲解”、“购买方式”三个方面写作文案。 - 在“产品讲解”部分详细讲解产品细节: - 【重要】充分扩充内容,吸引用户购买。 - 详细描述商品的优惠活动信息(如用户输入没有,不要询问,直接补充)。 - 在“购买方式”部分: - 仅包括引导用户点击链接购买的内容。 - 不要包括快递物流承诺。 - 【重要】每个部分多于5句话。 - 使用口语化的文字。 - 表达方式有亲和力和感染力。 - 对用户们使用“宝子们”、“宝宝们”、“家人们”等有亲和力的称呼。 - 多使用自问自答的语句。- 当你准备好了回复“请输入商品描述和优惠信息”我们就开始。#输出:- 包含在""里。- 写出每一部分内容的标题,包括【引出话题】、【产品讲解】、【购买方式】。- 示例:""【引出话题】【产品讲解】【购买方式】""- 输入商品信息:%s" % prompt
11
+ # query = '请你扮演直播间直播卖货的主播角色,根据用户提出的问题,检索知识库中的信息,如果查找到相关的信息,使用口语化的文字,用非常简短的一句话回答用户的问题;如果查找不到相关的信息,直接回答没有,并引导用户关注本场直播商品。你的目的是激发用户的购买欲、提高直播商品销量。回答对用户们使用“宝子”、“宝宝”、“亲”等有亲和力的称呼。\n\n# 用户提出的问题:\n这里面有可乐吗\n\n# 知识库:\n商品名称:外送欢乐桶(券)兑换券\n直播链接序号:64号链接号链接\n商品现价:66元\n商品原价:142元\n折扣:5\n品牌名称:麦当劳麦乐送\n券包张数:1\n套餐详情:["麦辣鸡腿汉堡1个","板烧鸡腿堡1个","中可乐3杯","四拼小食桶A"]\n补充:四拼小食桶A(麦麦脆汁鸡(鸡腿)1块+中薯条1份+麦辣鸡翅2块+麦乐鸡5块)+麦辣鸡腿堡1个+板烧鸡腿堡1个+中可乐3杯 原材料:面包\n\n\n# 输出格式:\n答:\n\n# 要求\n对于用户问到了知识库中未提及的信息不要编造,直接不回答'
12
+
13
+ title_text = np.array([query.encode('utf-8')], dtype=np.string_)
14
+ title_text = np.tile(title_text, (bs, 1))
15
+
16
+ data_batch = {}
17
+ data_batch['query'] = title_text
18
+
19
+ inputs = [
20
+ InferInput("query", data_batch['query'].shape,
21
+ np_to_triton_dtype(data_batch['query'].dtype)),
22
+ ]
23
+
24
+ inputs[0].set_data_from_numpy(data_batch['query'])
25
+
26
+ return inputs
27
+
28
+
29
+ def make_a_try(inputs, outputs='response', model_name='qwen', model_version='1'):
30
+ outputs_list = []
31
+ ori_outputs_list = outputs.strip().split(",")
32
+ for out_ele in ori_outputs_list:
33
+ outputs_list.append(out_ele.strip())
34
+ outputs = [InferRequestedOutput(x) for x in outputs_list]
35
+
36
+ response = my_client.predict(model_name=model_name, inputs=inputs, model_version=model_version, outputs=outputs)
37
+
38
+ rsp_info = {}
39
+ if outputs_list == []:
40
+ for out_name_ele in response._result.outputs:
41
+ outputs_list.append(out_name_ele.name)
42
+ for output_name in outputs_list:
43
+ res = response.as_numpy(output_name)
44
+ response = np.expand_dims(res, axis=0)
45
+ response = response[0].decode('utf-8')
46
+ rsp_info[output_name] = response
47
+ print("response:",rsp_info)
48
+ return rsp_info['response']
49
+
50
+
51
+
52
+ def greet(prompt):
53
+ """Greet someone."""
54
+ # print(prompt)
55
+ print("prompt:",prompt)
56
+ inputs = prepare_params(prompt)
57
+ print(inputs)
58
+ result = make_a_try(inputs)
59
+
60
+ return result
61
+
62
+ def clear_input():
63
+
64
+ return ""
65
+
66
+ if __name__ == "__main__":
67
+ param_info = {}
68
+ # param_info['appkey'] = "com.sankuai.automl.serving"
69
+ param_info['appkey'] = "com.sankuai.automl.streamvlm"
70
+
71
+ param_info['remote_appkey'] = "com.sankuai.automl.chat3"
72
+ param_info['model_name'] = "qwen"
73
+ param_info['model_version'] = "1"
74
+ param_info['time_out'] = 60000
75
+ param_info['server_targets'] = []
76
+ param_info['outputs'] = 'response'
77
+
78
+ appkey, remote_appkey, model_name, model_version, time_out, server_targets = param_info['appkey'], param_info['remote_appkey'], param_info['model_name'], param_info['model_version'], param_info['time_out'], param_info['server_targets']
79
+
80
+ my_client = TritonServerClient(appkey=appkey, remote_appkey=remote_appkey, time_out=time_out, server_targets=server_targets)
81
+
82
+ # # 以上部分,用户可以��改代码
83
+ # demo = gr.Interface(
84
+ # fn=greet,
85
+ # inputs=["textbox"],
86
+ # outputs=["textbox"],
87
+ # )
88
+ with gr.Blocks(title='demo') as demo:
89
+ with gr.Row():
90
+ with gr.Column():
91
+ promptbox = gr.Textbox(label = "prompt")
92
+
93
+ with gr.Column():
94
+ output = gr.Textbox(label = "output")
95
+ with gr.Row():
96
+ submit = gr.Button("submit")
97
+ clear = gr.Button("clear")
98
+
99
+ submit.click(fn=greet,inputs=[promptbox],outputs=[output])
100
+ clear.click(fn=clear_input, inputs=[], outputs=[output])
101
+
102
+ # demo.launch(server_name="0.0.0.0", server_port=8088, debug=True, share=True)
103
+ # demo.launch(server_name="0.0.0.0", server_port=8080, debug=True, share=True)
104
+
105
+ #http://10.99.5.48:8080/
requirements.txt ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.8.6
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ annotated-types==0.5.0
6
+ anyio==3.7.1
7
+ async-timeout==4.0.3
8
+ asynctest==0.13.0
9
+ attrs==23.2.0
10
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ cycler==0.11.0
14
+ exceptiongroup==1.2.1
15
+ fastapi==0.99.1
16
+ ffmpy==0.3.2
17
+ filelock==3.12.2
18
+ fonttools==4.38.0
19
+ frozenlist==1.3.3
20
+ fsspec==2023.1.0
21
+ gradio==3.34.0
22
+ gradio_client==0.2.6
23
+ grpcio==1.62.2
24
+ h11==0.14.0
25
+ httpcore==0.17.3
26
+ httpx==0.24.1
27
+ huggingface-hub==0.16.4
28
+ idna==3.7
29
+ importlib-metadata==6.7.0
30
+ importlib-resources==5.12.0
31
+ Jinja2==3.1.3
32
+ jsonschema==4.17.3
33
+ kiwisolver==1.4.5
34
+ linkify-it-py==2.0.3
35
+ markdown-it-py==2.2.0
36
+ MarkupSafe==2.1.5
37
+ matplotlib==3.5.3
38
+ mdit-py-plugins==0.3.3
39
+ mdurl==0.1.2
40
+ multidict==6.0.5
41
+ numpy==1.21.6
42
+ octo-rpc==0.4.7
43
+ orjson==3.9.7
44
+ packaging==24.0
45
+ pandas==1.3.5
46
+ Pillow==9.5.0
47
+ pkgutil_resolve_name==1.3.10
48
+ ply==3.11
49
+ protobuf==3.20.1
50
+ psutil==5.9.8
51
+ pydantic==1.10.11
52
+ pydantic_core==2.14.6
53
+ pydub==0.25.1
54
+ Pygments==2.17.2
55
+ pyparsing==3.1.2
56
+ pyrsistent==0.19.3
57
+ python-cat==0.0.11
58
+ python-dateutil==2.9.0.post0
59
+ python-multipart==0.0.8
60
+ pytz==2024.1
61
+ PyYAML==6.0.1
62
+ requests==2.31.0
63
+ semantic-version==2.10.0
64
+ six==1.16.0
65
+ sniffio==1.3.1
66
+ starlette==0.27.0
67
+ thrift==0.20.0
68
+ thriftpy2==0.4.20
69
+ toolz==0.12.1
70
+ tqdm==4.66.2
71
+ TritonServerClient==0.0.7
72
+ typing_extensions==4.7.1
73
+ uc-micro-py==1.0.3
74
+ urllib3==2.0.7
75
+ uvicorn==0.22.0
76
+ websockets==11.0.3
77
+ yarl==1.9.4
78
+ zipp==3.15.0
run.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ source /workdir/yanghandi/gradio_demo/setup.sh
2
+ # source setup.sh
3
+ python main.py &> gradio.log &
4
+
5
+ python try_hd_v2.py &> multi_gradio.log &
6
+
7
+ python try_hd_v2.py &> try_hd_v2.log &
8
+
9
+ python test_hd.py &> test_hd.log &
setup.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ conda deactivate
2
+ conda activate gradio
static/0000.jpeg ADDED
static/0001.jpg ADDED
static/image_1149d78e-43f6-4d5b-8e63-e24c2641012b.jpg ADDED
static/image_11bc11b6-57a9-42d3-81fc-97d957e62f28.jpg ADDED
static/image_33685dfc-b96f-401c-847d-3dd537186ec2.jpg ADDED
static/image_4a883844-b1b2-4154-bbfe-0b108d93dca6.jpg ADDED
static/image_4e43dfc4-9362-4e8d-b20f-d9816a9b684d.jpg ADDED
static/image_51e314f6-0935-4430-a53f-bfdc22ca7cc4.jpg ADDED
static/image_6c63b8a2-8bfc-485e-a238-94e76f22e7db.jpg ADDED
static/image_82c514bf-bdc1-4d60-b073-a97b5816ea56.jpg ADDED
static/image_845cd65a-30a7-43c3-be8c-819bc04de98d.jpg ADDED
static/image_86804064-1f0e-4130-bcc3-30dd839a3c0a.jpg ADDED
static/image_9afab29f-77d9-4451-8a52-092a5c37625d.jpg ADDED
static/image_9c5c694a-acef-4545-8759-75ccffa39f6d.jpg ADDED
static/image_b73c1efa-8112-404c-ab9e-eddb870f43af.jpg ADDED
static/image_bbacb8a4-89d9-45ac-9da2-28717009e750.jpg ADDED
static/image_bd087d82-9ed6-4b68-8ac1-c1417cfbb995.jpg ADDED
static/image_ca9e5436-bc9b-4242-9a02-261d2209feb3.jpg ADDED
static/image_dd018613-a2be-4eef-8183-a01929c835fd.jpg ADDED
static/image_e63e5e5c-41cc-4f42-b896-95644d6b28f3.jpg ADDED
static/image_e6c53092-7b49-4ba5-96dc-38e9479d3c6a.jpg ADDED
static/image_f9bef0de-bc72-415d-b527-6c068594af02.jpg ADDED
test_hd.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import queue
3
+ import sys
4
+ import uuid
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import tritonclient.grpc as grpcclient
9
+ from tritonclient.utils import InferenceServerException
10
+ import gradio as gr
11
+ from functools import wraps
12
+
13
+ ####
14
+ from PIL import Image
15
+ import base64
16
+ import io
17
+ #####
18
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
19
+ import socket
20
+ ####
21
+ import os
22
+ import uuid
23
+ ####
24
+
25
+ class UserData:
26
+ def __init__(self):
27
+ self._completed_requests = queue.Queue()
28
+
29
+ def callback(user_data, result, error):
30
+ if error:
31
+ user_data._completed_requests.put(error)
32
+ else:
33
+ user_data._completed_requests.put(result)
34
+
35
+ def make_a_try(img_url,text):
36
+ model_name = 'ensemble_mllm'
37
+ user_data = UserData()
38
+ sequence_id = 100
39
+ int_sequence_id0 = sequence_id
40
+ result_list=[]
41
+ try:
42
+ triton_client = grpcclient.InferenceServerClient(
43
+ url="10.95.163.43:8001",
44
+ # verbose=FLAGS.verbose,
45
+ verbose = True, #False
46
+ ssl=False,
47
+ root_certificates=None,
48
+ private_key=None,
49
+ certificate_chain=None,
50
+ )
51
+ except Exception as e:
52
+ print("channel creation failed: " + str(e))
53
+ return ""
54
+ # Infer
55
+ inputs = []
56
+ img_url_bytes = img_url.encode("utf-8")
57
+ img_url_bytes = np.array(img_url_bytes, dtype=bytes)
58
+ img_url_bytes = img_url_bytes.reshape([1, -1])
59
+
60
+ inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES"))
61
+ inputs[0].set_data_from_numpy(img_url_bytes)
62
+
63
+ text_bytes = text.encode("utf-8")
64
+ text_bytes = np.array(text_bytes, dtype=bytes)
65
+ text_bytes = text_bytes.reshape([1, -1])
66
+ # text_input = np.expand_dims(text_bytes, axis=0)
67
+ text_input = text_bytes
68
+
69
+ inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES"))
70
+ inputs[1].set_data_from_numpy(text_input)
71
+
72
+ outputs = []
73
+ outputs.append(grpcclient.InferRequestedOutput("OUTPUT"))
74
+ # Test with outputs
75
+ results = triton_client.infer(
76
+ model_name=model_name,
77
+ inputs=inputs,
78
+ outputs=outputs,
79
+ client_timeout=None, #FLAGS.client_timeout,
80
+ # headers={"test": "1"},
81
+ compression_algorithm=None, #FLAGS.grpc_compression_algorithm,
82
+ )
83
+
84
+ statistics = triton_client.get_inference_statistics(model_name=model_name)
85
+ print(statistics)
86
+ if len(statistics.model_stats) != 1:
87
+ print("FAILED: Inference Statistics")
88
+ return ""
89
+
90
+ # Get the output arrays from the results
91
+ output_data = results.as_numpy("OUTPUT")
92
+ result_str = output_data[0][0].decode('utf-8')
93
+
94
+ print("OUTPUT: "+ result_str)
95
+ return result_str
96
+
97
+ def greet(image, text):
98
+ ###save img
99
+ static_path = f"/workdir/yanghandi/gradio_demo/static"
100
+ # 将图片转换为字节流
101
+ img_byte_arr = io.BytesIO()
102
+ try:
103
+ image.save(img_byte_arr, format='JPEG')
104
+ except Exception:
105
+ return ""
106
+ img_byte_arr = img_byte_arr.getvalue()
107
+
108
+ # 为图片生成一个唯一的文件名
109
+ # filename = "image_" + str(os.getpid()) + ".jpg" #uuid
110
+ unique_id = uuid.uuid4()
111
+ filename = f"image_{unique_id}.jpg"
112
+ filepath = os.path.join(static_path, filename)
113
+
114
+ # 将字节流写入文件
115
+ with open(filepath, 'wb') as f:
116
+ f.write(img_byte_arr)
117
+
118
+
119
+ img_url = f"http://10.99.5.48:8080/file=static/" + filename
120
+ # img_url = PIL_to_URL(img_url)
121
+ # img_url = "http://10.99.5.48:8080/file=static/0000.jpeg"
122
+ result = make_a_try(img_url,text)
123
+ # print(result)
124
+ return result
125
+
126
+
127
+ def clear_output():
128
+
129
+ return ""
130
+
131
+ def get_example():
132
+ return [
133
+ [f"/workdir/yanghandi/gradio_demo/static/0001.jpg", f"图中的人物是谁"]
134
+ ]
135
+ if __name__ == "__main__":
136
+
137
+ param_info = {}
138
+ # param_info['appkey'] = "com.sankuai.automl.serving"
139
+ param_info['appkey'] = "10.199.14.151:8001"
140
+
141
+ # param_info['remote_appkey'] = "com.sankuai.automl.chat3"
142
+ param_info['remote_appkey'] = "10.199.14.151:8001"
143
+ param_info['model_name'] = 'ensemble_mllm'
144
+ param_info['model_version'] = "1"
145
+ param_info['time_out'] = 60000
146
+ param_info['server_targets'] = []
147
+ param_info['outputs'] = 'response'
148
+
149
+
150
+ gr.set_static_paths(paths=["static/"])
151
+
152
+ with gr.Blocks(title='demo') as demo:
153
+ gr.Markdown("# 自研模型测试demo")
154
+ gr.Markdown("尝试使用该demo,上传图片并开始讨论它,或者尝试下面的例子")
155
+
156
+ with gr.Row():
157
+ with gr.Column():
158
+ # imagebox = gr.Image(value="static/0000.jpeg",type="pil")
159
+ imagebox = gr.Image(type="pil")
160
+ promptbox = gr.Textbox(label = "prompt")
161
+
162
+ with gr.Column():
163
+ output = gr.Textbox(label = "output")
164
+ with gr.Row():
165
+ submit = gr.Button("submit")
166
+ clear = gr.Button("clear")
167
+
168
+ submit.click(fn=greet,inputs=[imagebox, promptbox],outputs=[output])
169
+ clear.click(fn=clear_output, inputs=[], outputs=[output])
170
+
171
+ gr.Markdown("# example")
172
+
173
+ gr.Examples(
174
+ examples = get_example(),
175
+ fn = greet,
176
+ inputs=[imagebox, promptbox],
177
+ outputs = [output],
178
+ cache_examples = True
179
+ )
180
+
181
+ demo.launch(server_name="0.0.0.0", server_port=8080, debug=True, share=True)
182
+
183
+
184
+ # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
185
+ # # img_url = f"http://10.99.5.48:8080/file=static/static/image_cff7077b-3506-4253-82b7-b6547f2f63c1.jpg"
186
+ # text = f"talk about this women"
187
+ # greet(img_url,text)
try_demo.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions
6
+ # are met:
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ # * Redistributions in binary form must reproduce the above copyright
10
+ # notice, this list of conditions and the following disclaimer in the
11
+ # documentation and/or other materials provided with the distribution.
12
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
13
+ # contributors may be used to endorse or promote products derived
14
+ # from this software without specific prior written permission.
15
+ #
16
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
+
28
+ import argparse
29
+ import queue
30
+ import sys
31
+ import uuid
32
+ from functools import partial
33
+
34
+ import numpy as np
35
+ import tritonclient.grpc as grpcclient
36
+ from tritonclient.utils import InferenceServerException
37
+
38
+ ##
39
+ import time
40
+ import threading
41
+ ###
42
+
43
+ FLAGS = None
44
+
45
+
46
+ class UserData:
47
+ def __init__(self):
48
+ self._completed_requests = queue.Queue()
49
+
50
+
51
+ # Define the callback function. Note the last two parameters should be
52
+ # result and error. InferenceServerClient would povide the results of an
53
+ # inference as grpcclient.InferResult in result. For successful
54
+ # inference, error will be None, otherwise it will be an object of
55
+ # tritonclientutils.InferenceServerException holding the error details
56
+ def callback(user_data, result, error):
57
+ if error:
58
+ user_data._completed_requests.put(error)
59
+ else:
60
+ user_data._completed_requests.put(result)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument(
66
+ "-v",
67
+ "--verbose",
68
+ action="store_true",
69
+ required=False,
70
+ default=False,
71
+ help="Enable verbose output",
72
+ )
73
+ # parser.add_argument(
74
+ # "-u",
75
+ # "--url",
76
+ # type=str,
77
+ # required=False,
78
+ # default="localhost:8001",
79
+ # help="Inference server URL and it gRPC port. Default is localhost:8001.",
80
+ # )
81
+ parser.add_argument(
82
+ "-u",
83
+ "--url",
84
+ type=str,
85
+ required=False,
86
+ default="10.199.14.151:8001",
87
+ help="Inference server URL and it gRPC port. Default is localhost:8001.",
88
+ )
89
+ parser.add_argument(
90
+ "-t",
91
+ "--stream-timeout",
92
+ type=float,
93
+ required=False,
94
+ default=None,
95
+ help="Stream timeout in seconds. Default is None.",
96
+ )
97
+ # parser.add_argument(
98
+ # "-d",
99
+ # "--dyna",
100
+ # action="store_true",
101
+ # required=False,
102
+ # default=False,
103
+ # help="Assume dynamic sequence model",
104
+ # )
105
+ # parser.add_argument(
106
+ # "-o",
107
+ # "--offset",
108
+ # type=int,
109
+ # required=False,
110
+ # default=0,
111
+ # help="Add offset to sequence ID used",
112
+ # )
113
+
114
+ FLAGS = parser.parse_args()
115
+
116
+ # # We use custom "sequence" models which take 1 input
117
+ # # value. The output is the accumulated value of the inputs. See
118
+ # # src/custom/sequence.
119
+ # int_sequence_model_name = (
120
+ # "simple_dyna_sequence" if FLAGS.dyna else "simple_sequence"
121
+ # )
122
+ # string_sequence_model_name = (
123
+ # "simple_string_dyna_sequence" if FLAGS.dyna else "simple_sequence"
124
+ # )
125
+ model_name = 'ensemble_mllm'
126
+ model_version = ""
127
+ batch_size = 1
128
+ # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
129
+ img_url = "/workdir/yanghandi/gradio_demo/static/0000.jpeg"
130
+ # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0003.jpeg"
131
+ text = f"详细描述一下这张图片"
132
+ sequence_id = 100
133
+ int_sequence_id0 = sequence_id
134
+
135
+ result_list = []
136
+ user_data = UserData()
137
+
138
+ # It is advisable to use client object within with..as clause
139
+ # when sending streaming requests. This ensures the client
140
+ # is closed when the block inside with exits.
141
+ with grpcclient.InferenceServerClient(
142
+ url=FLAGS.url, verbose=FLAGS.verbose
143
+ ) as triton_client:
144
+ try:
145
+ # Establish stream
146
+ triton_client.start_stream(
147
+ callback=partial(callback, user_data),
148
+ stream_timeout=FLAGS.stream_timeout,
149
+ )
150
+
151
+ # Create the tensor for INPUT
152
+ inputs = []
153
+ img_url_bytes = img_url.encode("utf-8")
154
+ img_url_bytes = np.array(img_url_bytes, dtype=bytes)
155
+ img_url_bytes = img_url_bytes.reshape([1, -1])
156
+
157
+ inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES"))
158
+ inputs[0].set_data_from_numpy(img_url_bytes)
159
+
160
+ text_bytes = text.encode("utf-8")
161
+ text_bytes = np.array(text_bytes, dtype=bytes)
162
+ text_bytes = text_bytes.reshape([1, -1])
163
+ # text_input = np.expand_dims(text_bytes, axis=0)
164
+ text_input = text_bytes
165
+
166
+ inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES"))
167
+ inputs[1].set_data_from_numpy(text_input)
168
+
169
+ outputs = []
170
+ outputs.append(grpcclient.InferRequestedOutput("OUTPUT"))
171
+ # Issue the asynchronous sequence inference.
172
+ triton_client.async_stream_infer(
173
+ model_name=model_name,
174
+ inputs=inputs,
175
+ outputs=outputs,
176
+ request_id="{}".format(sequence_id),
177
+ sequence_id=sequence_id,
178
+ sequence_start=True,
179
+ sequence_end=True,
180
+ )
181
+
182
+ except InferenceServerException as error:
183
+ print(error)
184
+ sys.exit(1)
185
+
186
+ # Retrieve results...
187
+ recv_count = 0
188
+
189
+ #####
190
+
191
+ ####
192
+
193
+ while True:
194
+
195
+ # if len(result_list) == 80:
196
+ # print("1")
197
+ data_item = user_data._completed_requests.get()
198
+ # try:
199
+ # data_item = user_data._completed_requests.get(timeout=5)
200
+ # except Exception as e:
201
+ # print("queue wrong")
202
+ # break
203
+ if type(data_item) == InferenceServerException:
204
+ print('InferenceServerException: ', data_item)
205
+ sys.exit(1)
206
+ this_id = data_item.get_response().id.split("_")[0]
207
+ if int(this_id) != int_sequence_id0:
208
+ print("unexpected sequence id returned by the server: {}".format(this_id))
209
+ sys.exit(1)
210
+
211
+ result = data_item.as_numpy("OUTPUT")
212
+ if len(result[0][0])==0:
213
+ break
214
+
215
+ result_list.append(data_item.as_numpy("OUTPUT"))
216
+
217
+ recv_count = recv_count + 1
218
+ result_str = ''.join([item[0][0].decode('utf-8') for item in result_list])
219
+ print(f"{len(result_list)}: {result_str}")
220
+
221
+
222
+ print("hd",result_str)
223
+ print("PASS: Sequence")
224
+ print("hd",result_str)
try_demo_demo.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions
6
+ # are met:
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ # * Redistributions in binary form must reproduce the above copyright
10
+ # notice, this list of conditions and the following disclaimer in the
11
+ # documentation and/or other materials provided with the distribution.
12
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
13
+ # contributors may be used to endorse or promote products derived
14
+ # from this software without specific prior written permission.
15
+ #
16
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
+
28
+ import argparse
29
+ import sys
30
+
31
+ import numpy as np
32
+ import tritonclient.grpc as grpcclient
33
+
34
+ if __name__ == "__main__":
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument(
37
+ "-v",
38
+ "--verbose",
39
+ action="store_true",
40
+ required=False,
41
+ default=False,
42
+ help="Enable verbose output",
43
+ )
44
+ parser.add_argument(
45
+ "-u",
46
+ "--url",
47
+ type=str,
48
+ required=False,
49
+ default="10.95.163.43:8001",
50
+ help="Inference server URL. Default is localhost:8001.",
51
+ )
52
+ parser.add_argument(
53
+ "-s",
54
+ "--ssl",
55
+ action="store_true",
56
+ required=False,
57
+ default=False,
58
+ help="Enable SSL encrypted channel to the server",
59
+ )
60
+ parser.add_argument(
61
+ "-t",
62
+ "--client-timeout",
63
+ type=float,
64
+ required=False,
65
+ default=None,
66
+ help="Client timeout in seconds. Default is None.",
67
+ )
68
+ parser.add_argument(
69
+ "-r",
70
+ "--root-certificates",
71
+ type=str,
72
+ required=False,
73
+ default=None,
74
+ help="File holding PEM-encoded root certificates. Default is None.",
75
+ )
76
+ parser.add_argument(
77
+ "-p",
78
+ "--private-key",
79
+ type=str,
80
+ required=False,
81
+ default=None,
82
+ help="File holding PEM-encoded private key. Default is None.",
83
+ )
84
+ parser.add_argument(
85
+ "-x",
86
+ "--certificate-chain",
87
+ type=str,
88
+ required=False,
89
+ default=None,
90
+ help="File holding PEM-encoded certificate chain. Default is None.",
91
+ )
92
+ parser.add_argument(
93
+ "-C",
94
+ "--grpc-compression-algorithm",
95
+ type=str,
96
+ required=False,
97
+ default=None,
98
+ help="The compression algorithm to be used when sending request to server. Default is None.",
99
+ )
100
+
101
+ FLAGS = parser.parse_args()
102
+ try:
103
+ # triton_client = grpcclient.InferenceServerClient(
104
+ # url=FLAGS.url,
105
+ # verbose=FLAGS.verbose,
106
+ # ssl=FLAGS.ssl,
107
+ # root_certificates=FLAGS.root_certificates,
108
+ # private_key=FLAGS.private_key,
109
+ # certificate_chain=FLAGS.certificate_chain,
110
+ # )
111
+ triton_client = grpcclient.InferenceServerClient(
112
+ url=FLAGS.url,
113
+ # verbose=FLAGS.verbose,
114
+ verbose = True,
115
+ ssl=FLAGS.ssl,
116
+ root_certificates=None,
117
+ private_key=None,
118
+ certificate_chain=None,
119
+ )
120
+ except Exception as e:
121
+ print("channel creation failed: " + str(e))
122
+ sys.exit()
123
+
124
+ model_name = "ensemble_mllm"
125
+ img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
126
+ # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0003.jpeg"
127
+ text = f"详细描述一下这张图片"
128
+
129
+ # Infer
130
+ inputs = []
131
+ img_url_bytes = img_url.encode("utf-8")
132
+ img_url_bytes = np.array(img_url_bytes, dtype=bytes)
133
+ img_url_bytes = img_url_bytes.reshape([1, -1])
134
+
135
+ inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES"))
136
+ inputs[0].set_data_from_numpy(img_url_bytes)
137
+
138
+ text_bytes = text.encode("utf-8")
139
+ text_bytes = np.array(text_bytes, dtype=bytes)
140
+ text_bytes = text_bytes.reshape([1, -1])
141
+ # text_input = np.expand_dims(text_bytes, axis=0)
142
+ text_input = text_bytes
143
+
144
+ inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES"))
145
+ inputs[1].set_data_from_numpy(text_input)
146
+
147
+ outputs = []
148
+ outputs.append(grpcclient.InferRequestedOutput("OUTPUT"))
149
+
150
+ # Test with outputs
151
+ results = triton_client.infer(
152
+ model_name=model_name,
153
+ inputs=inputs,
154
+ outputs=outputs,
155
+ client_timeout=None, #FLAGS.client_timeout,
156
+ # headers={"test": "1"},
157
+ compression_algorithm=None, #FLAGS.grpc_compression_algorithm,
158
+ )
159
+
160
+ statistics = triton_client.get_inference_statistics(model_name=model_name)
161
+ print(statistics)
162
+ if len(statistics.model_stats) != 1:
163
+ print("FAILED: Inference Statistics")
164
+ sys.exit(1)
165
+
166
+ # Get the output arrays from the results
167
+ output_data = results.as_numpy("OUTPUT")
168
+ result_str = output_data[0][0].decode('utf-8')
169
+
170
+ print("OUTPUT: "+ result_str)
171
+
172
+ # # Test with no outputs
173
+ # results = triton_client.infer(
174
+ # model_name=model_name,
175
+ # inputs=inputs,
176
+ # outputs=None,
177
+ # compression_algorithm=FLAGS.grpc_compression_algorithm,
178
+ # )
179
+
180
+ # # Get the output arrays from the results
181
+ # output0_data = results.as_numpy("OUTPUT0")
182
+ # output1_data = results.as_numpy("OUTPUT1")
183
+
184
+ # for i in range(16):
185
+ # print(
186
+ # str(input0_data[0][i])
187
+ # + " + "
188
+ # + str(input1_data[0][i])
189
+ # + " = "
190
+ # + str(output0_data[0][i])
191
+ # )
192
+ # print(
193
+ # str(input0_data[0][i])
194
+ # + " - "
195
+ # + str(input1_data[0][i])
196
+ # + " = "
197
+ # + str(output1_data[0][i])
198
+ # )
199
+ # if (input0_data[0][i] + input1_data[0][i]) != output0_data[0][i]:
200
+ # print("sync infer error: incorrect sum")
201
+ # sys.exit(1)
202
+ # if (input0_data[0][i] - input1_data[0][i]) != output1_data[0][i]:
203
+ # print("sync infer error: incorrect difference")
204
+ # sys.exit(1)
205
+
206
+ print("PASS: infer")
try_grpc.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ # from TritonServerClient import TritonServerClient, InferInput, InferRequestedOutput
3
+ # from TritonServerClient.utils import np_to_triton_dtype
4
+ from functools import wraps
5
+ import numpy as np
6
+
7
+ import tritonclient.grpc as grpcclient
8
+ from tritonclient.utils import InferenceServerException
9
+
10
+ # def prepare_params(query):
11
+ # bs=1
12
+ # # query = "#角色:直播间文案写作专家。#任务:根据用户输入的商品信息写出用于推广营销的直播文案。#目标:引导用户购买直播间商品,提高销量。#要求:- 如果没有输入商品信息,则邀请输入。- 文案写作 - 按照“引出话题”、“产品讲解”、“购买方式”三个方面写作文案。 - 在“产品讲解”部分详细讲解产品细节: - 【重要】充分扩充内容,吸引用户购买。 - 详细描述商品的优惠活动信息(如用户输入没有,不要询问,直接补充)。 - 在“购买方式”部分: - 仅包括引导用户点击链接购买的内容。 - 不要包括快递物流承诺。 - 【重要】每个部分多于5句话。 - 使用口语化的文字。 - 表达方式有亲和力和感染力。 - 对用户们使用“宝子们”、“宝宝们”、“家人们”等有亲和力的称呼。 - 多使用自问自答的语句。- 当你准备好了回复“请输入商品描述和优惠信息”我们就开始。#输出:- 包含在""里。- 写出每一部分内容的标题,包括【引出话题】、【产品讲解】、【购买方式】。- 示例:""【引出话题】【产品讲解】【购买方式】""- 输入商品信息:%s" % prompt
13
+ # # query = '请你扮演直播间直播卖货的主播角色,根据用户提出的问题,检索知识库中的信息,如果查找到相关的信息,使用口语化的文字,用非常简短的一句话回答用户的问题;如果查找不到相关的信息,直接回答没有,并引导用户关注本场直播商品。你的目的是激发用户的购买欲、提高直播商品销量。回答对用户们使用“宝子”、“宝宝”、“亲”等有亲和力的称呼。\n\n# 用户提出的问题:\n这里面有可乐吗\n\n# 知识库:\n商品名称:外送欢乐桶(券)兑换券\n直播链接序号:64号链接号链接\n商品现价:66元\n商品原价:142元\n折扣:5\n品牌名称:麦当劳麦乐送\n券包张数:1\n套餐详情:["麦辣鸡腿汉堡1个","板烧鸡腿堡1个","中可乐3杯","四拼小食桶A"]\n补充:四拼小食桶A(麦麦脆汁鸡(鸡腿)1块+中薯条1份+麦辣鸡翅2块+麦乐鸡5块)+麦辣鸡腿堡1个+板烧鸡腿堡1个+中可乐3杯 原材料:面包\n\n\n# 输出格式:\n答:\n\n# 要求\n对于用户问到了知识库中未提及的信息不要编造,直接不回答'
14
+ # ######
15
+ # #####
16
+ # title_text = np.array([query.encode('utf-8')], dtype=np.string_)
17
+ # title_text = np.tile(title_text, (bs, 1))
18
+
19
+ # data_batch = {}
20
+ # data_batch['query'] = title_text
21
+
22
+ # inputs = [
23
+ # InferInput("query", data_batch['query'].shape,
24
+ # np_to_triton_dtype(data_batch['query'].dtype)),
25
+ # ]
26
+
27
+ # inputs[0].set_data_from_numpy(data_batch['query'])
28
+
29
+ # return inputs
30
+ def prepare_params(query,img):
31
+ bs=1
32
+ # query = "#角色:直播间文案写作专家。#任务:根据用户输入的商品信息写出用于推广营销的直播文案。#目标:引导用户购买直播间商品,提高销量。#要求:- 如果没有输入商品信息,则邀请输入。- 文案写作 - 按照“引出话题”、“产品讲解”、“购买方式”三个方面写作文案。 - 在“产品讲解”部分详细讲解产品细节: - 【重要】充分扩充内容,吸引用户购买。 - 详细描述商品的优惠活动信息(如用户输入没有,不要询问,直接补充)。 - 在“购买方式”部分: - 仅包括引导用户点击链接购买的内容。 - 不要包括快递物流承诺。 - 【重要】每个部分多于5句话。 - 使用口语化的文字。 - 表达方式有亲和力和感染力。 - 对用户们使用“宝子们”、“宝宝们”、“家人们”等有亲和力的称呼。 - 多使用自问自答的语句。- 当你准备好了回复“请输入商品描述和优惠信息”我们就开始。#输出:- 包含在""里。- 写出每一部分内容的标题,包括【引出话题】、【产品讲解】、【购买方式】。- 示例:""【引出话题】【产品讲解】【购买方式】""- 输入商品信息:%s" % prompt
33
+ # query = '请你扮演直播间直播卖货的主播角色,根据用户提出的问题,检索知识库中的信息,如果查找到相关的信息,使用口语化的文字,用非常简短的一句话回答用户的问题;如果查找不到相关的信息,直接回答没有,并引导用户关注本场直播商品。你的目的是激发用户的购买欲、提高直播商品销量。回答对用户们使用“宝子”、“宝宝”、“亲��等有亲和力的称呼。\n\n# 用户提出的问题:\n这里面有可乐吗\n\n# 知识库:\n商品名称:外送欢乐桶(券)兑换券\n直播链接序号:64号链接号链接\n商品现价:66元\n商品原价:142元\n折扣:5\n品牌名称:麦当劳麦乐送\n券包张数:1\n套餐详情:["麦辣鸡腿汉堡1个","板烧鸡腿堡1个","中可乐3杯","四拼小食桶A"]\n补充:四拼小食桶A(麦麦脆汁鸡(鸡腿)1块+中薯条1份+麦辣鸡翅2块+麦乐鸡5块)+麦辣鸡腿堡1个+板烧鸡腿堡1个+中可乐3杯 原材料:面包\n\n\n# 输出格式:\n答:\n\n# 要求\n对于用户问到了知识库中未提及的信息不要编造,直接不回答'
34
+ inputs = []
35
+ ######
36
+ img_info = np.array([img.encode('utf-8')], dtype=np.string_)
37
+ img_info = np.tile(img_info, (bs,1))
38
+
39
+
40
+ inputs.append(InferInput("img", img_info.shape,
41
+ np_to_triton_dtype(img_info.dtype)),)
42
+
43
+
44
+ inputs[0].set_data_from_numpy(img_info)
45
+ #####
46
+ title_text = np.array([query.encode('utf-8')], dtype=np.string_)
47
+ title_text = np.tile(title_text, (bs, 1))
48
+
49
+ data_batch = {}
50
+ data_batch['query'] = title_text
51
+
52
+ inputs.append( InferInput("query", data_batch['query'].shape,
53
+ np_to_triton_dtype(data_batch['query'].dtype))
54
+ )
55
+
56
+
57
+ inputs[1].set_data_from_numpy(data_batch['query'])
58
+
59
+ return inputs
60
+
61
+ def make_a_try(inputs, outputs='response', model_name='ensemble_mllm', model_version=''): # qwen 1
62
+ outputs_list = []
63
+ ori_outputs_list = outputs.strip().split(",")
64
+ for out_ele in ori_outputs_list:
65
+ outputs_list.append(out_ele.strip())
66
+ outputs = [InferRequestedOutput(x) for x in outputs_list]
67
+
68
+ response = my_client.predict(model_name=model_name, inputs=inputs, model_version=model_version, outputs=outputs)
69
+
70
+ rsp_info = {}
71
+ if outputs_list == []:
72
+ for out_name_ele in response._result.outputs:
73
+ outputs_list.append(out_name_ele.name)
74
+ for output_name in outputs_list:
75
+ res = response.as_numpy(output_name)
76
+ response = np.expand_dims(res, axis=0)
77
+ response = response[0].decode('utf-8')
78
+ rsp_info[output_name] = response
79
+ print("response:",rsp_info)
80
+ return rsp_info['response']
81
+
82
+
83
+
84
+ # def greet(prompt):
85
+ # """Greet someone."""
86
+ # # print(prompt)
87
+ # print("prompt:",prompt)
88
+ # inputs = prepare_params(prompt)
89
+ # print(inputs)
90
+ # result = make_a_try(inputs)
91
+
92
+ # return result
93
+ def greet(prompt,img):
94
+ """Greet someone."""
95
+ # print(prompt)
96
+ print("prompt:",prompt)
97
+ inputs = prepare_params(prompt,img)
98
+ print(inputs)
99
+ result = make_a_try(inputs)
100
+
101
+ return result
102
+
103
+ def clear_input():
104
+
105
+ return ""
106
+
107
+
108
+ if __name__ == "__main__":
109
+ param_info = {}
110
+ # param_info['appkey'] = "com.sankuai.automl.serving"
111
+ param_info['appkey'] = "com.sankuai.automl.streamvlm"
112
+
113
+ # param_info['remote_appkey'] = "com.sankuai.automl.chat3"
114
+ param_info['remote_appkey'] = "com.sankuai.automl.streamvlm"
115
+
116
+ param_info['model_name'] = "ensemble_mllm"
117
+ param_info['model_version'] = "1"
118
+ param_info['time_out'] = 60000
119
+ param_info['server_targets'] = []
120
+ param_info['outputs'] = 'response'
121
+
122
+ appkey, remote_appkey, model_name, model_version, time_out, server_targets = param_info['appkey'], param_info['remote_appkey'], param_info['model_name'], param_info['model_version'], param_info['time_out'], param_info['server_targets']
123
+
124
+ my_client = TritonServerClient(appkey=appkey, remote_appkey=remote_appkey, time_out=time_out, server_targets=server_targets)
125
+ # triton_client.async_stream_infer(
126
+ # model_name=model_name,
127
+ # inputs=inputs,
128
+ # outputs=outputs,
129
+ # request_id="{}".format(sequence_id),
130
+ # sequence_id=sequence_id,
131
+ # sequence_start=True,
132
+ # sequence_end=True,
133
+ # )
134
+
135
+ img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
136
+
137
+ greet("nihao",img_url)
138
+ # greet("nihao")
139
+ print("描述这张图片")
try_hd.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from TritonServerClient import TritonServerClient, InferInput, InferRequestedOutput
3
+ from TritonServerClient.utils import np_to_triton_dtype
4
+ from functools import wraps
5
+ import numpy as np
6
+
7
+
8
+ # def prepare_params(query):
9
+ # bs=1
10
+ # # query = "#角色:直播间文案写作专家。#任务:根据用户输入的商品信息写出用于推广营销的直播文案。#目标:引导用户购买直播间商品,提高销量。#要求:- 如果没有输入商品信息,则邀请输入。- 文案写作 - 按照“引出话题”、“产品讲解”、“购买方式”三个方面写作文案。 - 在“产品讲解”部分详细讲解产品细节: - 【重要】充分扩充内容,吸引用户购买。 - 详细描述商品的优惠活动信息(如用户输入没有,不要询问,直接补充)。 - 在“购买方式”部分: - 仅包括引导用户点击链接购买的内容。 - 不要包括快递物流承诺。 - 【重要】每个部分多于5句话。 - 使用口语化的文字。 - 表达方式有亲和力和感染力。 - 对用户们使用“宝子们”、“宝宝们”、“家人们”等有亲和力的称呼。 - 多使用自问自答的语句。- 当你准备好了回复“请输入商品描述和优惠信息”我们就开始。#输出:- 包含在""里。- 写出每一部分内容的标题,包括【引出话题】、【产品讲解】、【购买方式】。- 示例:""【引出话题】【产品讲解】【购买方式】""- 输入商品信息:%s" % prompt
11
+ # # query = '请你扮演直播间直播卖货的主播角色,根据用户提出的问题,检索知识库中的信息,如果查找到相关的信息,使用口语化的文字,用非常简短的一句话回答用户的问题;如果查找不到相关的信息,直接回答没有,并引导用户关注本场直播商品。你的目的是激发用户的购买欲、提高直播商品销量。回答对用户们使用“宝子”、“宝宝”、“亲”等有亲和力的称呼。\n\n# 用户提出的问题:\n这里面有可乐吗\n\n# 知识库:\n商品名称:外送欢乐桶(券)兑换券\n直播链接序号:64号链接号链接\n商品现价:66元\n商品原价:142元\n折扣:5\n品牌名称:麦当劳麦乐送\n券包张数:1\n套餐详情:["麦辣鸡腿汉堡1个","板烧鸡腿堡1个","中可乐3杯","四拼小食桶A"]\n补充:四拼小食桶A(麦麦脆汁鸡(鸡腿)1块+中薯条1份+麦辣鸡翅2块+麦乐鸡5块)+麦辣鸡腿堡1个+板烧鸡腿堡1个+中可乐3杯 原材料:面包\n\n\n# 输出格式:\n答:\n\n# 要求\n对于用户问到了知识库中未提及的信息不要编造,直接不回答'
12
+ # ######
13
+ # #####
14
+ # title_text = np.array([query.encode('utf-8')], dtype=np.string_)
15
+ # title_text = np.tile(title_text, (bs, 1))
16
+
17
+ # data_batch = {}
18
+ # data_batch['query'] = title_text
19
+
20
+ # inputs = [
21
+ # InferInput("query", data_batch['query'].shape,
22
+ # np_to_triton_dtype(data_batch['query'].dtype)),
23
+ # ]
24
+
25
+ # inputs[0].set_data_from_numpy(data_batch['query'])
26
+
27
+ # return inputs
28
+ def prepare_params(query,img):
29
+ bs=1
30
+ # query = "#角色:直播间文案写作专家。#任务:根据用户输入的商品信息写出用于推广营销的直播文案。#目标:引导用户购买直播间商品,提高销量。#要求:- 如果没有输入商品信息,则邀请输入。- 文案写作 - 按照“引出话题”、“产品讲解”、“购买方式”三个方面写作文案。 - 在“产品讲解”部分详细讲解产品细节: - 【重要】充分扩充内容,吸引用户购买。 - 详细描述商品的优惠活动信息(如用户输入没有,不要询问,直接补充)。 - 在“购买方式”部分: - 仅包括引导用户点击链接购买的内容。 - 不要包括快递物流承诺。 - 【重要】每个部分多于5句话。 - 使用口语化的文字。 - 表达方式有亲和力和感染力。 - 对用户们使用“宝子们”、“宝宝们”、“家人们”等有亲和力的称呼。 - 多使用自问自答的语句。- 当你准备好了回复“请输入商品描述和优惠信息”我们就开始。#输出:- 包含在""里。- 写出每一部分内容的标题,包括【引出话题】、【产品讲解】、【购买方式】。- 示例:""【引出话题】【产品讲解】【购买方式】""- 输入商品信息:%s" % prompt
31
+ # query = '请你扮演直播间直播卖货的主播角色,根据用户提出的问题,检索知识库中的信息,如果查找到相关的信息,使用口语化的文字,用非常简短的一句话回答用户的问题;如果查找不到相关的信息,直接回答没有,并引导用户关注本场直播商品。你的目的是激发用户的购买欲、提高直播商品销量。回答对用户们使用“宝子”、“宝宝”、“亲”等有亲和力的称呼。\n\n# 用户提出的问题:\n这里面有可乐吗\n\n# 知识库:\n商品名称:外送欢乐桶(券)兑换券\n直播链接序号:64号链接号链接\n商品现价:66元\n商品原价:142元\n折扣:5\n品牌名称:麦当劳麦乐送\n券包张数:1\n套餐详情:["麦辣鸡腿汉堡1个","板烧鸡腿堡1个","中可乐3杯","四拼小食桶A"]\n补充:四拼小食桶A(麦麦脆汁鸡(鸡腿)1块+中薯条1份+麦辣鸡翅2块+麦乐鸡5块)+麦辣鸡腿堡1个+板烧鸡腿堡1个+中可乐3杯 原材料:面包\n\n\n# 输出格式:\n答:\n\n# 要求\n对于用户问到了知识库中未提及的信息不要编造,直接不回答'
32
+ inputs = []
33
+ ######
34
+ img_info = np.array([img.encode('utf-8')], dtype=np.string_)
35
+ img_info = np.tile(img_info, (bs,1))
36
+
37
+
38
+ inputs.append(InferInput("img", img_info.shape,
39
+ np_to_triton_dtype(img_info.dtype)),)
40
+
41
+
42
+ inputs[0].set_data_from_numpy(img_info)
43
+ #####
44
+ title_text = np.array([query.encode('utf-8')], dtype=np.string_)
45
+ title_text = np.tile(title_text, (bs, 1))
46
+
47
+ data_batch = {}
48
+ data_batch['query'] = title_text
49
+
50
+ inputs.append( InferInput("query", data_batch['query'].shape,
51
+ np_to_triton_dtype(data_batch['query'].dtype))
52
+ )
53
+
54
+
55
+ inputs[1].set_data_from_numpy(data_batch['query'])
56
+
57
+ return inputs
58
+
59
+ def make_a_try(inputs, outputs='response', model_name='ensemble_mllm', model_version=''): # qwen 1
60
+ outputs_list = []
61
+ ori_outputs_list = outputs.strip().split(",")
62
+ for out_ele in ori_outputs_list:
63
+ outputs_list.append(out_ele.strip())
64
+ outputs = [InferRequestedOutput(x) for x in outputs_list]
65
+
66
+ response = my_client.predict(model_name=model_name, inputs=inputs, model_version=model_version, outputs=outputs)
67
+
68
+ rsp_info = {}
69
+ if outputs_list == []:
70
+ for out_name_ele in response._result.outputs:
71
+ outputs_list.append(out_name_ele.name)
72
+ for output_name in outputs_list:
73
+ res = response.as_numpy(output_name)
74
+ response = np.expand_dims(res, axis=0)
75
+ response = response[0].decode('utf-8')
76
+ rsp_info[output_name] = response
77
+ print("response:",rsp_info)
78
+ return rsp_info['response']
79
+
80
+
81
+
82
+ # def greet(prompt):
83
+ # """Greet someone."""
84
+ # # print(prompt)
85
+ # print("prompt:",prompt)
86
+ # inputs = prepare_params(prompt)
87
+ # print(inputs)
88
+ # result = make_a_try(inputs)
89
+
90
+ # return result
91
+ def greet(prompt,img):
92
+ """Greet someone."""
93
+ # print(prompt)
94
+ print("prompt:",prompt)
95
+ inputs = prepare_params(prompt,img)
96
+ print(inputs)
97
+ result = make_a_try(inputs)
98
+
99
+ return result
100
+
101
+ def clear_input():
102
+
103
+ return ""
104
+
105
+
106
+ if __name__ == "__main__":
107
+ param_info = {}
108
+ # param_info['appkey'] = "com.sankuai.automl.serving"
109
+ param_info['appkey'] = "com.sankuai.automl.streamvlm"
110
+
111
+ # param_info['remote_appkey'] = "com.sankuai.automl.chat3"
112
+ param_info['remote_appkey'] = "com.sankuai.automl.streamvlm"
113
+
114
+ param_info['model_name'] = "ensemble_mllm"
115
+ param_info['model_version'] = "1"
116
+ param_info['time_out'] = 60000
117
+ param_info['server_targets'] = []
118
+ param_info['outputs'] = 'response'
119
+
120
+ appkey, remote_appkey, model_name, model_version, time_out, server_targets = param_info['appkey'], param_info['remote_appkey'], param_info['model_name'], param_info['model_version'], param_info['time_out'], param_info['server_targets']
121
+
122
+ my_client = TritonServerClient(appkey=appkey, remote_appkey=remote_appkey, time_out=time_out, server_targets=server_targets)
123
+ # triton_client.async_stream_infer(
124
+ # model_name=model_name,
125
+ # inputs=inputs,
126
+ # outputs=outputs,
127
+ # request_id="{}".format(sequence_id),
128
+ # sequence_id=sequence_id,
129
+ # sequence_start=True,
130
+ # sequence_end=True,
131
+ # )
132
+
133
+ img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
134
+
135
+ greet("nihao",img_url)
136
+ # greet("nihao")
137
+ print("描述这张图片")
try_hd_v2.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import queue
3
+ import sys
4
+ import uuid
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import tritonclient.grpc as grpcclient
9
+ from tritonclient.utils import InferenceServerException
10
+ import gradio as gr
11
+ from functools import wraps
12
+
13
+ ####
14
+ from PIL import Image
15
+ import base64
16
+ import io
17
+ #####
18
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
19
+ import socket
20
+ ####
21
+ import os
22
+ import uuid
23
+ ####
24
+
25
+ class UserData:
26
+ def __init__(self):
27
+ self._completed_requests = queue.Queue()
28
+
29
+ def callback(user_data, result, error):
30
+ if error:
31
+ user_data._completed_requests.put(error)
32
+ else:
33
+ user_data._completed_requests.put(result)
34
+
35
+ def make_a_try(img_url,text):
36
+ model_name = 'ensemble_mllm'
37
+ user_data = UserData()
38
+ sequence_id = 100
39
+ int_sequence_id0 = sequence_id
40
+ result_list=[]
41
+ with grpcclient.InferenceServerClient(
42
+ url="10.199.14.151:8001", verbose = False
43
+ ) as triton_client:
44
+ try:
45
+ # Establish stream
46
+ triton_client.start_stream(
47
+ callback=partial(callback, user_data),
48
+ stream_timeout=None,
49
+ )
50
+ # Create the tensor for INPUT
51
+ inputs = []
52
+ img_url_bytes = img_url.encode("utf-8")
53
+ img_url_bytes = np.array(img_url_bytes, dtype=bytes)
54
+ img_url_bytes = img_url_bytes.reshape([1, -1])
55
+
56
+ inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES"))
57
+ inputs[0].set_data_from_numpy(img_url_bytes)
58
+
59
+ text_bytes = text.encode("utf-8")
60
+ text_bytes = np.array(text_bytes, dtype=bytes)
61
+ text_bytes = text_bytes.reshape([1, -1])
62
+ # text_input = np.expand_dims(text_bytes, axis=0)
63
+ text_input = text_bytes
64
+
65
+ inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES"))
66
+ inputs[1].set_data_from_numpy(text_input)
67
+
68
+ outputs = []
69
+ outputs.append(grpcclient.InferRequestedOutput("OUTPUT"))
70
+ # Issue the asynchronous sequence inference.
71
+ triton_client.async_stream_infer(
72
+ model_name=model_name,
73
+ inputs=inputs,
74
+ outputs=outputs,
75
+ request_id="{}".format(sequence_id),
76
+ sequence_id=sequence_id,
77
+ sequence_start=True,
78
+ sequence_end=True,
79
+ )
80
+ ######hd
81
+ except InferenceServerException as error:
82
+ print(error)
83
+ # sys.exit(1)
84
+ # continue
85
+ return ""
86
+
87
+
88
+ # Retrieve results...
89
+ recv_count = 0
90
+
91
+ while True:
92
+ try:
93
+ data_item = user_data._completed_requests.get(timeout=5)
94
+ except Exception as e:
95
+ break
96
+ # data_item = user_data._completed_requests.get()
97
+ if type(data_item) == InferenceServerException:
98
+ print('InferenceServerException: ', data_item)
99
+ # sys.exit(1)
100
+ return ""
101
+ this_id = data_item.get_response().id.split("_")[0]
102
+ if int(this_id) != int_sequence_id0:
103
+ print("unexpected sequence id returned by the server: {}".format(this_id))
104
+ # sys.exit(1)
105
+ return ""
106
+ ####
107
+ result = data_item.as_numpy("OUTPUT")
108
+ if len(result[0][0])==0:
109
+ break
110
+ ####
111
+ result_list.append(data_item.as_numpy("OUTPUT"))
112
+
113
+ recv_count = recv_count + 1
114
+ result_str = ''.join([item[0][0].decode('utf-8') for item in result_list])
115
+ return result_str
116
+
117
+
118
+ def greet(image, text):
119
+ ###save img
120
+ static_path = f"/workdir/yanghandi/gradio_demo/static"
121
+ # 将图片转换为字节流
122
+ img_byte_arr = io.BytesIO()
123
+ try:
124
+ image.save(img_byte_arr, format='JPEG')
125
+ except Exception:
126
+ return ""
127
+ img_byte_arr = img_byte_arr.getvalue()
128
+
129
+ # 为图片生成一个唯一的文件名
130
+ # filename = "image_" + str(os.getpid()) + ".jpg" #uuid
131
+ unique_id = uuid.uuid4()
132
+ filename = f"image_{unique_id}.jpg"
133
+ filepath = os.path.join(static_path, filename)
134
+
135
+ # 将字节流写入文件
136
+ with open(filepath, 'wb') as f:
137
+ f.write(img_byte_arr)
138
+
139
+
140
+ img_url = f"http://10.99.5.48:8080/file=static/" + filename
141
+ # img_url = PIL_to_URL(img_url)
142
+ # img_url = "http://10.99.5.48:8080/file=static/0000.jpeg"
143
+ result = make_a_try(img_url,text)
144
+ # print(result)
145
+ return result
146
+
147
+ # def greet_example(image, text):
148
+ # ###save img
149
+ # # filename = image
150
+ # # static_path = "/workdir/yanghandi/gradio_demo/static"
151
+ # img_url = "http://10.99.5.48:8080/file=static/0000.jpeg"
152
+
153
+ # # img_url = PIL_to_URL(img_url)
154
+ # # img_url = "http://10.99.5.48:8080/file=static/0000.jpeg"
155
+ # result = make_a_try(img_url,text)
156
+ # # print(result)
157
+ # return result
158
+
159
+ def clear_output():
160
+
161
+ return ""
162
+
163
+ def get_example():
164
+ return [
165
+ [f"/workdir/yanghandi/gradio_demo/static/0001.jpg", f"图中的人物是谁"]
166
+ ]
167
+ if __name__ == "__main__":
168
+
169
+ param_info = {}
170
+ # param_info['appkey'] = "com.sankuai.automl.serving"
171
+ param_info['appkey'] = "10.199.14.151:8001"
172
+
173
+ # param_info['remote_appkey'] = "com.sankuai.automl.chat3"
174
+ param_info['remote_appkey'] = "10.199.14.151:8001"
175
+ param_info['model_name'] = 'ensemble_mllm'
176
+ param_info['model_version'] = "1"
177
+ param_info['time_out'] = 60000
178
+ param_info['server_targets'] = []
179
+ param_info['outputs'] = 'response'
180
+
181
+
182
+ gr.set_static_paths(paths=["static/"])
183
+
184
+ with gr.Blocks(title='demo') as demo:
185
+ gr.Markdown("# 自研模型测试demo")
186
+ gr.Markdown("尝试使用该demo,上传图片并开始讨论它,或者尝试下面的例子")
187
+
188
+ with gr.Row():
189
+ with gr.Column():
190
+ # imagebox = gr.Image(value="static/0000.jpeg",type="pil")
191
+ imagebox = gr.Image(type="pil")
192
+ promptbox = gr.Textbox(label = "prompt")
193
+
194
+ with gr.Column():
195
+ output = gr.Textbox(label = "output")
196
+ with gr.Row():
197
+ submit = gr.Button("submit")
198
+ clear = gr.Button("clear")
199
+
200
+ submit.click(fn=greet,inputs=[imagebox, promptbox],outputs=[output])
201
+ clear.click(fn=clear_output, inputs=[], outputs=[output])
202
+
203
+ gr.Markdown("# example")
204
+
205
+ gr.Examples(
206
+ examples = get_example(),
207
+ fn = greet,
208
+ inputs=[imagebox, promptbox],
209
+ outputs = [output],
210
+ cache_examples = True
211
+ )
212
+
213
+ demo.launch(server_name="0.0.0.0", server_port=8080, debug=True, share=True)
214
+
215
+
216
+ # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
217
+ # text = f"详细描述一下这张图片"
218
+ # greet(img_url,text)