Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions | |
| # are met: | |
| # * Redistributions of source code must retain the above copyright | |
| # notice, this list of conditions and the following disclaimer. | |
| # * Redistributions in binary form must reproduce the above copyright | |
| # notice, this list of conditions and the following disclaimer in the | |
| # documentation and/or other materials provided with the distribution. | |
| # * Neither the name of NVIDIA CORPORATION nor the names of its | |
| # contributors may be used to endorse or promote products derived | |
| # from this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | |
| # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | |
| # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | |
| # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | |
| # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | |
| # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | |
| # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |
| # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| import argparse | |
| import queue | |
| import sys | |
| import uuid | |
| from functools import partial | |
| import numpy as np | |
| import tritonclient.grpc as grpcclient | |
| from tritonclient.utils import InferenceServerException | |
| ## | |
| import time | |
| import threading | |
| ### | |
| FLAGS = None | |
| class UserData: | |
| def __init__(self): | |
| self._completed_requests = queue.Queue() | |
| # Define the callback function. Note the last two parameters should be | |
| # result and error. InferenceServerClient would povide the results of an | |
| # inference as grpcclient.InferResult in result. For successful | |
| # inference, error will be None, otherwise it will be an object of | |
| # tritonclientutils.InferenceServerException holding the error details | |
| def callback(user_data, result, error): | |
| if error: | |
| user_data._completed_requests.put(error) | |
| else: | |
| user_data._completed_requests.put(result) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-v", | |
| "--verbose", | |
| action="store_true", | |
| required=False, | |
| default=False, | |
| help="Enable verbose output", | |
| ) | |
| # parser.add_argument( | |
| # "-u", | |
| # "--url", | |
| # type=str, | |
| # required=False, | |
| # default="localhost:8001", | |
| # help="Inference server URL and it gRPC port. Default is localhost:8001.", | |
| # ) | |
| parser.add_argument( | |
| "-u", | |
| "--url", | |
| type=str, | |
| required=False, | |
| default="10.199.14.151:8001", | |
| help="Inference server URL and it gRPC port. Default is localhost:8001.", | |
| ) | |
| parser.add_argument( | |
| "-t", | |
| "--stream-timeout", | |
| type=float, | |
| required=False, | |
| default=None, | |
| help="Stream timeout in seconds. Default is None.", | |
| ) | |
| # parser.add_argument( | |
| # "-d", | |
| # "--dyna", | |
| # action="store_true", | |
| # required=False, | |
| # default=False, | |
| # help="Assume dynamic sequence model", | |
| # ) | |
| # parser.add_argument( | |
| # "-o", | |
| # "--offset", | |
| # type=int, | |
| # required=False, | |
| # default=0, | |
| # help="Add offset to sequence ID used", | |
| # ) | |
| FLAGS = parser.parse_args() | |
| # # We use custom "sequence" models which take 1 input | |
| # # value. The output is the accumulated value of the inputs. See | |
| # # src/custom/sequence. | |
| # int_sequence_model_name = ( | |
| # "simple_dyna_sequence" if FLAGS.dyna else "simple_sequence" | |
| # ) | |
| # string_sequence_model_name = ( | |
| # "simple_string_dyna_sequence" if FLAGS.dyna else "simple_sequence" | |
| # ) | |
| model_name = 'ensemble_mllm' | |
| model_version = "" | |
| batch_size = 1 | |
| # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg" | |
| img_url = "/workdir/yanghandi/gradio_demo/static/0000.jpeg" | |
| # img_url = f"https://s3plus.sankuai.com/automl-pkgs/0003.jpeg" | |
| text = f"详细描述一下这张图片" | |
| sequence_id = 100 | |
| int_sequence_id0 = sequence_id | |
| result_list = [] | |
| user_data = UserData() | |
| # It is advisable to use client object within with..as clause | |
| # when sending streaming requests. This ensures the client | |
| # is closed when the block inside with exits. | |
| with grpcclient.InferenceServerClient( | |
| url=FLAGS.url, verbose=FLAGS.verbose | |
| ) as triton_client: | |
| try: | |
| # Establish stream | |
| triton_client.start_stream( | |
| callback=partial(callback, user_data), | |
| stream_timeout=FLAGS.stream_timeout, | |
| ) | |
| # Create the tensor for INPUT | |
| inputs = [] | |
| img_url_bytes = img_url.encode("utf-8") | |
| img_url_bytes = np.array(img_url_bytes, dtype=bytes) | |
| img_url_bytes = img_url_bytes.reshape([1, -1]) | |
| inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES")) | |
| inputs[0].set_data_from_numpy(img_url_bytes) | |
| text_bytes = text.encode("utf-8") | |
| text_bytes = np.array(text_bytes, dtype=bytes) | |
| text_bytes = text_bytes.reshape([1, -1]) | |
| # text_input = np.expand_dims(text_bytes, axis=0) | |
| text_input = text_bytes | |
| inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES")) | |
| inputs[1].set_data_from_numpy(text_input) | |
| outputs = [] | |
| outputs.append(grpcclient.InferRequestedOutput("OUTPUT")) | |
| # Issue the asynchronous sequence inference. | |
| triton_client.async_stream_infer( | |
| model_name=model_name, | |
| inputs=inputs, | |
| outputs=outputs, | |
| request_id="{}".format(sequence_id), | |
| sequence_id=sequence_id, | |
| sequence_start=True, | |
| sequence_end=True, | |
| ) | |
| except InferenceServerException as error: | |
| print(error) | |
| sys.exit(1) | |
| # Retrieve results... | |
| recv_count = 0 | |
| ##### | |
| #### | |
| while True: | |
| # if len(result_list) == 80: | |
| # print("1") | |
| data_item = user_data._completed_requests.get() | |
| # try: | |
| # data_item = user_data._completed_requests.get(timeout=5) | |
| # except Exception as e: | |
| # print("queue wrong") | |
| # break | |
| if type(data_item) == InferenceServerException: | |
| print('InferenceServerException: ', data_item) | |
| sys.exit(1) | |
| this_id = data_item.get_response().id.split("_")[0] | |
| if int(this_id) != int_sequence_id0: | |
| print("unexpected sequence id returned by the server: {}".format(this_id)) | |
| sys.exit(1) | |
| result = data_item.as_numpy("OUTPUT") | |
| if len(result[0][0])==0: | |
| break | |
| result_list.append(data_item.as_numpy("OUTPUT")) | |
| recv_count = recv_count + 1 | |
| result_str = ''.join([item[0][0].decode('utf-8') for item in result_list]) | |
| print(f"{len(result_list)}: {result_str}") | |
| print("hd",result_str) | |
| print("PASS: Sequence") | |
| print("hd",result_str) |