SpaceLLaVA / docker /client.py
salma-remyx's picture
Add SpaceLlaVA Triton Server
0ea9ef5 verified
raw
history blame
No virus
2.42 kB
import argparse
import time
import base64
import numpy as np
import requests
import os
from urllib.parse import urlparse
from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput
def download_image(image_url):
parsed_url = urlparse(image_url)
filename = os.path.basename(parsed_url.path)
response = requests.get(image_url)
if response.status_code == 200:
with open(filename, 'wb') as img_file:
img_file.write(response.content)
return filename
else:
raise Exception("Failed to download image")
def image_to_base64_data_uri(image_input):
with open(image_input, "rb") as img_file:
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
return base64_data
def setup_argparse():
parser = argparse.ArgumentParser(description="Client for Triton Inference Server")
parser.add_argument("--image_path", type=str, required=True, help="Path to the image or URL of the image to process")
parser.add_argument("--prompt", type=str, required=True, help="Prompt to be used for the inference")
return parser.parse_args()
if __name__ == "__main__":
args = setup_argparse()
triton_client = InferenceServerClient(url="localhost:8000", verbose=False)
if args.image_path.startswith('http://') or args.image_path.startswith('https://'):
image_path = download_image(args.image_path)
else:
image_path = args.image_path
image_data = image_to_base64_data_uri(image_path).encode('utf-8')
image_data_np = np.array([image_data], dtype=object)
prompt_np = np.array([args.prompt.encode('utf-8')], dtype=object)
images_in = InferInput(name="IMAGES", shape=[1], datatype="BYTES")
images_in.set_data_from_numpy(image_data_np, binary_data=True)
prompt_in = InferInput(name="PROMPT", shape=[1], datatype="BYTES")
prompt_in.set_data_from_numpy(prompt_np, binary_data=True)
results_out = InferRequestedOutput(name="RESULTS", binary_data=False)
start_time = time.time()
response = triton_client.infer(model_name="spacellava",
model_version="1",
inputs=[prompt_in, images_in],
outputs=[results_out])
results = response.get_response()["outputs"][0]["data"][0]
print("--- %s seconds ---" % (time.time() - start_time))
print(results)