import unittest | |
import importlib | |
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') | |
utils.setup_test_env() | |
from scripts import external_code | |
class TestGetAllUnitsFrom(unittest.TestCase): | |
def setUp(self): | |
self.control_unit = { | |
"module": "none", | |
"model": utils.get_model(), | |
"image": utils.readImage("test/test_files/img2img_basic.png"), | |
"resize_mode": 1, | |
"low_vram": False, | |
"processor_res": 64, | |
"control_mode": external_code.ControlMode.BALANCED.value, | |
} | |
self.object_unit = external_code.ControlNetUnit(**self.control_unit) | |
def test_empty_converts(self): | |
script_args = [] | |
units = external_code.get_all_units_from(script_args) | |
self.assertListEqual(units, []) | |
def test_object_forwards(self): | |
script_args = [self.object_unit] | |
units = external_code.get_all_units_from(script_args) | |
self.assertListEqual(units, [self.object_unit]) | |
if __name__ == '__main__': | |
unittest.main() |