Spaces:
Build error
Build error
#!/usr/bin/env python | |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions | |
# are met: | |
# * Redistributions of source code must retain the above copyright | |
# notice, this list of conditions and the following disclaimer. | |
# * Redistributions in binary form must reproduce the above copyright | |
# notice, this list of conditions and the following disclaimer in the | |
# documentation and/or other materials provided with the distribution. | |
# * Neither the name of NVIDIA CORPORATION nor the names of its | |
# contributors may be used to endorse or promote products derived | |
# from this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | |
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | |
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | |
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | |
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | |
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | |
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
import argparse | |
import sys | |
import numpy as np | |
import tritonclient.grpc as grpcclient | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
action="store_true", | |
required=False, | |
default=False, | |
help="Enable verbose output", | |
) | |
parser.add_argument( | |
"-u", | |
"--url", | |
type=str, | |
required=False, | |
default="10.95.163.43:8001", | |
help="Inference server URL. Default is localhost:8001.", | |
) | |
parser.add_argument( | |
"-s", | |
"--ssl", | |
action="store_true", | |
required=False, | |
default=False, | |
help="Enable SSL encrypted channel to the server", | |
) | |
parser.add_argument( | |
"-t", | |
"--client-timeout", | |
type=float, | |
required=False, | |
default=None, | |
help="Client timeout in seconds. Default is None.", | |
) | |
parser.add_argument( | |
"-r", | |
"--root-certificates", | |
type=str, | |
required=False, | |
default=None, | |
help="File holding PEM-encoded root certificates. Default is None.", | |
) | |
parser.add_argument( | |
"-p", | |
"--private-key", | |
type=str, | |
required=False, | |
default=None, | |
help="File holding PEM-encoded private key. Default is None.", | |
) | |
parser.add_argument( | |
"-x", | |
"--certificate-chain", | |
type=str, | |
required=False, | |
default=None, | |
help="File holding PEM-encoded certificate chain. Default is None.", | |
) | |
parser.add_argument( | |
"-C", | |
"--grpc-compression-algorithm", | |
type=str, | |
required=False, | |
default=None, | |
help="The compression algorithm to be used when sending request to server. Default is None.", | |
) | |
FLAGS = parser.parse_args() | |
try: | |
# triton_client = grpcclient.InferenceServerClient( | |
# url=FLAGS.url, | |
# verbose=FLAGS.verbose, | |
# ssl=FLAGS.ssl, | |
# root_certificates=FLAGS.root_certificates, | |
# private_key=FLAGS.private_key, | |
# certificate_chain=FLAGS.certificate_chain, | |
# ) | |
triton_client = grpcclient.InferenceServerClient( | |
url=FLAGS.url, | |
# verbose=FLAGS.verbose, | |
verbose = True, | |
ssl=FLAGS.ssl, | |
root_certificates=None, | |
private_key=None, | |
certificate_chain=None, | |
) | |
except Exception as e: | |
print("channel creation failed: " + str(e)) | |
sys.exit() | |
model_name = "ensemble_mllm" | |
img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg" | |
# img_url = f"https://s3plus.sankuai.com/automl-pkgs/0003.jpeg" | |
text = f"详细描述一下这张图片" | |
# Infer | |
inputs = [] | |
img_url_bytes = img_url.encode("utf-8") | |
img_url_bytes = np.array(img_url_bytes, dtype=bytes) | |
img_url_bytes = img_url_bytes.reshape([1, -1]) | |
inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES")) | |
inputs[0].set_data_from_numpy(img_url_bytes) | |
text_bytes = text.encode("utf-8") | |
text_bytes = np.array(text_bytes, dtype=bytes) | |
text_bytes = text_bytes.reshape([1, -1]) | |
# text_input = np.expand_dims(text_bytes, axis=0) | |
text_input = text_bytes | |
inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES")) | |
inputs[1].set_data_from_numpy(text_input) | |
outputs = [] | |
outputs.append(grpcclient.InferRequestedOutput("OUTPUT")) | |
# Test with outputs | |
results = triton_client.infer( | |
model_name=model_name, | |
inputs=inputs, | |
outputs=outputs, | |
client_timeout=None, #FLAGS.client_timeout, | |
# headers={"test": "1"}, | |
compression_algorithm=None, #FLAGS.grpc_compression_algorithm, | |
) | |
statistics = triton_client.get_inference_statistics(model_name=model_name) | |
print(statistics) | |
if len(statistics.model_stats) != 1: | |
print("FAILED: Inference Statistics") | |
sys.exit(1) | |
# Get the output arrays from the results | |
output_data = results.as_numpy("OUTPUT") | |
result_str = output_data[0][0].decode('utf-8') | |
print("OUTPUT: "+ result_str) | |
# # Test with no outputs | |
# results = triton_client.infer( | |
# model_name=model_name, | |
# inputs=inputs, | |
# outputs=None, | |
# compression_algorithm=FLAGS.grpc_compression_algorithm, | |
# ) | |
# # Get the output arrays from the results | |
# output0_data = results.as_numpy("OUTPUT0") | |
# output1_data = results.as_numpy("OUTPUT1") | |
# for i in range(16): | |
# print( | |
# str(input0_data[0][i]) | |
# + " + " | |
# + str(input1_data[0][i]) | |
# + " = " | |
# + str(output0_data[0][i]) | |
# ) | |
# print( | |
# str(input0_data[0][i]) | |
# + " - " | |
# + str(input1_data[0][i]) | |
# + " = " | |
# + str(output1_data[0][i]) | |
# ) | |
# if (input0_data[0][i] + input1_data[0][i]) != output0_data[0][i]: | |
# print("sync infer error: incorrect sum") | |
# sys.exit(1) | |
# if (input0_data[0][i] - input1_data[0][i]) != output1_data[0][i]: | |
# print("sync infer error: incorrect difference") | |
# sys.exit(1) | |
print("PASS: infer") |