gradio_demo / try_demo_demo.py
hd0013's picture
Upload folder using huggingface_hub
7f119fd verified
#!/usr/bin/env python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import sys
import numpy as np
import tritonclient.grpc as grpcclient
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--verbose",
action="store_true",
required=False,
default=False,
help="Enable verbose output",
)
parser.add_argument(
"-u",
"--url",
type=str,
required=False,
default="10.95.163.43:8001",
help="Inference server URL. Default is localhost:8001.",
)
parser.add_argument(
"-s",
"--ssl",
action="store_true",
required=False,
default=False,
help="Enable SSL encrypted channel to the server",
)
parser.add_argument(
"-t",
"--client-timeout",
type=float,
required=False,
default=None,
help="Client timeout in seconds. Default is None.",
)
parser.add_argument(
"-r",
"--root-certificates",
type=str,
required=False,
default=None,
help="File holding PEM-encoded root certificates. Default is None.",
)
parser.add_argument(
"-p",
"--private-key",
type=str,
required=False,
default=None,
help="File holding PEM-encoded private key. Default is None.",
)
parser.add_argument(
"-x",
"--certificate-chain",
type=str,
required=False,
default=None,
help="File holding PEM-encoded certificate chain. Default is None.",
)
parser.add_argument(
"-C",
"--grpc-compression-algorithm",
type=str,
required=False,
default=None,
help="The compression algorithm to be used when sending request to server. Default is None.",
)
FLAGS = parser.parse_args()
try:
# triton_client = grpcclient.InferenceServerClient(
# url=FLAGS.url,
# verbose=FLAGS.verbose,
# ssl=FLAGS.ssl,
# root_certificates=FLAGS.root_certificates,
# private_key=FLAGS.private_key,
# certificate_chain=FLAGS.certificate_chain,
# )
triton_client = grpcclient.InferenceServerClient(
url=FLAGS.url,
# verbose=FLAGS.verbose,
verbose = True,
ssl=FLAGS.ssl,
root_certificates=None,
private_key=None,
certificate_chain=None,
)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()
model_name = "ensemble_mllm"
img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg"
# img_url = f"https://s3plus.sankuai.com/automl-pkgs/0003.jpeg"
text = f"详细描述一下这张图片"
# Infer
inputs = []
img_url_bytes = img_url.encode("utf-8")
img_url_bytes = np.array(img_url_bytes, dtype=bytes)
img_url_bytes = img_url_bytes.reshape([1, -1])
inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES"))
inputs[0].set_data_from_numpy(img_url_bytes)
text_bytes = text.encode("utf-8")
text_bytes = np.array(text_bytes, dtype=bytes)
text_bytes = text_bytes.reshape([1, -1])
# text_input = np.expand_dims(text_bytes, axis=0)
text_input = text_bytes
inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES"))
inputs[1].set_data_from_numpy(text_input)
outputs = []
outputs.append(grpcclient.InferRequestedOutput("OUTPUT"))
# Test with outputs
results = triton_client.infer(
model_name=model_name,
inputs=inputs,
outputs=outputs,
client_timeout=None, #FLAGS.client_timeout,
# headers={"test": "1"},
compression_algorithm=None, #FLAGS.grpc_compression_algorithm,
)
statistics = triton_client.get_inference_statistics(model_name=model_name)
print(statistics)
if len(statistics.model_stats) != 1:
print("FAILED: Inference Statistics")
sys.exit(1)
# Get the output arrays from the results
output_data = results.as_numpy("OUTPUT")
result_str = output_data[0][0].decode('utf-8')
print("OUTPUT: "+ result_str)
# # Test with no outputs
# results = triton_client.infer(
# model_name=model_name,
# inputs=inputs,
# outputs=None,
# compression_algorithm=FLAGS.grpc_compression_algorithm,
# )
# # Get the output arrays from the results
# output0_data = results.as_numpy("OUTPUT0")
# output1_data = results.as_numpy("OUTPUT1")
# for i in range(16):
# print(
# str(input0_data[0][i])
# + " + "
# + str(input1_data[0][i])
# + " = "
# + str(output0_data[0][i])
# )
# print(
# str(input0_data[0][i])
# + " - "
# + str(input1_data[0][i])
# + " = "
# + str(output1_data[0][i])
# )
# if (input0_data[0][i] + input1_data[0][i]) != output0_data[0][i]:
# print("sync infer error: incorrect sum")
# sys.exit(1)
# if (input0_data[0][i] - input1_data[0][i]) != output1_data[0][i]:
# print("sync infer error: incorrect difference")
# sys.exit(1)
print("PASS: infer")