surena26 commited on
Commit
6f7e8eb
·
verified ·
1 Parent(s): d9fb09c

Upload folder using huggingface_hub

Browse files
ComfyUI/tests/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automated Testing
2
+
3
+ ## Running tests locally
4
+
5
+ Additional requirements for running tests:
6
+ ```
7
+ pip install pytest
8
+ pip install websocket-client==1.6.1
9
+ opencv-python==4.6.0.66
10
+ scikit-image==0.21.0
11
+ ```
12
+ Run inference tests:
13
+ ```
14
+ pytest tests/inference
15
+ ```
16
+
17
+ ## Quality regression test
18
+ Compares images in 2 directories to ensure they are the same
19
+
20
+ 1) Run an inference test to save a directory of "ground truth" images
21
+ ```
22
+ pytest tests/inference --output_dir tests/inference/baseline
23
+ ```
24
+ 2) Make code edits
25
+
26
+ 3) Run inference and quality comparison tests
27
+ ```
28
+ pytest
29
+ ```
ComfyUI/tests/__init__.py ADDED
File without changes
ComfyUI/tests/compare/conftest.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ # Command line arguments for pytest
5
+ def pytest_addoption(parser):
6
+ parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
7
+ parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
8
+ parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
9
+ parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
10
+
11
+ # This initializes args at the beginning of the test session
12
+ @pytest.fixture(scope="session", autouse=True)
13
+ def args_pytest(pytestconfig):
14
+ args = {}
15
+ args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
16
+ args['test_dir'] = pytestconfig.getoption('test_dir')
17
+ args['metrics_file'] = pytestconfig.getoption('metrics_file')
18
+ args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
19
+
20
+ # Initialize metrics file
21
+ with open(args['metrics_file'], 'a') as f:
22
+ # if file is empty, write header
23
+ if os.stat(args['metrics_file']).st_size == 0:
24
+ f.write("| date | run | file | status | value | \n")
25
+ f.write("| --- | --- | --- | --- | --- | \n")
26
+
27
+ return args
28
+
29
+
30
+ def gather_file_basenames(directory: str):
31
+ files = []
32
+ for file in os.listdir(directory):
33
+ if file.endswith(".png"):
34
+ files.append(file)
35
+ return files
36
+
37
+ # Creates the list of baseline file names to use as a fixture
38
+ def pytest_generate_tests(metafunc):
39
+ if "baseline_fname" in metafunc.fixturenames:
40
+ baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
41
+ metafunc.parametrize("baseline_fname", baseline_fnames)
ComfyUI/tests/compare/test_quality.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import pytest
6
+ from pytest import fixture
7
+ from typing import Tuple, List
8
+
9
+ from cv2 import imread, cvtColor, COLOR_BGR2RGB
10
+ from skimage.metrics import structural_similarity as ssim
11
+
12
+
13
+ """
14
+ This test suite compares images in 2 directories by file name
15
+ The directories are specified by the command line arguments --baseline_dir and --test_dir
16
+
17
+ """
18
+ # ssim: Structural Similarity Index
19
+ # Returns a tuple of (ssim, diff_image)
20
+ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
21
+ score, diff = ssim(img0, img1, channel_axis=-1, full=True)
22
+ # rescale the difference image to 0-255 range
23
+ diff = (diff * 255).astype("uint8")
24
+ return score, diff
25
+
26
+ # Metrics must return a tuple of (score, diff_image)
27
+ METRICS = {"ssim": ssim_score}
28
+ METRICS_PASS_THRESHOLD = {"ssim": 0.95}
29
+
30
+
31
+ class TestCompareImageMetrics:
32
+ @fixture(scope="class")
33
+ def test_file_names(self, args_pytest):
34
+ test_dir = args_pytest['test_dir']
35
+ fnames = self.gather_file_basenames(test_dir)
36
+ yield fnames
37
+ del fnames
38
+
39
+ @fixture(scope="class", autouse=True)
40
+ def teardown(self, args_pytest):
41
+ yield
42
+ # Runs after all tests are complete
43
+ # Aggregate output files into a grid of images
44
+ baseline_dir = args_pytest['baseline_dir']
45
+ test_dir = args_pytest['test_dir']
46
+ img_output_dir = args_pytest['img_output_dir']
47
+ metrics_file = args_pytest['metrics_file']
48
+
49
+ grid_dir = os.path.join(img_output_dir, "grid")
50
+ os.makedirs(grid_dir, exist_ok=True)
51
+
52
+ for metric_dir in METRICS.keys():
53
+ metric_path = os.path.join(img_output_dir, metric_dir)
54
+ for file in os.listdir(metric_path):
55
+ if file.endswith(".png"):
56
+ score = self.lookup_score_from_fname(file, metrics_file)
57
+ image_file_list = []
58
+ image_file_list.append([
59
+ os.path.join(baseline_dir, file),
60
+ os.path.join(test_dir, file),
61
+ os.path.join(metric_path, file)
62
+ ])
63
+ # Create grid
64
+ image_list = [[Image.open(file) for file in files] for files in image_file_list]
65
+ grid = self.image_grid(image_list)
66
+ grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
67
+
68
+ # Tests run for each baseline file name
69
+ @fixture()
70
+ def fname(self, baseline_fname):
71
+ yield baseline_fname
72
+ del baseline_fname
73
+
74
+ def test_directories_not_empty(self, args_pytest):
75
+ baseline_dir = args_pytest['baseline_dir']
76
+ test_dir = args_pytest['test_dir']
77
+ assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
78
+ assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
79
+
80
+ def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
81
+ # Check that all files in baseline_dir have a file in test_dir with matching metadata
82
+ baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
83
+ file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
84
+ file_match = self.find_file_match(baseline_file_path, file_paths)
85
+ assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
86
+
87
+ # For a baseline image file, finds the corresponding file name in test_dir and
88
+ # compares the images using the metrics in METRICS
89
+ @pytest.mark.parametrize("metric", METRICS.keys())
90
+ def test_pipeline_compare(
91
+ self,
92
+ args_pytest,
93
+ fname,
94
+ test_file_names,
95
+ metric,
96
+ ):
97
+ baseline_dir = args_pytest['baseline_dir']
98
+ test_dir = args_pytest['test_dir']
99
+ metrics_output_file = args_pytest['metrics_file']
100
+ img_output_dir = args_pytest['img_output_dir']
101
+
102
+ baseline_file_path = os.path.join(baseline_dir, fname)
103
+
104
+ # Find file match
105
+ file_paths = [os.path.join(test_dir, f) for f in test_file_names]
106
+ test_file = self.find_file_match(baseline_file_path, file_paths)
107
+
108
+ # Run metrics
109
+ sample_baseline = self.read_img(baseline_file_path)
110
+ sample_secondary = self.read_img(test_file)
111
+
112
+ score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
113
+ metric_status = score > METRICS_PASS_THRESHOLD[metric]
114
+
115
+ # Save metric values
116
+ with open(metrics_output_file, 'a') as f:
117
+ run_info = os.path.splitext(fname)[0]
118
+ metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
119
+ date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
+ f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
121
+
122
+ # Save metric image
123
+ metric_img_dir = os.path.join(img_output_dir, metric)
124
+ os.makedirs(metric_img_dir, exist_ok=True)
125
+ output_filename = f'{fname}'
126
+ Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
127
+
128
+ assert score > METRICS_PASS_THRESHOLD[metric]
129
+
130
+ def read_img(self, filename: str) -> np.ndarray:
131
+ cvImg = imread(filename)
132
+ cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
133
+ return cvImg
134
+
135
+ def image_grid(self, img_list: list[list[Image.Image]]):
136
+ # imgs is a 2D list of images
137
+ # Assumes the input images are a rectangular grid of equal sized images
138
+ rows = len(img_list)
139
+ cols = len(img_list[0])
140
+
141
+ w, h = img_list[0][0].size
142
+ grid = Image.new('RGB', size=(cols*w, rows*h))
143
+
144
+ for i, row in enumerate(img_list):
145
+ for j, img in enumerate(row):
146
+ grid.paste(img, box=(j*w, i*h))
147
+ return grid
148
+
149
+ def lookup_score_from_fname(self,
150
+ fname: str,
151
+ metrics_output_file: str
152
+ ) -> float:
153
+ fname_basestr = os.path.splitext(fname)[0]
154
+ with open(metrics_output_file, 'r') as f:
155
+ for line in f:
156
+ if fname_basestr in line:
157
+ score = float(line.split('|')[5])
158
+ return score
159
+ raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
160
+
161
+ def gather_file_basenames(self, directory: str):
162
+ files = []
163
+ for file in os.listdir(directory):
164
+ if file.endswith(".png"):
165
+ files.append(file)
166
+ return files
167
+
168
+ def read_file_prompt(self, fname:str) -> str:
169
+ # Read prompt from image file metadata
170
+ img = Image.open(fname)
171
+ img.load()
172
+ return img.info['prompt']
173
+
174
+ def find_file_match(self, baseline_file: str, file_paths: List[str]):
175
+ # Find a file in file_paths with matching metadata to baseline_file
176
+ baseline_prompt = self.read_file_prompt(baseline_file)
177
+
178
+ # Do not match empty prompts
179
+ if baseline_prompt is None or baseline_prompt == "":
180
+ return None
181
+
182
+ # Find file match
183
+ # Reorder test_file_names so that the file with matching name is first
184
+ # This is an optimization because matching file names are more likely
185
+ # to have matching metadata if they were generated with the same script
186
+ basename = os.path.basename(baseline_file)
187
+ file_path_basenames = [os.path.basename(f) for f in file_paths]
188
+ if basename in file_path_basenames:
189
+ match_index = file_path_basenames.index(basename)
190
+ file_paths.insert(0, file_paths.pop(match_index))
191
+
192
+ for f in file_paths:
193
+ test_file_prompt = self.read_file_prompt(f)
194
+ if baseline_prompt == test_file_prompt:
195
+ return f
ComfyUI/tests/conftest.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ # Command line arguments for pytest
5
+ def pytest_addoption(parser):
6
+ parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
7
+ parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
8
+ parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
9
+
10
+ # This initializes args at the beginning of the test session
11
+ @pytest.fixture(scope="session", autouse=True)
12
+ def args_pytest(pytestconfig):
13
+ args = {}
14
+ args['output_dir'] = pytestconfig.getoption('output_dir')
15
+ args['listen'] = pytestconfig.getoption('listen')
16
+ args['port'] = pytestconfig.getoption('port')
17
+
18
+ os.makedirs(args['output_dir'], exist_ok=True)
19
+
20
+ return args
21
+
22
+ def pytest_collection_modifyitems(items):
23
+ # Modifies items so tests run in the correct order
24
+
25
+ LAST_TESTS = ['test_quality']
26
+
27
+ # Move the last items to the end
28
+ last_items = []
29
+ for test_name in LAST_TESTS:
30
+ for item in items.copy():
31
+ print(item.module.__name__, item)
32
+ if item.module.__name__ == test_name:
33
+ last_items.append(item)
34
+ items.remove(item)
35
+
36
+ items.extend(last_items)
ComfyUI/tests/inference/__init__.py ADDED
File without changes
ComfyUI/tests/inference/graphs/default_graph_sdxl1_0.json ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "inputs": {
4
+ "ckpt_name": "sd_xl_base_1.0.safetensors"
5
+ },
6
+ "class_type": "CheckpointLoaderSimple"
7
+ },
8
+ "5": {
9
+ "inputs": {
10
+ "width": 1024,
11
+ "height": 1024,
12
+ "batch_size": 1
13
+ },
14
+ "class_type": "EmptyLatentImage"
15
+ },
16
+ "6": {
17
+ "inputs": {
18
+ "text": "a photo of a cat",
19
+ "clip": [
20
+ "4",
21
+ 1
22
+ ]
23
+ },
24
+ "class_type": "CLIPTextEncode"
25
+ },
26
+ "10": {
27
+ "inputs": {
28
+ "add_noise": "enable",
29
+ "noise_seed": 42,
30
+ "steps": 20,
31
+ "cfg": 7.5,
32
+ "sampler_name": "euler",
33
+ "scheduler": "normal",
34
+ "start_at_step": 0,
35
+ "end_at_step": 32,
36
+ "return_with_leftover_noise": "enable",
37
+ "model": [
38
+ "4",
39
+ 0
40
+ ],
41
+ "positive": [
42
+ "6",
43
+ 0
44
+ ],
45
+ "negative": [
46
+ "15",
47
+ 0
48
+ ],
49
+ "latent_image": [
50
+ "5",
51
+ 0
52
+ ]
53
+ },
54
+ "class_type": "KSamplerAdvanced"
55
+ },
56
+ "12": {
57
+ "inputs": {
58
+ "samples": [
59
+ "14",
60
+ 0
61
+ ],
62
+ "vae": [
63
+ "4",
64
+ 2
65
+ ]
66
+ },
67
+ "class_type": "VAEDecode"
68
+ },
69
+ "13": {
70
+ "inputs": {
71
+ "filename_prefix": "test_inference",
72
+ "images": [
73
+ "12",
74
+ 0
75
+ ]
76
+ },
77
+ "class_type": "SaveImage"
78
+ },
79
+ "14": {
80
+ "inputs": {
81
+ "add_noise": "disable",
82
+ "noise_seed": 42,
83
+ "steps": 20,
84
+ "cfg": 7.5,
85
+ "sampler_name": "euler",
86
+ "scheduler": "normal",
87
+ "start_at_step": 32,
88
+ "end_at_step": 10000,
89
+ "return_with_leftover_noise": "disable",
90
+ "model": [
91
+ "16",
92
+ 0
93
+ ],
94
+ "positive": [
95
+ "17",
96
+ 0
97
+ ],
98
+ "negative": [
99
+ "20",
100
+ 0
101
+ ],
102
+ "latent_image": [
103
+ "10",
104
+ 0
105
+ ]
106
+ },
107
+ "class_type": "KSamplerAdvanced"
108
+ },
109
+ "15": {
110
+ "inputs": {
111
+ "conditioning": [
112
+ "6",
113
+ 0
114
+ ]
115
+ },
116
+ "class_type": "ConditioningZeroOut"
117
+ },
118
+ "16": {
119
+ "inputs": {
120
+ "ckpt_name": "sd_xl_refiner_1.0.safetensors"
121
+ },
122
+ "class_type": "CheckpointLoaderSimple"
123
+ },
124
+ "17": {
125
+ "inputs": {
126
+ "text": "a photo of a cat",
127
+ "clip": [
128
+ "16",
129
+ 1
130
+ ]
131
+ },
132
+ "class_type": "CLIPTextEncode"
133
+ },
134
+ "20": {
135
+ "inputs": {
136
+ "text": "",
137
+ "clip": [
138
+ "16",
139
+ 1
140
+ ]
141
+ },
142
+ "class_type": "CLIPTextEncode"
143
+ }
144
+ }
ComfyUI/tests/inference/test_inference.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from io import BytesIO
3
+ from urllib import request
4
+ import numpy
5
+ import os
6
+ from PIL import Image
7
+ import pytest
8
+ from pytest import fixture
9
+ import time
10
+ import torch
11
+ from typing import Union
12
+ import json
13
+ import subprocess
14
+ import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
15
+ import uuid
16
+ import urllib.request
17
+ import urllib.parse
18
+
19
+
20
+ from comfy.samplers import KSampler
21
+
22
+ """
23
+ These tests generate and save images through a range of parameters
24
+ """
25
+
26
+ class ComfyGraph:
27
+ def __init__(self,
28
+ graph: dict,
29
+ sampler_nodes: list[str],
30
+ ):
31
+ self.graph = graph
32
+ self.sampler_nodes = sampler_nodes
33
+
34
+ def set_prompt(self, prompt, negative_prompt=None):
35
+ # Sets the prompt for the sampler nodes (eg. base and refiner)
36
+ for node in self.sampler_nodes:
37
+ prompt_node = self.graph[node]['inputs']['positive'][0]
38
+ self.graph[prompt_node]['inputs']['text'] = prompt
39
+ if negative_prompt:
40
+ negative_prompt_node = self.graph[node]['inputs']['negative'][0]
41
+ self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
42
+
43
+ def set_sampler_name(self, sampler_name:str, ):
44
+ # sets the sampler name for the sampler nodes (eg. base and refiner)
45
+ for node in self.sampler_nodes:
46
+ self.graph[node]['inputs']['sampler_name'] = sampler_name
47
+
48
+ def set_scheduler(self, scheduler:str):
49
+ # sets the sampler name for the sampler nodes (eg. base and refiner)
50
+ for node in self.sampler_nodes:
51
+ self.graph[node]['inputs']['scheduler'] = scheduler
52
+
53
+ def set_filename_prefix(self, prefix:str):
54
+ # sets the filename prefix for the save nodes
55
+ for node in self.graph:
56
+ if self.graph[node]['class_type'] == 'SaveImage':
57
+ self.graph[node]['inputs']['filename_prefix'] = prefix
58
+
59
+
60
+ class ComfyClient:
61
+ # From examples/websockets_api_example.py
62
+
63
+ def connect(self,
64
+ listen:str = '127.0.0.1',
65
+ port:Union[str,int] = 8188,
66
+ client_id: str = str(uuid.uuid4())
67
+ ):
68
+ self.client_id = client_id
69
+ self.server_address = f"{listen}:{port}"
70
+ ws = websocket.WebSocket()
71
+ ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
72
+ self.ws = ws
73
+
74
+ def queue_prompt(self, prompt):
75
+ p = {"prompt": prompt, "client_id": self.client_id}
76
+ data = json.dumps(p).encode('utf-8')
77
+ req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
78
+ return json.loads(urllib.request.urlopen(req).read())
79
+
80
+ def get_image(self, filename, subfolder, folder_type):
81
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
82
+ url_values = urllib.parse.urlencode(data)
83
+ with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
84
+ return response.read()
85
+
86
+ def get_history(self, prompt_id):
87
+ with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
88
+ return json.loads(response.read())
89
+
90
+ def get_images(self, graph, save=True):
91
+ prompt = graph
92
+ if not save:
93
+ # Replace save nodes with preview nodes
94
+ prompt_str = json.dumps(prompt)
95
+ prompt_str = prompt_str.replace('SaveImage', 'PreviewImage')
96
+ prompt = json.loads(prompt_str)
97
+
98
+ prompt_id = self.queue_prompt(prompt)['prompt_id']
99
+ output_images = {}
100
+ while True:
101
+ out = self.ws.recv()
102
+ if isinstance(out, str):
103
+ message = json.loads(out)
104
+ if message['type'] == 'executing':
105
+ data = message['data']
106
+ if data['node'] is None and data['prompt_id'] == prompt_id:
107
+ break #Execution is done
108
+ else:
109
+ continue #previews are binary data
110
+
111
+ history = self.get_history(prompt_id)[prompt_id]
112
+ for o in history['outputs']:
113
+ for node_id in history['outputs']:
114
+ node_output = history['outputs'][node_id]
115
+ if 'images' in node_output:
116
+ images_output = []
117
+ for image in node_output['images']:
118
+ image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
119
+ images_output.append(image_data)
120
+ output_images[node_id] = images_output
121
+
122
+ return output_images
123
+
124
+ #
125
+ # Initialize graphs
126
+ #
127
+ default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
128
+ with open(default_graph_file, 'r') as file:
129
+ default_graph = json.loads(file.read())
130
+ DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
131
+ DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
132
+
133
+ #
134
+ # Loop through these variables
135
+ #
136
+ comfy_graph_list = [DEFAULT_COMFY_GRAPH]
137
+ comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
138
+ prompt_list = [
139
+ 'a painting of a cat',
140
+ ]
141
+
142
+ sampler_list = KSampler.SAMPLERS
143
+ scheduler_list = KSampler.SCHEDULERS
144
+
145
+ @pytest.mark.inference
146
+ @pytest.mark.parametrize("sampler", sampler_list)
147
+ @pytest.mark.parametrize("scheduler", scheduler_list)
148
+ @pytest.mark.parametrize("prompt", prompt_list)
149
+ class TestInference:
150
+ #
151
+ # Initialize server and client
152
+ #
153
+ @fixture(scope="class", autouse=True)
154
+ def _server(self, args_pytest):
155
+ # Start server
156
+ p = subprocess.Popen([
157
+ 'python','main.py',
158
+ '--output-directory', args_pytest["output_dir"],
159
+ '--listen', args_pytest["listen"],
160
+ '--port', str(args_pytest["port"]),
161
+ ])
162
+ yield
163
+ p.kill()
164
+ torch.cuda.empty_cache()
165
+
166
+ def start_client(self, listen:str, port:int):
167
+ # Start client
168
+ comfy_client = ComfyClient()
169
+ # Connect to server (with retries)
170
+ n_tries = 5
171
+ for i in range(n_tries):
172
+ time.sleep(4)
173
+ try:
174
+ comfy_client.connect(listen=listen, port=port)
175
+ except ConnectionRefusedError as e:
176
+ print(e)
177
+ print(f"({i+1}/{n_tries}) Retrying...")
178
+ else:
179
+ break
180
+ return comfy_client
181
+
182
+ #
183
+ # Client and graph fixtures with server warmup
184
+ #
185
+ # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
186
+ # The "graph" is the default graph
187
+ @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
188
+ def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
189
+ comfy_graph = request.param
190
+
191
+ # Start client
192
+ comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
193
+
194
+ # Warm up pipeline
195
+ comfy_client.get_images(graph=comfy_graph.graph, save=False)
196
+
197
+ yield comfy_client, comfy_graph
198
+ del comfy_client
199
+ del comfy_graph
200
+ torch.cuda.empty_cache()
201
+
202
+ @fixture
203
+ def client(self, _client_graph):
204
+ client = _client_graph[0]
205
+ yield client
206
+
207
+ @fixture
208
+ def comfy_graph(self, _client_graph):
209
+ # avoid mutating the graph
210
+ graph = deepcopy(_client_graph[1])
211
+ yield graph
212
+
213
+ def test_comfy(
214
+ self,
215
+ client,
216
+ comfy_graph,
217
+ sampler,
218
+ scheduler,
219
+ prompt,
220
+ request
221
+ ):
222
+ test_info = request.node.name
223
+ comfy_graph.set_filename_prefix(test_info)
224
+ # Settings for comfy graph
225
+ comfy_graph.set_sampler_name(sampler)
226
+ comfy_graph.set_scheduler(scheduler)
227
+ comfy_graph.set_prompt(prompt)
228
+
229
+ # Generate
230
+ images = client.get_images(comfy_graph.graph)
231
+
232
+ assert len(images) != 0, "No images generated"
233
+ # assert all images are not blank
234
+ for images_output in images.values():
235
+ for image_data in images_output:
236
+ pil_image = Image.open(BytesIO(image_data))
237
+ assert numpy.array(pil_image).any() != 0, "Image is blank"
238
+
239
+