extensions
/
microsoftexcel-controlnet
/tests
/annotator_tests
/openpose_tests
/openpose_e2e_test.py
import unittest | |
import cv2 | |
import numpy as np | |
from typing import Dict | |
import importlib | |
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') | |
utils.setup_test_env() | |
from annotator.openpose import OpenposeDetector | |
class TestOpenposeDetector(unittest.TestCase): | |
image_path = './tests/images' | |
def setUp(self) -> None: | |
self.detector = OpenposeDetector() | |
self.detector.load_model() | |
def tearDown(self) -> None: | |
self.detector.unload_model() | |
def expect_same_image(self, img1, img2, diff_img_path: str): | |
# Calculate the difference between the two images | |
diff = cv2.absdiff(img1, img2) | |
# Set a threshold to highlight the different pixels | |
threshold = 30 | |
diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8) | |
# Assert that the two images are similar within a tolerance | |
similar = np.allclose(img1, img2, rtol=1e-05, atol=1e-08) | |
if not similar: | |
# Save the diff_highlighted image to inspect the differences | |
cv2.imwrite(diff_img_path, diff_highlighted) | |
self.assertTrue(similar) | |
# Save expectation image as png so that no compression issue happens. | |
def template(self, test_image: str, expected_image: str, detector_config: Dict, overwrite_expectation: bool = False): | |
oriImg = cv2.imread(test_image) | |
canvas = self.detector(oriImg, **detector_config) | |
# Create expectation file | |
if overwrite_expectation: | |
cv2.imwrite(expected_image, canvas) | |
else: | |
expected_canvas = cv2.imread(expected_image) | |
self.expect_same_image(canvas, expected_canvas, diff_img_path=expected_image.replace('.png', '_diff.png')) | |
def test_body(self): | |
self.template( | |
test_image = f'{TestOpenposeDetector.image_path}/ski.jpg', | |
expected_image = f'{TestOpenposeDetector.image_path}/expected_ski_output.png', | |
detector_config=dict(), | |
overwrite_expectation=False | |
) | |
def test_hand(self): | |
self.template( | |
test_image = f'{TestOpenposeDetector.image_path}/woman.jpeg', | |
expected_image = f'{TestOpenposeDetector.image_path}/expected_woman_hand_output.png', | |
detector_config=dict( | |
include_body=False, | |
include_face=False, | |
include_hand=True, | |
), | |
overwrite_expectation=False | |
) | |
def test_face(self): | |
self.template( | |
test_image = f'{TestOpenposeDetector.image_path}/woman.jpeg', | |
expected_image = f'{TestOpenposeDetector.image_path}/expected_woman_face_output.png', | |
detector_config=dict( | |
include_body=False, | |
include_face=True, | |
include_hand=False, | |
), | |
overwrite_expectation=False | |
) | |
def test_all(self): | |
self.template( | |
test_image = f'{TestOpenposeDetector.image_path}/woman.jpeg', | |
expected_image = f'{TestOpenposeDetector.image_path}/expected_woman_all_output.png', | |
detector_config=dict( | |
include_body=True, | |
include_face=True, | |
include_hand=True, | |
), | |
overwrite_expectation=False | |
) | |
if __name__ == '__main__': | |
unittest.main() |