|
import tritonclient.http as http_client
|
|
from tritonclient.utils import *
|
|
import numpy as np
|
|
|
|
ENABLE_SSL = False
|
|
ENDPOINT_URL = 'localhost:8000'
|
|
HTTP_HEADERS = {"Authorization": "Bearer __PASTE_KEY_HERE__"}
|
|
|
|
|
|
if ENABLE_SSL:
|
|
import gevent.ssl
|
|
triton_http_client = http_client.InferenceServerClient(
|
|
url=ENDPOINT_URL, verbose=False,
|
|
ssl=True, ssl_context_factory=gevent.ssl._create_default_https_context,
|
|
)
|
|
else:
|
|
triton_http_client = http_client.InferenceServerClient(
|
|
url=ENDPOINT_URL, verbose=False,
|
|
)
|
|
|
|
print("Is server ready - {}".format(triton_http_client.is_server_ready(headers=HTTP_HEADERS)))
|
|
|
|
def get_string_tensor(string_values, tensor_name):
|
|
string_obj = np.array(string_values, dtype="object")
|
|
input_obj = http_client.InferInput(tensor_name, string_obj.shape, np_to_triton_dtype(string_obj.dtype))
|
|
input_obj.set_data_from_numpy(string_obj)
|
|
return input_obj
|
|
|
|
def get_translation_input_for_triton(texts: list, src_lang: str, tgt_lang: str):
|
|
return [
|
|
get_string_tensor([[text] for text in texts], "INPUT_TEXT"),
|
|
get_string_tensor([[src_lang]] * len(texts), "INPUT_LANGUAGE_ID"),
|
|
get_string_tensor([[tgt_lang]] * len(texts), "OUTPUT_LANGUAGE_ID"),
|
|
]
|
|
|
|
|
|
input_sentences = ["Hello world, I am Ram and I am from Ayodhya.", "How are you Ravan bro?"]
|
|
inputs = get_translation_input_for_triton(input_sentences, "en", "hi")
|
|
output0 = http_client.InferRequestedOutput("OUTPUT_TEXT")
|
|
|
|
|
|
response = triton_http_client.infer(
|
|
"nmt",
|
|
model_version='1',
|
|
inputs=inputs,
|
|
outputs=[output0],
|
|
headers=HTTP_HEADERS,
|
|
)
|
|
|
|
|
|
output_batch = response.as_numpy('OUTPUT_TEXT').tolist()
|
|
for input_sentence, translation in zip(input_sentences, output_batch):
|
|
print()
|
|
print(input_sentence)
|
|
print(translation[0].decode("utf-8"))
|
|
|