Text Generation
Transformers
Safetensors
GGUF
llava
remyx
Inference Endpoints
File size: 2,415 Bytes
0ea9ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import time
import base64
import numpy as np
import requests
import os
from urllib.parse import urlparse
from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput

def download_image(image_url):
    parsed_url = urlparse(image_url)
    filename = os.path.basename(parsed_url.path)
    response = requests.get(image_url)
    if response.status_code == 200:
        with open(filename, 'wb') as img_file:
            img_file.write(response.content)
        return filename
    else:
        raise Exception("Failed to download image")

def image_to_base64_data_uri(image_input):
    with open(image_input, "rb") as img_file:
        base64_data = base64.b64encode(img_file.read()).decode('utf-8')
    return base64_data

def setup_argparse():
    parser = argparse.ArgumentParser(description="Client for Triton Inference Server")
    parser.add_argument("--image_path", type=str, required=True, help="Path to the image or URL of the image to process")
    parser.add_argument("--prompt", type=str, required=True, help="Prompt to be used for the inference")
    return parser.parse_args()

if __name__ == "__main__":
    args = setup_argparse()

    triton_client = InferenceServerClient(url="localhost:8000", verbose=False)

    if args.image_path.startswith('http://') or args.image_path.startswith('https://'):
        image_path = download_image(args.image_path)
    else:
        image_path = args.image_path

    image_data = image_to_base64_data_uri(image_path).encode('utf-8')
    image_data_np = np.array([image_data], dtype=object)
    prompt_np = np.array([args.prompt.encode('utf-8')], dtype=object)

    images_in = InferInput(name="IMAGES", shape=[1], datatype="BYTES")
    images_in.set_data_from_numpy(image_data_np, binary_data=True)
    prompt_in = InferInput(name="PROMPT", shape=[1], datatype="BYTES")
    prompt_in.set_data_from_numpy(prompt_np, binary_data=True)

    results_out = InferRequestedOutput(name="RESULTS", binary_data=False)

    start_time = time.time()
    response = triton_client.infer(model_name="spacellava",
                                   model_version="1",
                                   inputs=[prompt_in, images_in],
                                   outputs=[results_out])

    results = response.get_response()["outputs"][0]["data"][0]
    print("--- %s seconds ---" % (time.time() - start_time))
    print(results)