Spaces:
Runtime error
Runtime error
LivePortrait
/
stf
/stf-api-alternative
/pytriton
/tests
/functional
/L0_example_identity_python
/test.py
#!/usr/bin/env python3 | |
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Test of identity_python example""" | |
import argparse | |
import logging | |
import re | |
import signal | |
import sys | |
import time | |
from tests.utils import ( | |
DEFAULT_LOG_FORMAT, | |
ScriptThread, | |
get_current_container_version, | |
search_warning_on_too_verbose_log_level, | |
verify_docker_image_in_readme_same_as_tested, | |
) | |
LOGGER = logging.getLogger((__package__ or "main").split(".")[-1]) | |
METADATA = { | |
"image_name": "nvcr.io/nvidia/pytorch:{TEST_CONTAINER_VERSION}-py3", | |
"platforms": ["amd64", "arm64"], | |
} | |
def verify_client_output(client_output): | |
input1_match = re.search(r"INPUT_1: (.*)", client_output, re.MULTILINE) | |
input2_match = re.search(r"INPUT_2: (.*)", client_output, re.MULTILINE) | |
output1_match = re.search(r"OUTPUT_1: (.*)", client_output, re.MULTILINE) | |
output2_match = re.search(r"OUTPUT_2: (.*)", client_output, re.MULTILINE) | |
input1_array = input1_match.group(1) if input1_match else None | |
input2_array = input2_match.group(1) if input2_match else None | |
output1_array = output1_match.group(1) if output1_match else None | |
output2_array = output2_match.group(1) if output2_match else None | |
if not input1_array or input1_array != output1_array: | |
raise ValueError(f"input1_array: {input1_array} differs from output1_array: {output1_array}") | |
if not input2_array or input2_array != output2_array: | |
raise ValueError(f"input2_array: {input2_array} differs from output2_array: {output2_array}") | |
LOGGER.info("Input and output arrays matches") | |
def main(): | |
parser = argparse.ArgumentParser(description="short_description") | |
parser.add_argument("--timeout-s", required=False, default=300, type=float, help="Timeout for test") | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.DEBUG, format=DEFAULT_LOG_FORMAT) | |
logging.captureWarnings(True) | |
docker_image_with_name = METADATA["image_name"].format(TEST_CONTAINER_VERSION=get_current_container_version()) | |
verify_docker_image_in_readme_same_as_tested("examples/identity_python//README.md", docker_image_with_name) | |
install_cmd = ["bash", "examples/identity_python/install.sh"] | |
with ScriptThread(install_cmd, name="install") as install_thread: | |
install_thread.join() | |
if install_thread.returncode != 0: | |
raise RuntimeError(f"Install thread returned {install_thread.returncode}") | |
start_time = time.time() | |
elapsed_s = 0 | |
wait_time_s = min(args.timeout_s, 1) | |
server_cmd = ["python", "examples/identity_python/server.py"] | |
client_cmd = ["python", "examples/identity_python/client.py"] | |
with ScriptThread(server_cmd, name="server") as server_thread: | |
with ScriptThread(client_cmd, name="client") as client_thread: | |
while server_thread.is_alive() and client_thread.is_alive() and elapsed_s < args.timeout_s: | |
client_thread.join(timeout=wait_time_s) | |
elapsed_s = time.time() - start_time | |
LOGGER.info("Interrupting server script process") | |
if server_thread.process: | |
server_thread.process.send_signal(signal.SIGINT) | |
if client_thread.returncode != 0: | |
raise RuntimeError(f"Client returned {client_thread.returncode}") | |
if server_thread.returncode not in [0, -2]: # -2 is returned when process finished after receiving SIGINT signal | |
raise RuntimeError(f"Server returned {server_thread.returncode}") | |
timeout = elapsed_s >= args.timeout_s and client_thread.is_alive() and server_thread.is_alive() | |
if timeout: | |
LOGGER.error(f"Timeout occurred (timeout_s={args.timeout_s})") | |
sys.exit(-2) | |
verify_client_output(client_thread.output) | |
assert not search_warning_on_too_verbose_log_level(server_thread.output) | |
if __name__ == "__main__": | |
main() | |