File size: 2,415 Bytes
0ea9ef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import argparse
import time
import base64
import numpy as np
import requests
import os
from urllib.parse import urlparse
from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput
def download_image(image_url):
parsed_url = urlparse(image_url)
filename = os.path.basename(parsed_url.path)
response = requests.get(image_url)
if response.status_code == 200:
with open(filename, 'wb') as img_file:
img_file.write(response.content)
return filename
else:
raise Exception("Failed to download image")
def image_to_base64_data_uri(image_input):
with open(image_input, "rb") as img_file:
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
return base64_data
def setup_argparse():
parser = argparse.ArgumentParser(description="Client for Triton Inference Server")
parser.add_argument("--image_path", type=str, required=True, help="Path to the image or URL of the image to process")
parser.add_argument("--prompt", type=str, required=True, help="Prompt to be used for the inference")
return parser.parse_args()
if __name__ == "__main__":
args = setup_argparse()
triton_client = InferenceServerClient(url="localhost:8000", verbose=False)
if args.image_path.startswith('http://') or args.image_path.startswith('https://'):
image_path = download_image(args.image_path)
else:
image_path = args.image_path
image_data = image_to_base64_data_uri(image_path).encode('utf-8')
image_data_np = np.array([image_data], dtype=object)
prompt_np = np.array([args.prompt.encode('utf-8')], dtype=object)
images_in = InferInput(name="IMAGES", shape=[1], datatype="BYTES")
images_in.set_data_from_numpy(image_data_np, binary_data=True)
prompt_in = InferInput(name="PROMPT", shape=[1], datatype="BYTES")
prompt_in.set_data_from_numpy(prompt_np, binary_data=True)
results_out = InferRequestedOutput(name="RESULTS", binary_data=False)
start_time = time.time()
response = triton_client.infer(model_name="spacellava",
model_version="1",
inputs=[prompt_in, images_in],
outputs=[results_out])
results = response.get_response()["outputs"][0]["data"][0]
print("--- %s seconds ---" % (time.time() - start_time))
print(results)
|