Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import typing | |
| import numpy as np | |
| import tritonclient.grpc | |
| import tritonclient.http | |
| import tritonclient.utils | |
| from pytriton.model_config.generator import ModelConfigGenerator | |
| from pytriton.model_config.triton_model_config import TritonModelConfig | |
| def verify_equalness_of_dicts_with_ndarray(a_dict, b_dict): | |
| assert a_dict.keys() == b_dict.keys(), f"{a_dict} != {b_dict}" | |
| for output_name in a_dict: | |
| assert isinstance( | |
| a_dict[output_name], type(b_dict[output_name]) | |
| ), f"type(a[{output_name}])={type(a_dict[output_name])} != type(b[{output_name}])={type(b_dict[output_name])}" | |
| if isinstance(a_dict[output_name], np.ndarray): | |
| assert a_dict[output_name].dtype == b_dict[output_name].dtype | |
| assert a_dict[output_name].shape == b_dict[output_name].shape | |
| if np.issubdtype(a_dict[output_name].dtype, np.number): | |
| assert np.allclose(b_dict[output_name], a_dict[output_name]) | |
| else: | |
| assert np.array_equal(b_dict[output_name], a_dict[output_name]) | |
| else: | |
| assert a_dict[output_name] == b_dict[output_name] | |
| def wrap_to_grpc_infer_result( | |
| model_config: TritonModelConfig, request_id: str, outputs_dict: typing.Dict[str, np.ndarray] | |
| ): | |
| raw_output_contents = [output_data.tobytes() for output_data in outputs_dict.values()] | |
| return tritonclient.grpc.InferResult( | |
| tritonclient.grpc.service_pb2.ModelInferResponse( | |
| model_name=model_config.model_name, | |
| model_version=str(model_config.model_version), | |
| id=request_id, | |
| outputs=[ | |
| tritonclient.grpc.service_pb2.ModelInferResponse.InferOutputTensor( | |
| name=output_name, | |
| datatype=tritonclient.utils.np_to_triton_dtype(output_data.dtype), | |
| shape=output_data.shape, | |
| ) | |
| for output_name, output_data in outputs_dict.items() | |
| ], | |
| raw_output_contents=raw_output_contents, | |
| ) | |
| ) | |
| def wrap_to_http_infer_result( | |
| model_config: TritonModelConfig, request_id: str, outputs_dict: typing.Dict[str, np.ndarray] | |
| ): | |
| raw_output_contents = [output_data.tobytes() for output_data in outputs_dict.values()] | |
| buffer = b"".join(raw_output_contents) | |
| content = { | |
| "outputs": [ | |
| { | |
| "name": name, | |
| "datatype": tritonclient.utils.np_to_triton_dtype(output_data.dtype), | |
| "shape": list(output_data.shape), | |
| "parameters": {"binary_data_size": len(output_data.tobytes())}, | |
| } | |
| for name, output_data in outputs_dict.items() | |
| ] | |
| } | |
| header = json.dumps(content).encode("utf-8") | |
| response_body = header + buffer | |
| return tritonclient.http.InferResult.from_response_body(response_body, False, len(header)) | |
| def extract_array_from_grpc_infer_input(input_: tritonclient.grpc.InferInput): | |
| np_array = np.frombuffer(input_._raw_content, dtype=tritonclient.utils.triton_to_np_dtype(input_.datatype())) | |
| np_array = np_array.reshape(input_.shape()) | |
| return np_array | |
| def extract_array_from_http_infer_input(input_: tritonclient.http.InferInput): | |
| np_array = np.frombuffer(input_._raw_data, dtype=tritonclient.utils.triton_to_np_dtype(input_.datatype())) | |
| np_array = np_array.reshape(input_.shape()) | |
| return np_array | |
| def patch_grpc_client__server_up_and_ready(mocker): | |
| mocker.patch.object(tritonclient.grpc.InferenceServerClient, "is_server_ready").return_value = True | |
| mocker.patch.object(tritonclient.grpc.InferenceServerClient, "is_server_live").return_value = True | |
| def patch_http_client__server_up_and_ready(mocker): | |
| mocker.patch.object(tritonclient.http.InferenceServerClient, "is_server_ready").return_value = True | |
| mocker.patch.object(tritonclient.http.InferenceServerClient, "is_server_live").return_value = True | |
| def patch_grpc_client__model_up_and_ready(mocker, model_config: TritonModelConfig): | |
| from google.protobuf import json_format # pytype: disable=pyi-error | |
| from tritonclient.grpc import model_config_pb2, service_pb2 # pytype: disable=pyi-error | |
| mock_get_repo_index = mocker.patch.object(tritonclient.grpc.InferenceServerClient, "get_model_repository_index") | |
| mock_get_repo_index.return_value = service_pb2.RepositoryIndexResponse( | |
| models=[ | |
| service_pb2.RepositoryIndexResponse.ModelIndex( | |
| name=model_config.model_name, version="1", state="READY", reason="" | |
| ), | |
| ] | |
| ) | |
| mocker.patch.object(tritonclient.grpc.InferenceServerClient, "is_model_ready").return_value = True | |
| model_config_dict = ModelConfigGenerator(model_config).get_config() | |
| model_config_protobuf = json_format.ParseDict(model_config_dict, model_config_pb2.ModelConfig()) | |
| response = service_pb2.ModelConfigResponse(config=model_config_protobuf) | |
| response_dict = json.loads(json_format.MessageToJson(response, preserving_proto_field_name=True)) | |
| mock_get_model_config = mocker.patch.object(tritonclient.grpc.InferenceServerClient, "get_model_config") | |
| mock_get_model_config.return_value = response_dict | |
| def patch_http_client__model_up_and_ready(mocker, model_config: TritonModelConfig): | |
| mock_get_repo_index = mocker.patch.object(tritonclient.http.InferenceServerClient, "get_model_repository_index") | |
| mock_get_repo_index.return_value = [ | |
| {"name": model_config.model_name, "version": "1", "state": "READY", "reason": ""} | |
| ] | |
| mocker.patch.object(tritonclient.http.InferenceServerClient, "is_model_ready").return_value = True | |
| model_config_dict = ModelConfigGenerator(model_config).get_config() | |
| mock_get_model_config = mocker.patch.object(tritonclient.http.InferenceServerClient, "get_model_config") | |
| mock_get_model_config.return_value = model_config_dict | |