jnkr36 commited on
Commit
c1ac2f2
1 Parent(s): 9cd0422

Upload 7 files

Browse files
Files changed (7) hide show
  1. execution.py +371 -0
  2. extra_model_paths.yaml.example +23 -0
  3. folder_paths.py +69 -0
  4. main.py +162 -0
  5. nodes.py +1115 -0
  6. requirements.txt +11 -0
  7. server.py +294 -0
execution.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import json
5
+ import threading
6
+ import heapq
7
+ import traceback
8
+ import gc
9
+
10
+ import torch
11
+ import nodes
12
+
13
+ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
14
+ valid_inputs = class_def.INPUT_TYPES()
15
+ input_data_all = {}
16
+ for x in inputs:
17
+ input_data = inputs[x]
18
+ if isinstance(input_data, list):
19
+ input_unique_id = input_data[0]
20
+ output_index = input_data[1]
21
+ if input_unique_id not in outputs:
22
+ return None
23
+ obj = outputs[input_unique_id][output_index]
24
+ input_data_all[x] = obj
25
+ else:
26
+ if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
27
+ input_data_all[x] = input_data
28
+
29
+ if "hidden" in valid_inputs:
30
+ h = valid_inputs["hidden"]
31
+ for x in h:
32
+ if h[x] == "PROMPT":
33
+ input_data_all[x] = prompt
34
+ if h[x] == "EXTRA_PNGINFO":
35
+ if "extra_pnginfo" in extra_data:
36
+ input_data_all[x] = extra_data['extra_pnginfo']
37
+ if h[x] == "UNIQUE_ID":
38
+ input_data_all[x] = unique_id
39
+ return input_data_all
40
+
41
+ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
42
+ unique_id = current_item
43
+ inputs = prompt[unique_id]['inputs']
44
+ class_type = prompt[unique_id]['class_type']
45
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
46
+ if unique_id in outputs:
47
+ return []
48
+
49
+ executed = []
50
+
51
+ for x in inputs:
52
+ input_data = inputs[x]
53
+
54
+ if isinstance(input_data, list):
55
+ input_unique_id = input_data[0]
56
+ output_index = input_data[1]
57
+ if input_unique_id not in outputs:
58
+ executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
59
+
60
+ input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
61
+ if server.client_id is not None:
62
+ server.last_node_id = unique_id
63
+ server.send_sync("executing", { "node": unique_id }, server.client_id)
64
+ obj = class_def()
65
+
66
+ nodes.before_node_execution()
67
+ outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
68
+ if "ui" in outputs[unique_id]:
69
+ if server.client_id is not None:
70
+ server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
71
+ if "result" in outputs[unique_id]:
72
+ outputs[unique_id] = outputs[unique_id]["result"]
73
+ return executed + [unique_id]
74
+
75
+ def recursive_will_execute(prompt, outputs, current_item):
76
+ unique_id = current_item
77
+ inputs = prompt[unique_id]['inputs']
78
+ will_execute = []
79
+ if unique_id in outputs:
80
+ return []
81
+
82
+ for x in inputs:
83
+ input_data = inputs[x]
84
+ if isinstance(input_data, list):
85
+ input_unique_id = input_data[0]
86
+ output_index = input_data[1]
87
+ if input_unique_id not in outputs:
88
+ will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
89
+
90
+ return will_execute + [unique_id]
91
+
92
+ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
93
+ unique_id = current_item
94
+ inputs = prompt[unique_id]['inputs']
95
+ class_type = prompt[unique_id]['class_type']
96
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
97
+
98
+ is_changed_old = ''
99
+ is_changed = ''
100
+ if hasattr(class_def, 'IS_CHANGED'):
101
+ if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
102
+ is_changed_old = old_prompt[unique_id]['is_changed']
103
+ if 'is_changed' not in prompt[unique_id]:
104
+ input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
105
+ if input_data_all is not None:
106
+ is_changed = class_def.IS_CHANGED(**input_data_all)
107
+ prompt[unique_id]['is_changed'] = is_changed
108
+ else:
109
+ is_changed = prompt[unique_id]['is_changed']
110
+
111
+ if unique_id not in outputs:
112
+ return True
113
+
114
+ to_delete = False
115
+ if is_changed != is_changed_old:
116
+ to_delete = True
117
+ elif unique_id not in old_prompt:
118
+ to_delete = True
119
+ elif inputs == old_prompt[unique_id]['inputs']:
120
+ for x in inputs:
121
+ input_data = inputs[x]
122
+
123
+ if isinstance(input_data, list):
124
+ input_unique_id = input_data[0]
125
+ output_index = input_data[1]
126
+ if input_unique_id in outputs:
127
+ to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
128
+ else:
129
+ to_delete = True
130
+ if to_delete:
131
+ break
132
+ else:
133
+ to_delete = True
134
+
135
+ if to_delete:
136
+ d = outputs.pop(unique_id)
137
+ del d
138
+ return to_delete
139
+
140
+ class PromptExecutor:
141
+ def __init__(self, server):
142
+ self.outputs = {}
143
+ self.old_prompt = {}
144
+ self.server = server
145
+
146
+ def execute(self, prompt, extra_data={}):
147
+ nodes.interrupt_processing(False)
148
+
149
+ if "client_id" in extra_data:
150
+ self.server.client_id = extra_data["client_id"]
151
+ else:
152
+ self.server.client_id = None
153
+
154
+ with torch.inference_mode():
155
+ for x in prompt:
156
+ recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
157
+
158
+ current_outputs = set(self.outputs.keys())
159
+ executed = []
160
+ try:
161
+ to_execute = []
162
+ for x in prompt:
163
+ class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
164
+ if hasattr(class_, 'OUTPUT_NODE'):
165
+ to_execute += [(0, x)]
166
+
167
+ while len(to_execute) > 0:
168
+ #always execute the output that depends on the least amount of unexecuted nodes first
169
+ to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
170
+ x = to_execute.pop(0)[-1]
171
+
172
+ class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
173
+ if hasattr(class_, 'OUTPUT_NODE'):
174
+ if class_.OUTPUT_NODE == True:
175
+ valid = False
176
+ try:
177
+ m = validate_inputs(prompt, x)
178
+ valid = m[0]
179
+ except:
180
+ valid = False
181
+ if valid:
182
+ executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
183
+ except Exception as e:
184
+ print(traceback.format_exc())
185
+ to_delete = []
186
+ for o in self.outputs:
187
+ if o not in current_outputs:
188
+ to_delete += [o]
189
+ if o in self.old_prompt:
190
+ d = self.old_prompt.pop(o)
191
+ del d
192
+ for o in to_delete:
193
+ d = self.outputs.pop(o)
194
+ del d
195
+ else:
196
+ executed = set(executed)
197
+ for x in executed:
198
+ self.old_prompt[x] = copy.deepcopy(prompt[x])
199
+ finally:
200
+ self.server.last_node_id = None
201
+ if self.server.client_id is not None:
202
+ self.server.send_sync("executing", { "node": None }, self.server.client_id)
203
+
204
+ gc.collect()
205
+ if torch.cuda.is_available():
206
+ if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
207
+ torch.cuda.empty_cache()
208
+ torch.cuda.ipc_collect()
209
+
210
+
211
+ def validate_inputs(prompt, item):
212
+ unique_id = item
213
+ inputs = prompt[unique_id]['inputs']
214
+ class_type = prompt[unique_id]['class_type']
215
+ obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
216
+
217
+ class_inputs = obj_class.INPUT_TYPES()
218
+ required_inputs = class_inputs['required']
219
+ for x in required_inputs:
220
+ if x not in inputs:
221
+ return (False, "Required input is missing. {}, {}".format(class_type, x))
222
+ val = inputs[x]
223
+ info = required_inputs[x]
224
+ type_input = info[0]
225
+ if isinstance(val, list):
226
+ if len(val) != 2:
227
+ return (False, "Bad Input. {}, {}".format(class_type, x))
228
+ o_id = val[0]
229
+ o_class_type = prompt[o_id]['class_type']
230
+ r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
231
+ if r[val[1]] != type_input:
232
+ return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
233
+ r = validate_inputs(prompt, o_id)
234
+ if r[0] == False:
235
+ return r
236
+ else:
237
+ if type_input == "INT":
238
+ val = int(val)
239
+ inputs[x] = val
240
+ if type_input == "FLOAT":
241
+ val = float(val)
242
+ inputs[x] = val
243
+ if type_input == "STRING":
244
+ val = str(val)
245
+ inputs[x] = val
246
+
247
+ if len(info) > 1:
248
+ if "min" in info[1] and val < info[1]["min"]:
249
+ return (False, "Value smaller than min. {}, {}".format(class_type, x))
250
+ if "max" in info[1] and val > info[1]["max"]:
251
+ return (False, "Value bigger than max. {}, {}".format(class_type, x))
252
+
253
+ if isinstance(type_input, list):
254
+ if val not in type_input:
255
+ return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
256
+ return (True, "")
257
+
258
+ def validate_prompt(prompt):
259
+ outputs = set()
260
+ for x in prompt:
261
+ class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
262
+ if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
263
+ outputs.add(x)
264
+
265
+ if len(outputs) == 0:
266
+ return (False, "Prompt has no outputs")
267
+
268
+ good_outputs = set()
269
+ errors = []
270
+ for o in outputs:
271
+ valid = False
272
+ reason = ""
273
+ try:
274
+ m = validate_inputs(prompt, o)
275
+ valid = m[0]
276
+ reason = m[1]
277
+ except:
278
+ valid = False
279
+ reason = "Parsing error"
280
+
281
+ if valid == True:
282
+ good_outputs.add(x)
283
+ else:
284
+ print("Failed to validate prompt for output {} {}".format(o, reason))
285
+ print("output will be ignored")
286
+ errors += [(o, reason)]
287
+
288
+ if len(good_outputs) == 0:
289
+ errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
290
+ return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
291
+
292
+ return (True, "")
293
+
294
+
295
+ class PromptQueue:
296
+ def __init__(self, server):
297
+ self.server = server
298
+ self.mutex = threading.RLock()
299
+ self.not_empty = threading.Condition(self.mutex)
300
+ self.task_counter = 0
301
+ self.queue = []
302
+ self.currently_running = {}
303
+ self.history = {}
304
+ server.prompt_queue = self
305
+
306
+ def put(self, item):
307
+ with self.mutex:
308
+ heapq.heappush(self.queue, item)
309
+ self.server.queue_updated()
310
+ self.not_empty.notify()
311
+
312
+ def get(self):
313
+ with self.not_empty:
314
+ while len(self.queue) == 0:
315
+ self.not_empty.wait()
316
+ item = heapq.heappop(self.queue)
317
+ i = self.task_counter
318
+ self.currently_running[i] = copy.deepcopy(item)
319
+ self.task_counter += 1
320
+ self.server.queue_updated()
321
+ return (item, i)
322
+
323
+ def task_done(self, item_id, outputs):
324
+ with self.mutex:
325
+ prompt = self.currently_running.pop(item_id)
326
+ self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
327
+ for o in outputs:
328
+ if "ui" in outputs[o]:
329
+ self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
330
+ self.server.queue_updated()
331
+
332
+ def get_current_queue(self):
333
+ with self.mutex:
334
+ out = []
335
+ for x in self.currently_running.values():
336
+ out += [x]
337
+ return (out, copy.deepcopy(self.queue))
338
+
339
+ def get_tasks_remaining(self):
340
+ with self.mutex:
341
+ return len(self.queue) + len(self.currently_running)
342
+
343
+ def wipe_queue(self):
344
+ with self.mutex:
345
+ self.queue = []
346
+ self.server.queue_updated()
347
+
348
+ def delete_queue_item(self, function):
349
+ with self.mutex:
350
+ for x in range(len(self.queue)):
351
+ if function(self.queue[x]):
352
+ if len(self.queue) == 1:
353
+ self.wipe_queue()
354
+ else:
355
+ self.queue.pop(x)
356
+ heapq.heapify(self.queue)
357
+ self.server.queue_updated()
358
+ return True
359
+ return False
360
+
361
+ def get_history(self):
362
+ with self.mutex:
363
+ return copy.deepcopy(self.history)
364
+
365
+ def wipe_history(self):
366
+ with self.mutex:
367
+ self.history = {}
368
+
369
+ def delete_history_item(self, id_to_delete):
370
+ with self.mutex:
371
+ self.history.pop(id_to_delete, None)
extra_model_paths.yaml.example ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Rename this to extra_model_paths.yaml and ComfyUI will load it
2
+
3
+ #config for a1111 ui
4
+ #all you have to do is change the base_path to where yours is installed
5
+ a111:
6
+ base_path: path/to/stable-diffusion-webui/
7
+
8
+ checkpoints: models/Stable-diffusion
9
+ configs: models/Stable-diffusion
10
+ vae: models/VAE
11
+ loras: models/Lora
12
+ upscale_models: |
13
+ models/ESRGAN
14
+ models/SwinIR
15
+ embeddings: embeddings
16
+ controlnet: models/ControlNet
17
+
18
+ #other_ui:
19
+ # base_path: path/to/ui
20
+ # checkpoints: models/checkpoints
21
+
22
+
23
+
folder_paths.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ supported_ckpt_extensions = set(['.ckpt', '.pth'])
4
+ supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
5
+ try:
6
+ import safetensors.torch
7
+ supported_ckpt_extensions.add('.safetensors')
8
+ supported_pt_extensions.add('.safetensors')
9
+ except:
10
+ print("Could not import safetensors, safetensors support disabled.")
11
+
12
+
13
+ folder_names_and_paths = {}
14
+
15
+
16
+ models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
17
+ folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
18
+ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
19
+
20
+ folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
21
+ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
22
+ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
23
+ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
24
+ folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
25
+ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
26
+
27
+ folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
28
+ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
29
+
30
+
31
+ def add_model_folder_path(folder_name, full_folder_path):
32
+ global folder_names_and_paths
33
+ if folder_name in folder_names_and_paths:
34
+ folder_names_and_paths[folder_name][0].append(full_folder_path)
35
+
36
+ def get_folder_paths(folder_name):
37
+ return folder_names_and_paths[folder_name][0][:]
38
+
39
+ def recursive_search(directory):
40
+ result = []
41
+ for root, subdir, file in os.walk(directory, followlinks=True):
42
+ for filepath in file:
43
+ #we os.path,join directory with a blank string to generate a path separator at the end.
44
+ result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
45
+ return result
46
+
47
+ def filter_files_extensions(files, extensions):
48
+ return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
49
+
50
+
51
+
52
+ def get_full_path(folder_name, filename):
53
+ global folder_names_and_paths
54
+ folders = folder_names_and_paths[folder_name]
55
+ for x in folders[0]:
56
+ full_path = os.path.join(x, filename)
57
+ if os.path.isfile(full_path):
58
+ return full_path
59
+
60
+
61
+ def get_filename_list(folder_name):
62
+ global folder_names_and_paths
63
+ output_list = set()
64
+ folders = folder_names_and_paths[folder_name]
65
+ for x in folders[0]:
66
+ output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
67
+ return sorted(list(output_list))
68
+
69
+
main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+
5
+ import threading
6
+ import asyncio
7
+
8
+ if os.name == "nt":
9
+ import logging
10
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
11
+
12
+ if __name__ == "__main__":
13
+ if '--help' in sys.argv:
14
+ print()
15
+ print("Valid Command line Arguments:")
16
+ print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
17
+ print("\t--port 8188\t\t\tSet the listen port.")
18
+ print()
19
+ print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
20
+ print()
21
+ print()
22
+ print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
23
+ print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
24
+ print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
25
+ print("\t--disable-xformers\t\tdisables xformers")
26
+ print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
27
+ print()
28
+ print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
29
+ print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
30
+ print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
31
+ print("\t--novram\t\t\tWhen lowvram isn't enough.")
32
+ print()
33
+ print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
34
+ exit()
35
+
36
+ if '--dont-upcast-attention' in sys.argv:
37
+ print("disabling upcasting of attention")
38
+ os.environ['ATTN_PRECISION'] = "fp16"
39
+
40
+ try:
41
+ index = sys.argv.index('--cuda-device')
42
+ device = sys.argv[index + 1]
43
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
44
+ print("Set cuda device to:", device)
45
+ except:
46
+ pass
47
+
48
+ from nodes import init_custom_nodes
49
+ import execution
50
+ import server
51
+ import folder_paths
52
+ import yaml
53
+
54
+ def prompt_worker(q, server):
55
+ e = execution.PromptExecutor(server)
56
+ while True:
57
+ item, item_id = q.get()
58
+ e.execute(item[-2], item[-1])
59
+ q.task_done(item_id, e.outputs)
60
+
61
+ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
62
+ await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
63
+
64
+ def hijack_progress(server):
65
+ from tqdm.auto import tqdm
66
+ orig_func = getattr(tqdm, "update")
67
+ def wrapped_func(*args, **kwargs):
68
+ pbar = args[0]
69
+ v = orig_func(*args, **kwargs)
70
+ server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
71
+ return v
72
+ setattr(tqdm, "update", wrapped_func)
73
+
74
+ def cleanup_temp():
75
+ temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
76
+ if os.path.exists(temp_dir):
77
+ shutil.rmtree(temp_dir, ignore_errors=True)
78
+
79
+ def load_extra_path_config(yaml_path):
80
+ with open(yaml_path, 'r') as stream:
81
+ config = yaml.safe_load(stream)
82
+ for c in config:
83
+ conf = config[c]
84
+ if conf is None:
85
+ continue
86
+ base_path = None
87
+ if "base_path" in conf:
88
+ base_path = conf.pop("base_path")
89
+ for x in conf:
90
+ for y in conf[x].split("\n"):
91
+ if len(y) == 0:
92
+ continue
93
+ full_path = y
94
+ if base_path is not None:
95
+ full_path = os.path.join(base_path, full_path)
96
+ print("Adding extra search path", x, full_path)
97
+ folder_paths.add_model_folder_path(x, full_path)
98
+
99
+ if __name__ == "__main__":
100
+ cleanup_temp()
101
+
102
+ loop = asyncio.new_event_loop()
103
+ asyncio.set_event_loop(loop)
104
+ server = server.PromptServer(loop)
105
+ q = execution.PromptQueue(server)
106
+
107
+ init_custom_nodes()
108
+ server.add_routes()
109
+ hijack_progress(server)
110
+
111
+ threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
112
+ try:
113
+ address = '0.0.0.0'
114
+ p_index = sys.argv.index('--listen')
115
+ try:
116
+ ip = sys.argv[p_index + 1]
117
+ if ip[:2] != '--':
118
+ address = ip
119
+ except:
120
+ pass
121
+ except:
122
+ address = '127.0.0.1'
123
+
124
+ dont_print = False
125
+ if '--dont-print-server' in sys.argv:
126
+ dont_print = True
127
+
128
+ extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
129
+ if os.path.isfile(extra_model_paths_config_path):
130
+ load_extra_path_config(extra_model_paths_config_path)
131
+
132
+ if '--extra-model-paths-config' in sys.argv:
133
+ indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config']
134
+ for i in indices:
135
+ load_extra_path_config(sys.argv[i])
136
+
137
+ port = 8188
138
+ try:
139
+ p_index = sys.argv.index('--port')
140
+ port = int(sys.argv[p_index + 1])
141
+ except:
142
+ pass
143
+
144
+ if '--quick-test-for-ci' in sys.argv:
145
+ exit(0)
146
+
147
+ call_on_start = None
148
+ if "--windows-standalone-build" in sys.argv:
149
+ def startup_server(address, port):
150
+ import webbrowser
151
+ webbrowser.open("http://{}:{}".format(address, port))
152
+ call_on_start = startup_server
153
+
154
+ if os.name == "nt":
155
+ try:
156
+ loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
157
+ except KeyboardInterrupt:
158
+ pass
159
+ else:
160
+ loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
161
+
162
+ cleanup_temp()
nodes.py ADDED
@@ -0,0 +1,1115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ import hashlib
7
+ import copy
8
+ import traceback
9
+
10
+ from PIL import Image
11
+ from PIL.PngImagePlugin import PngInfo
12
+ import numpy as np
13
+
14
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
15
+
16
+
17
+ import comfy.samplers
18
+ import comfy.sd
19
+ import comfy.utils
20
+
21
+ import comfy.clip_vision
22
+
23
+ import model_management
24
+ import importlib
25
+
26
+ import folder_paths
27
+
28
+ def before_node_execution():
29
+ model_management.throw_exception_if_processing_interrupted()
30
+
31
+ def interrupt_processing(value=True):
32
+ model_management.interrupt_current_processing(value)
33
+
34
+ MAX_RESOLUTION=8192
35
+
36
+ class CLIPTextEncode:
37
+ @classmethod
38
+ def INPUT_TYPES(s):
39
+ return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}}
40
+ RETURN_TYPES = ("CONDITIONING",)
41
+ FUNCTION = "encode"
42
+
43
+ CATEGORY = "conditioning"
44
+
45
+ def encode(self, clip, text):
46
+ return ([[clip.encode(text), {}]], )
47
+
48
+ class ConditioningCombine:
49
+ @classmethod
50
+ def INPUT_TYPES(s):
51
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
52
+ RETURN_TYPES = ("CONDITIONING",)
53
+ FUNCTION = "combine"
54
+
55
+ CATEGORY = "conditioning"
56
+
57
+ def combine(self, conditioning_1, conditioning_2):
58
+ return (conditioning_1 + conditioning_2, )
59
+
60
+ class ConditioningSetArea:
61
+ @classmethod
62
+ def INPUT_TYPES(s):
63
+ return {"required": {"conditioning": ("CONDITIONING", ),
64
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
65
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
66
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
67
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
68
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
69
+ }}
70
+ RETURN_TYPES = ("CONDITIONING",)
71
+ FUNCTION = "append"
72
+
73
+ CATEGORY = "conditioning"
74
+
75
+ def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
76
+ c = []
77
+ for t in conditioning:
78
+ n = [t[0], t[1].copy()]
79
+ n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
80
+ n[1]['strength'] = strength
81
+ n[1]['min_sigma'] = min_sigma
82
+ n[1]['max_sigma'] = max_sigma
83
+ c.append(n)
84
+ return (c, )
85
+
86
+ class VAEDecode:
87
+ def __init__(self, device="cpu"):
88
+ self.device = device
89
+
90
+ @classmethod
91
+ def INPUT_TYPES(s):
92
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
93
+ RETURN_TYPES = ("IMAGE",)
94
+ FUNCTION = "decode"
95
+
96
+ CATEGORY = "latent"
97
+
98
+ def decode(self, vae, samples):
99
+ return (vae.decode(samples["samples"]), )
100
+
101
+ class VAEDecodeTiled:
102
+ def __init__(self, device="cpu"):
103
+ self.device = device
104
+
105
+ @classmethod
106
+ def INPUT_TYPES(s):
107
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
108
+ RETURN_TYPES = ("IMAGE",)
109
+ FUNCTION = "decode"
110
+
111
+ CATEGORY = "_for_testing"
112
+
113
+ def decode(self, vae, samples):
114
+ return (vae.decode_tiled(samples["samples"]), )
115
+
116
+ class VAEEncode:
117
+ def __init__(self, device="cpu"):
118
+ self.device = device
119
+
120
+ @classmethod
121
+ def INPUT_TYPES(s):
122
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
123
+ RETURN_TYPES = ("LATENT",)
124
+ FUNCTION = "encode"
125
+
126
+ CATEGORY = "latent"
127
+
128
+ def encode(self, vae, pixels):
129
+ x = (pixels.shape[1] // 64) * 64
130
+ y = (pixels.shape[2] // 64) * 64
131
+ if pixels.shape[1] != x or pixels.shape[2] != y:
132
+ pixels = pixels[:,:x,:y,:]
133
+ t = vae.encode(pixels[:,:,:,:3])
134
+
135
+ return ({"samples":t}, )
136
+
137
+
138
+ class VAEEncodeTiled:
139
+ def __init__(self, device="cpu"):
140
+ self.device = device
141
+
142
+ @classmethod
143
+ def INPUT_TYPES(s):
144
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
145
+ RETURN_TYPES = ("LATENT",)
146
+ FUNCTION = "encode"
147
+
148
+ CATEGORY = "_for_testing"
149
+
150
+ def encode(self, vae, pixels):
151
+ x = (pixels.shape[1] // 64) * 64
152
+ y = (pixels.shape[2] // 64) * 64
153
+ if pixels.shape[1] != x or pixels.shape[2] != y:
154
+ pixels = pixels[:,:x,:y,:]
155
+ t = vae.encode_tiled(pixels[:,:,:,:3])
156
+
157
+ return ({"samples":t}, )
158
+ class VAEEncodeForInpaint:
159
+ def __init__(self, device="cpu"):
160
+ self.device = device
161
+
162
+ @classmethod
163
+ def INPUT_TYPES(s):
164
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
165
+ RETURN_TYPES = ("LATENT",)
166
+ FUNCTION = "encode"
167
+
168
+ CATEGORY = "latent/inpaint"
169
+
170
+ def encode(self, vae, pixels, mask):
171
+ x = (pixels.shape[1] // 64) * 64
172
+ y = (pixels.shape[2] // 64) * 64
173
+ mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]
174
+
175
+ pixels = pixels.clone()
176
+ if pixels.shape[1] != x or pixels.shape[2] != y:
177
+ pixels = pixels[:,:x,:y,:]
178
+ mask = mask[:x,:y]
179
+
180
+ #grow mask by a few pixels to keep things seamless in latent space
181
+ kernel_tensor = torch.ones((1, 1, 6, 6))
182
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1)
183
+ m = (1.0 - mask.round())
184
+ for i in range(3):
185
+ pixels[:,:,:,i] -= 0.5
186
+ pixels[:,:,:,i] *= m
187
+ pixels[:,:,:,i] += 0.5
188
+ t = vae.encode(pixels)
189
+
190
+ return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
191
+
192
+ class CheckpointLoader:
193
+ @classmethod
194
+ def INPUT_TYPES(s):
195
+ return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
196
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
197
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
198
+ FUNCTION = "load_checkpoint"
199
+
200
+ CATEGORY = "loaders"
201
+
202
+ def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
203
+ config_path = folder_paths.get_full_path("configs", config_name)
204
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
205
+ return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
206
+
207
+ class CheckpointLoaderSimple:
208
+ @classmethod
209
+ def INPUT_TYPES(s):
210
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
211
+ }}
212
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
213
+ FUNCTION = "load_checkpoint"
214
+
215
+ CATEGORY = "loaders"
216
+
217
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
218
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
219
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
220
+ return out
221
+
222
+ class unCLIPCheckpointLoader:
223
+ @classmethod
224
+ def INPUT_TYPES(s):
225
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
226
+ }}
227
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
228
+ FUNCTION = "load_checkpoint"
229
+
230
+ CATEGORY = "_for_testing/unclip"
231
+
232
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
233
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
234
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
235
+ return out
236
+
237
+ class CLIPSetLastLayer:
238
+ @classmethod
239
+ def INPUT_TYPES(s):
240
+ return {"required": { "clip": ("CLIP", ),
241
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
242
+ }}
243
+ RETURN_TYPES = ("CLIP",)
244
+ FUNCTION = "set_last_layer"
245
+
246
+ CATEGORY = "conditioning"
247
+
248
+ def set_last_layer(self, clip, stop_at_clip_layer):
249
+ clip = clip.clone()
250
+ clip.clip_layer(stop_at_clip_layer)
251
+ return (clip,)
252
+
253
+ class LoraLoader:
254
+ @classmethod
255
+ def INPUT_TYPES(s):
256
+ return {"required": { "model": ("MODEL",),
257
+ "clip": ("CLIP", ),
258
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
259
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
260
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
261
+ }}
262
+ RETURN_TYPES = ("MODEL", "CLIP")
263
+ FUNCTION = "load_lora"
264
+
265
+ CATEGORY = "loaders"
266
+
267
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
268
+ lora_path = folder_paths.get_full_path("loras", lora_name)
269
+ model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
270
+ return (model_lora, clip_lora)
271
+
272
+ class TomePatchModel:
273
+ @classmethod
274
+ def INPUT_TYPES(s):
275
+ return {"required": { "model": ("MODEL",),
276
+ "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
277
+ }}
278
+ RETURN_TYPES = ("MODEL",)
279
+ FUNCTION = "patch"
280
+
281
+ CATEGORY = "_for_testing"
282
+
283
+ def patch(self, model, ratio):
284
+ m = model.clone()
285
+ m.set_model_tomesd(ratio)
286
+ return (m, )
287
+
288
+ class VAELoader:
289
+ @classmethod
290
+ def INPUT_TYPES(s):
291
+ return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
292
+ RETURN_TYPES = ("VAE",)
293
+ FUNCTION = "load_vae"
294
+
295
+ CATEGORY = "loaders"
296
+
297
+ #TODO: scale factor?
298
+ def load_vae(self, vae_name):
299
+ vae_path = folder_paths.get_full_path("vae", vae_name)
300
+ vae = comfy.sd.VAE(ckpt_path=vae_path)
301
+ return (vae,)
302
+
303
+ class ControlNetLoader:
304
+ @classmethod
305
+ def INPUT_TYPES(s):
306
+ return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
307
+
308
+ RETURN_TYPES = ("CONTROL_NET",)
309
+ FUNCTION = "load_controlnet"
310
+
311
+ CATEGORY = "loaders"
312
+
313
+ def load_controlnet(self, control_net_name):
314
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
315
+ controlnet = comfy.sd.load_controlnet(controlnet_path)
316
+ return (controlnet,)
317
+
318
+ class DiffControlNetLoader:
319
+ @classmethod
320
+ def INPUT_TYPES(s):
321
+ return {"required": { "model": ("MODEL",),
322
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
323
+
324
+ RETURN_TYPES = ("CONTROL_NET",)
325
+ FUNCTION = "load_controlnet"
326
+
327
+ CATEGORY = "loaders"
328
+
329
+ def load_controlnet(self, model, control_net_name):
330
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
331
+ controlnet = comfy.sd.load_controlnet(controlnet_path, model)
332
+ return (controlnet,)
333
+
334
+
335
+ class ControlNetApply:
336
+ @classmethod
337
+ def INPUT_TYPES(s):
338
+ return {"required": {"conditioning": ("CONDITIONING", ),
339
+ "control_net": ("CONTROL_NET", ),
340
+ "image": ("IMAGE", ),
341
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
342
+ }}
343
+ RETURN_TYPES = ("CONDITIONING",)
344
+ FUNCTION = "apply_controlnet"
345
+
346
+ CATEGORY = "conditioning"
347
+
348
+ def apply_controlnet(self, conditioning, control_net, image, strength):
349
+ c = []
350
+ control_hint = image.movedim(-1,1)
351
+ print(control_hint.shape)
352
+ for t in conditioning:
353
+ n = [t[0], t[1].copy()]
354
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
355
+ if 'control' in t[1]:
356
+ c_net.set_previous_controlnet(t[1]['control'])
357
+ n[1]['control'] = c_net
358
+ c.append(n)
359
+ return (c, )
360
+
361
+ class CLIPLoader:
362
+ @classmethod
363
+ def INPUT_TYPES(s):
364
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
365
+ }}
366
+ RETURN_TYPES = ("CLIP",)
367
+ FUNCTION = "load_clip"
368
+
369
+ CATEGORY = "loaders"
370
+
371
+ def load_clip(self, clip_name):
372
+ clip_path = folder_paths.get_full_path("clip", clip_name)
373
+ clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
374
+ return (clip,)
375
+
376
+ class CLIPVisionLoader:
377
+ @classmethod
378
+ def INPUT_TYPES(s):
379
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
380
+ }}
381
+ RETURN_TYPES = ("CLIP_VISION",)
382
+ FUNCTION = "load_clip"
383
+
384
+ CATEGORY = "loaders"
385
+
386
+ def load_clip(self, clip_name):
387
+ clip_path = folder_paths.get_full_path("clip_vision", clip_name)
388
+ clip_vision = comfy.clip_vision.load(clip_path)
389
+ return (clip_vision,)
390
+
391
+ class CLIPVisionEncode:
392
+ @classmethod
393
+ def INPUT_TYPES(s):
394
+ return {"required": { "clip_vision": ("CLIP_VISION",),
395
+ "image": ("IMAGE",)
396
+ }}
397
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
398
+ FUNCTION = "encode"
399
+
400
+ CATEGORY = "conditioning"
401
+
402
+ def encode(self, clip_vision, image):
403
+ output = clip_vision.encode_image(image)
404
+ return (output,)
405
+
406
+ class StyleModelLoader:
407
+ @classmethod
408
+ def INPUT_TYPES(s):
409
+ return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
410
+
411
+ RETURN_TYPES = ("STYLE_MODEL",)
412
+ FUNCTION = "load_style_model"
413
+
414
+ CATEGORY = "loaders"
415
+
416
+ def load_style_model(self, style_model_name):
417
+ style_model_path = folder_paths.get_full_path("style_models", style_model_name)
418
+ style_model = comfy.sd.load_style_model(style_model_path)
419
+ return (style_model,)
420
+
421
+
422
+ class StyleModelApply:
423
+ @classmethod
424
+ def INPUT_TYPES(s):
425
+ return {"required": {"conditioning": ("CONDITIONING", ),
426
+ "style_model": ("STYLE_MODEL", ),
427
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
428
+ }}
429
+ RETURN_TYPES = ("CONDITIONING",)
430
+ FUNCTION = "apply_stylemodel"
431
+
432
+ CATEGORY = "conditioning/style_model"
433
+
434
+ def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
435
+ cond = style_model.get_cond(clip_vision_output)
436
+ c = []
437
+ for t in conditioning:
438
+ n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
439
+ c.append(n)
440
+ return (c, )
441
+
442
+ class unCLIPConditioning:
443
+ @classmethod
444
+ def INPUT_TYPES(s):
445
+ return {"required": {"conditioning": ("CONDITIONING", ),
446
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
447
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
448
+ }}
449
+ RETURN_TYPES = ("CONDITIONING",)
450
+ FUNCTION = "apply_adm"
451
+
452
+ CATEGORY = "_for_testing/unclip"
453
+
454
+ def apply_adm(self, conditioning, clip_vision_output, strength):
455
+ c = []
456
+ for t in conditioning:
457
+ o = t[1].copy()
458
+ x = (clip_vision_output, strength)
459
+ if "adm" in o:
460
+ o["adm"] = o["adm"][:] + [x]
461
+ else:
462
+ o["adm"] = [x]
463
+ n = [t[0], o]
464
+ c.append(n)
465
+ return (c, )
466
+
467
+
468
+ class EmptyLatentImage:
469
+ def __init__(self, device="cpu"):
470
+ self.device = device
471
+
472
+ @classmethod
473
+ def INPUT_TYPES(s):
474
+ return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
475
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
476
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
477
+ RETURN_TYPES = ("LATENT",)
478
+ FUNCTION = "generate"
479
+
480
+ CATEGORY = "latent"
481
+
482
+ def generate(self, width, height, batch_size=1):
483
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
484
+ return ({"samples":latent}, )
485
+
486
+
487
+
488
+ class LatentUpscale:
489
+ upscale_methods = ["nearest-exact", "bilinear", "area"]
490
+ crop_methods = ["disabled", "center"]
491
+
492
+ @classmethod
493
+ def INPUT_TYPES(s):
494
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
495
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
496
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
497
+ "crop": (s.crop_methods,)}}
498
+ RETURN_TYPES = ("LATENT",)
499
+ FUNCTION = "upscale"
500
+
501
+ CATEGORY = "latent"
502
+
503
+ def upscale(self, samples, upscale_method, width, height, crop):
504
+ s = samples.copy()
505
+ s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
506
+ return (s,)
507
+
508
+ class LatentRotate:
509
+ @classmethod
510
+ def INPUT_TYPES(s):
511
+ return {"required": { "samples": ("LATENT",),
512
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
513
+ }}
514
+ RETURN_TYPES = ("LATENT",)
515
+ FUNCTION = "rotate"
516
+
517
+ CATEGORY = "latent/transform"
518
+
519
+ def rotate(self, samples, rotation):
520
+ s = samples.copy()
521
+ rotate_by = 0
522
+ if rotation.startswith("90"):
523
+ rotate_by = 1
524
+ elif rotation.startswith("180"):
525
+ rotate_by = 2
526
+ elif rotation.startswith("270"):
527
+ rotate_by = 3
528
+
529
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
530
+ return (s,)
531
+
532
+ class LatentFlip:
533
+ @classmethod
534
+ def INPUT_TYPES(s):
535
+ return {"required": { "samples": ("LATENT",),
536
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
537
+ }}
538
+ RETURN_TYPES = ("LATENT",)
539
+ FUNCTION = "flip"
540
+
541
+ CATEGORY = "latent/transform"
542
+
543
+ def flip(self, samples, flip_method):
544
+ s = samples.copy()
545
+ if flip_method.startswith("x"):
546
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
547
+ elif flip_method.startswith("y"):
548
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
549
+
550
+ return (s,)
551
+
552
+ class LatentComposite:
553
+ @classmethod
554
+ def INPUT_TYPES(s):
555
+ return {"required": { "samples_to": ("LATENT",),
556
+ "samples_from": ("LATENT",),
557
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
558
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
559
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
560
+ }}
561
+ RETURN_TYPES = ("LATENT",)
562
+ FUNCTION = "composite"
563
+
564
+ CATEGORY = "latent"
565
+
566
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
567
+ x = x // 8
568
+ y = y // 8
569
+ feather = feather // 8
570
+ samples_out = samples_to.copy()
571
+ s = samples_to["samples"].clone()
572
+ samples_to = samples_to["samples"]
573
+ samples_from = samples_from["samples"]
574
+ if feather == 0:
575
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
576
+ else:
577
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
578
+ mask = torch.ones_like(samples_from)
579
+ for t in range(feather):
580
+ if y != 0:
581
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
582
+
583
+ if y + samples_from.shape[2] < samples_to.shape[2]:
584
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
585
+ if x != 0:
586
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
587
+ if x + samples_from.shape[3] < samples_to.shape[3]:
588
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
589
+ rev_mask = torch.ones_like(mask) - mask
590
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
591
+ samples_out["samples"] = s
592
+ return (samples_out,)
593
+
594
+ class LatentCrop:
595
+ @classmethod
596
+ def INPUT_TYPES(s):
597
+ return {"required": { "samples": ("LATENT",),
598
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
599
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
600
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
601
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
602
+ }}
603
+ RETURN_TYPES = ("LATENT",)
604
+ FUNCTION = "crop"
605
+
606
+ CATEGORY = "latent/transform"
607
+
608
+ def crop(self, samples, width, height, x, y):
609
+ s = samples.copy()
610
+ samples = samples['samples']
611
+ x = x // 8
612
+ y = y // 8
613
+
614
+ #enfonce minimum size of 64
615
+ if x > (samples.shape[3] - 8):
616
+ x = samples.shape[3] - 8
617
+ if y > (samples.shape[2] - 8):
618
+ y = samples.shape[2] - 8
619
+
620
+ new_height = height // 8
621
+ new_width = width // 8
622
+ to_x = new_width + x
623
+ to_y = new_height + y
624
+ def enforce_image_dim(d, to_d, max_d):
625
+ if to_d > max_d:
626
+ leftover = (to_d - max_d) % 8
627
+ to_d = max_d
628
+ d -= leftover
629
+ return (d, to_d)
630
+
631
+ #make sure size is always multiple of 64
632
+ x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
633
+ y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
634
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
635
+ return (s,)
636
+
637
+ class SetLatentNoiseMask:
638
+ @classmethod
639
+ def INPUT_TYPES(s):
640
+ return {"required": { "samples": ("LATENT",),
641
+ "mask": ("MASK",),
642
+ }}
643
+ RETURN_TYPES = ("LATENT",)
644
+ FUNCTION = "set_mask"
645
+
646
+ CATEGORY = "latent/inpaint"
647
+
648
+ def set_mask(self, samples, mask):
649
+ s = samples.copy()
650
+ s["noise_mask"] = mask
651
+ return (s,)
652
+
653
+
654
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
655
+ latent_image = latent["samples"]
656
+ noise_mask = None
657
+ device = model_management.get_torch_device()
658
+
659
+ if disable_noise:
660
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
661
+ else:
662
+ noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
663
+
664
+ if "noise_mask" in latent:
665
+ noise_mask = latent['noise_mask']
666
+ noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
667
+ noise_mask = noise_mask.round()
668
+ noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
669
+ noise_mask = torch.cat([noise_mask] * noise.shape[0])
670
+ noise_mask = noise_mask.to(device)
671
+
672
+ real_model = None
673
+ model_management.load_model_gpu(model)
674
+ real_model = model.model
675
+
676
+ noise = noise.to(device)
677
+ latent_image = latent_image.to(device)
678
+
679
+ positive_copy = []
680
+ negative_copy = []
681
+
682
+ control_nets = []
683
+ for p in positive:
684
+ t = p[0]
685
+ if t.shape[0] < noise.shape[0]:
686
+ t = torch.cat([t] * noise.shape[0])
687
+ t = t.to(device)
688
+ if 'control' in p[1]:
689
+ control_nets += [p[1]['control']]
690
+ positive_copy += [[t] + p[1:]]
691
+ for n in negative:
692
+ t = n[0]
693
+ if t.shape[0] < noise.shape[0]:
694
+ t = torch.cat([t] * noise.shape[0])
695
+ t = t.to(device)
696
+ if 'control' in n[1]:
697
+ control_nets += [n[1]['control']]
698
+ negative_copy += [[t] + n[1:]]
699
+
700
+ control_net_models = []
701
+ for x in control_nets:
702
+ control_net_models += x.get_control_models()
703
+ model_management.load_controlnet_gpu(control_net_models)
704
+
705
+ if sampler_name in comfy.samplers.KSampler.SAMPLERS:
706
+ sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
707
+ else:
708
+ #other samplers
709
+ pass
710
+
711
+ samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
712
+ samples = samples.cpu()
713
+ for c in control_nets:
714
+ c.cleanup()
715
+
716
+ out = latent.copy()
717
+ out["samples"] = samples
718
+ return (out, )
719
+
720
+ class KSampler:
721
+ @classmethod
722
+ def INPUT_TYPES(s):
723
+ return {"required":
724
+ {"model": ("MODEL",),
725
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
726
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
727
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
728
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
729
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
730
+ "positive": ("CONDITIONING", ),
731
+ "negative": ("CONDITIONING", ),
732
+ "latent_image": ("LATENT", ),
733
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
734
+ }}
735
+
736
+ RETURN_TYPES = ("LATENT",)
737
+ FUNCTION = "sample"
738
+
739
+ CATEGORY = "sampling"
740
+
741
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
742
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
743
+
744
+ class KSamplerAdvanced:
745
+ @classmethod
746
+ def INPUT_TYPES(s):
747
+ return {"required":
748
+ {"model": ("MODEL",),
749
+ "add_noise": (["enable", "disable"], ),
750
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
751
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
752
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
753
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
754
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
755
+ "positive": ("CONDITIONING", ),
756
+ "negative": ("CONDITIONING", ),
757
+ "latent_image": ("LATENT", ),
758
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
759
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
760
+ "return_with_leftover_noise": (["disable", "enable"], ),
761
+ }}
762
+
763
+ RETURN_TYPES = ("LATENT",)
764
+ FUNCTION = "sample"
765
+
766
+ CATEGORY = "sampling"
767
+
768
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
769
+ force_full_denoise = True
770
+ if return_with_leftover_noise == "enable":
771
+ force_full_denoise = False
772
+ disable_noise = False
773
+ if add_noise == "disable":
774
+ disable_noise = True
775
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
776
+
777
+ class SaveImage:
778
+ def __init__(self):
779
+ self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
780
+ self.type = "output"
781
+
782
+ @classmethod
783
+ def INPUT_TYPES(s):
784
+ return {"required":
785
+ {"images": ("IMAGE", ),
786
+ "filename_prefix": ("STRING", {"default": "ComfyUI"})},
787
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
788
+ }
789
+
790
+ RETURN_TYPES = ()
791
+ FUNCTION = "save_images"
792
+
793
+ OUTPUT_NODE = True
794
+
795
+ CATEGORY = "image"
796
+
797
+ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
798
+ def map_filename(filename):
799
+ prefix_len = len(os.path.basename(filename_prefix))
800
+ prefix = filename[:prefix_len + 1]
801
+ try:
802
+ digits = int(filename[prefix_len + 1:].split('_')[0])
803
+ except:
804
+ digits = 0
805
+ return (digits, prefix)
806
+
807
+ def compute_vars(input):
808
+ input = input.replace("%width%", str(images[0].shape[1]))
809
+ input = input.replace("%height%", str(images[0].shape[0]))
810
+ return input
811
+
812
+ filename_prefix = compute_vars(filename_prefix)
813
+
814
+ subfolder = os.path.dirname(os.path.normpath(filename_prefix))
815
+ filename = os.path.basename(os.path.normpath(filename_prefix))
816
+
817
+ full_output_folder = os.path.join(self.output_dir, subfolder)
818
+
819
+ if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
820
+ print("Saving image outside the output folder is not allowed.")
821
+ return {}
822
+
823
+ try:
824
+ counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
825
+ except ValueError:
826
+ counter = 1
827
+ except FileNotFoundError:
828
+ os.makedirs(full_output_folder, exist_ok=True)
829
+ counter = 1
830
+
831
+ if not os.path.exists(self.output_dir):
832
+ os.makedirs(self.output_dir)
833
+
834
+ results = list()
835
+ for image in images:
836
+ i = 255. * image.cpu().numpy()
837
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
838
+ metadata = PngInfo()
839
+ if prompt is not None:
840
+ metadata.add_text("prompt", json.dumps(prompt))
841
+ if extra_pnginfo is not None:
842
+ for x in extra_pnginfo:
843
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
844
+
845
+ file = f"{filename}_{counter:05}_.png"
846
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
847
+ results.append({
848
+ "filename": file,
849
+ "subfolder": subfolder,
850
+ "type": self.type
851
+ });
852
+ counter += 1
853
+
854
+ return { "ui": { "images": results } }
855
+
856
+ class PreviewImage(SaveImage):
857
+ def __init__(self):
858
+ self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
859
+ self.type = "temp"
860
+
861
+ @classmethod
862
+ def INPUT_TYPES(s):
863
+ return {"required":
864
+ {"images": ("IMAGE", ), },
865
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
866
+ }
867
+
868
+ class LoadImage:
869
+ input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
870
+ @classmethod
871
+ def INPUT_TYPES(s):
872
+ if not os.path.exists(s.input_dir):
873
+ os.makedirs(s.input_dir)
874
+ return {"required":
875
+ {"image": (sorted(os.listdir(s.input_dir)), )},
876
+ }
877
+
878
+ CATEGORY = "image"
879
+
880
+ RETURN_TYPES = ("IMAGE", "MASK")
881
+ FUNCTION = "load_image"
882
+ def load_image(self, image):
883
+ image_path = os.path.join(self.input_dir, image)
884
+ i = Image.open(image_path)
885
+ image = i.convert("RGB")
886
+ image = np.array(image).astype(np.float32) / 255.0
887
+ image = torch.from_numpy(image)[None,]
888
+ if 'A' in i.getbands():
889
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
890
+ mask = 1. - torch.from_numpy(mask)
891
+ else:
892
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
893
+ return (image, mask)
894
+
895
+ @classmethod
896
+ def IS_CHANGED(s, image):
897
+ image_path = os.path.join(s.input_dir, image)
898
+ m = hashlib.sha256()
899
+ with open(image_path, 'rb') as f:
900
+ m.update(f.read())
901
+ return m.digest().hex()
902
+
903
+ class LoadImageMask:
904
+ input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
905
+ @classmethod
906
+ def INPUT_TYPES(s):
907
+ return {"required":
908
+ {"image": (sorted(os.listdir(s.input_dir)), ),
909
+ "channel": (["alpha", "red", "green", "blue"], ),}
910
+ }
911
+
912
+ CATEGORY = "image"
913
+
914
+ RETURN_TYPES = ("MASK",)
915
+ FUNCTION = "load_image"
916
+ def load_image(self, image, channel):
917
+ image_path = os.path.join(self.input_dir, image)
918
+ i = Image.open(image_path)
919
+ mask = None
920
+ c = channel[0].upper()
921
+ if c in i.getbands():
922
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
923
+ mask = torch.from_numpy(mask)
924
+ if c == 'A':
925
+ mask = 1. - mask
926
+ else:
927
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
928
+ return (mask,)
929
+
930
+ @classmethod
931
+ def IS_CHANGED(s, image, channel):
932
+ image_path = os.path.join(s.input_dir, image)
933
+ m = hashlib.sha256()
934
+ with open(image_path, 'rb') as f:
935
+ m.update(f.read())
936
+ return m.digest().hex()
937
+
938
+ class ImageScale:
939
+ upscale_methods = ["nearest-exact", "bilinear", "area"]
940
+ crop_methods = ["disabled", "center"]
941
+
942
+ @classmethod
943
+ def INPUT_TYPES(s):
944
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
945
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
946
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
947
+ "crop": (s.crop_methods,)}}
948
+ RETURN_TYPES = ("IMAGE",)
949
+ FUNCTION = "upscale"
950
+
951
+ CATEGORY = "image/upscaling"
952
+
953
+ def upscale(self, image, upscale_method, width, height, crop):
954
+ samples = image.movedim(-1,1)
955
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
956
+ s = s.movedim(1,-1)
957
+ return (s,)
958
+
959
+ class ImageInvert:
960
+
961
+ @classmethod
962
+ def INPUT_TYPES(s):
963
+ return {"required": { "image": ("IMAGE",)}}
964
+
965
+ RETURN_TYPES = ("IMAGE",)
966
+ FUNCTION = "invert"
967
+
968
+ CATEGORY = "image"
969
+
970
+ def invert(self, image):
971
+ s = 1.0 - image
972
+ return (s,)
973
+
974
+
975
+ class ImagePadForOutpaint:
976
+
977
+ @classmethod
978
+ def INPUT_TYPES(s):
979
+ return {
980
+ "required": {
981
+ "image": ("IMAGE",),
982
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
983
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
984
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
985
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
986
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
987
+ }
988
+ }
989
+
990
+ RETURN_TYPES = ("IMAGE", "MASK")
991
+ FUNCTION = "expand_image"
992
+
993
+ CATEGORY = "image"
994
+
995
+ def expand_image(self, image, left, top, right, bottom, feathering):
996
+ d1, d2, d3, d4 = image.size()
997
+
998
+ new_image = torch.zeros(
999
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1000
+ dtype=torch.float32,
1001
+ )
1002
+ new_image[:, top:top + d2, left:left + d3, :] = image
1003
+
1004
+ mask = torch.ones(
1005
+ (d2 + top + bottom, d3 + left + right),
1006
+ dtype=torch.float32,
1007
+ )
1008
+
1009
+ t = torch.zeros(
1010
+ (d2, d3),
1011
+ dtype=torch.float32
1012
+ )
1013
+
1014
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1015
+
1016
+ for i in range(d2):
1017
+ for j in range(d3):
1018
+ dt = i if top != 0 else d2
1019
+ db = d2 - i if bottom != 0 else d2
1020
+
1021
+ dl = j if left != 0 else d3
1022
+ dr = d3 - j if right != 0 else d3
1023
+
1024
+ d = min(dt, db, dl, dr)
1025
+
1026
+ if d >= feathering:
1027
+ continue
1028
+
1029
+ v = (feathering - d) / feathering
1030
+
1031
+ t[i, j] = v * v
1032
+
1033
+ mask[top:top + d2, left:left + d3] = t
1034
+
1035
+ return (new_image, mask)
1036
+
1037
+
1038
+ NODE_CLASS_MAPPINGS = {
1039
+ "KSampler": KSampler,
1040
+ "CheckpointLoader": CheckpointLoader,
1041
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1042
+ "CLIPTextEncode": CLIPTextEncode,
1043
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1044
+ "VAEDecode": VAEDecode,
1045
+ "VAEEncode": VAEEncode,
1046
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1047
+ "VAELoader": VAELoader,
1048
+ "EmptyLatentImage": EmptyLatentImage,
1049
+ "LatentUpscale": LatentUpscale,
1050
+ "SaveImage": SaveImage,
1051
+ "PreviewImage": PreviewImage,
1052
+ "LoadImage": LoadImage,
1053
+ "LoadImageMask": LoadImageMask,
1054
+ "ImageScale": ImageScale,
1055
+ "ImageInvert": ImageInvert,
1056
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1057
+ "ConditioningCombine": ConditioningCombine,
1058
+ "ConditioningSetArea": ConditioningSetArea,
1059
+ "KSamplerAdvanced": KSamplerAdvanced,
1060
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1061
+ "LatentComposite": LatentComposite,
1062
+ "LatentRotate": LatentRotate,
1063
+ "LatentFlip": LatentFlip,
1064
+ "LatentCrop": LatentCrop,
1065
+ "LoraLoader": LoraLoader,
1066
+ "CLIPLoader": CLIPLoader,
1067
+ "CLIPVisionEncode": CLIPVisionEncode,
1068
+ "StyleModelApply": StyleModelApply,
1069
+ "unCLIPConditioning": unCLIPConditioning,
1070
+ "ControlNetApply": ControlNetApply,
1071
+ "ControlNetLoader": ControlNetLoader,
1072
+ "DiffControlNetLoader": DiffControlNetLoader,
1073
+ "StyleModelLoader": StyleModelLoader,
1074
+ "CLIPVisionLoader": CLIPVisionLoader,
1075
+ "VAEDecodeTiled": VAEDecodeTiled,
1076
+ "VAEEncodeTiled": VAEEncodeTiled,
1077
+ "TomePatchModel": TomePatchModel,
1078
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1079
+ }
1080
+
1081
+ def load_custom_node(module_path):
1082
+ module_name = os.path.basename(module_path)
1083
+ if os.path.isfile(module_path):
1084
+ sp = os.path.splitext(module_path)
1085
+ module_name = sp[0]
1086
+ try:
1087
+ if os.path.isfile(module_path):
1088
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
1089
+ else:
1090
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
1091
+ module = importlib.util.module_from_spec(module_spec)
1092
+ sys.modules[module_name] = module
1093
+ module_spec.loader.exec_module(module)
1094
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
1095
+ NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
1096
+ else:
1097
+ print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
1098
+ except Exception as e:
1099
+ print(traceback.format_exc())
1100
+ print(f"Cannot import {module_path} module for custom nodes:", e)
1101
+
1102
+ def load_custom_nodes():
1103
+ CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
1104
+ possible_modules = os.listdir(CUSTOM_NODE_PATH)
1105
+ if "__pycache__" in possible_modules:
1106
+ possible_modules.remove("__pycache__")
1107
+
1108
+ for possible_module in possible_modules:
1109
+ module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
1110
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
1111
+ load_custom_node(module_path)
1112
+
1113
+ def init_custom_nodes():
1114
+ load_custom_nodes()
1115
+ load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchdiffeq
3
+ torchsde
4
+ einops
5
+ open-clip-torch
6
+ transformers>=4.25.1
7
+ safetensors
8
+ pytorch_lightning
9
+ aiohttp
10
+ accelerate
11
+ pyyaml
server.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ import nodes
5
+ import folder_paths
6
+ import execution
7
+ import uuid
8
+ import json
9
+ import glob
10
+ try:
11
+ import aiohttp
12
+ from aiohttp import web
13
+ except ImportError:
14
+ print("Module 'aiohttp' not installed. Please install it via:")
15
+ print("pip install aiohttp")
16
+ print("or")
17
+ print("pip install -r requirements.txt")
18
+ sys.exit()
19
+
20
+ import mimetypes
21
+
22
+
23
+ @web.middleware
24
+ async def cache_control(request: web.Request, handler):
25
+ response: web.Response = await handler(request)
26
+ if request.path.endswith('.js') or request.path.endswith('.css'):
27
+ response.headers.setdefault('Cache-Control', 'no-cache')
28
+ return response
29
+
30
+ class PromptServer():
31
+ def __init__(self, loop):
32
+ PromptServer.instance = self
33
+
34
+ mimetypes.init();
35
+ mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
36
+ self.prompt_queue = None
37
+ self.loop = loop
38
+ self.messages = asyncio.Queue()
39
+ self.number = 0
40
+ self.app = web.Application(client_max_size=20971520, middlewares=[cache_control])
41
+ self.sockets = dict()
42
+ self.web_root = os.path.join(os.path.dirname(
43
+ os.path.realpath(__file__)), "web")
44
+ routes = web.RouteTableDef()
45
+ self.routes = routes
46
+ self.last_node_id = None
47
+ self.client_id = None
48
+
49
+ @routes.get('/ws')
50
+ async def websocket_handler(request):
51
+ ws = web.WebSocketResponse()
52
+ await ws.prepare(request)
53
+ sid = request.rel_url.query.get('clientId', '')
54
+ if sid:
55
+ # Reusing existing session, remove old
56
+ self.sockets.pop(sid, None)
57
+ else:
58
+ sid = uuid.uuid4().hex
59
+
60
+ self.sockets[sid] = ws
61
+
62
+ try:
63
+ # Send initial state to the new client
64
+ await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
65
+ # On reconnect if we are the currently executing client send the current node
66
+ if self.client_id == sid and self.last_node_id is not None:
67
+ await self.send("executing", { "node": self.last_node_id }, sid)
68
+
69
+ async for msg in ws:
70
+ if msg.type == aiohttp.WSMsgType.ERROR:
71
+ print('ws connection closed with exception %s' % ws.exception())
72
+ finally:
73
+ self.sockets.pop(sid, None)
74
+ return ws
75
+
76
+ @routes.get("/")
77
+ async def get_root(request):
78
+ return web.FileResponse(os.path.join(self.web_root, "index.html"))
79
+
80
+ @routes.get("/embeddings")
81
+ def get_embeddings(self):
82
+ embeddings = folder_paths.get_filename_list("embeddings")
83
+ return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))
84
+
85
+ @routes.get("/extensions")
86
+ async def get_extensions(request):
87
+ files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
88
+ return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
89
+
90
+ @routes.post("/upload/image")
91
+ async def upload_image(request):
92
+ upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
93
+
94
+ if not os.path.exists(upload_dir):
95
+ os.makedirs(upload_dir)
96
+
97
+ post = await request.post()
98
+ image = post.get("image")
99
+
100
+ if image and image.file:
101
+ filename = image.filename
102
+ if not filename:
103
+ return web.Response(status=400)
104
+
105
+ split = os.path.splitext(filename)
106
+ i = 1
107
+ while os.path.exists(os.path.join(upload_dir, filename)):
108
+ filename = f"{split[0]} ({i}){split[1]}"
109
+ i += 1
110
+
111
+ filepath = os.path.join(upload_dir, filename)
112
+
113
+ with open(filepath, "wb") as f:
114
+ f.write(image.file.read())
115
+
116
+ return web.json_response({"name" : filename})
117
+ else:
118
+ return web.Response(status=400)
119
+
120
+
121
+ @routes.get("/view")
122
+ async def view_image(request):
123
+ if "filename" in request.rel_url.query:
124
+ type = request.rel_url.query.get("type", "output")
125
+ if type not in ["output", "input", "temp"]:
126
+ return web.Response(status=400)
127
+
128
+ output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
129
+ if "subfolder" in request.rel_url.query:
130
+ full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
131
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
132
+ return web.Response(status=403)
133
+ output_dir = full_output_dir
134
+
135
+ filename = request.rel_url.query["filename"]
136
+ filename = os.path.basename(filename)
137
+ file = os.path.join(output_dir, filename)
138
+
139
+ if os.path.isfile(file):
140
+ return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
141
+
142
+ return web.Response(status=404)
143
+
144
+ @routes.get("/prompt")
145
+ async def get_prompt(request):
146
+ return web.json_response(self.get_queue_info())
147
+
148
+ @routes.get("/object_info")
149
+ async def get_object_info(request):
150
+ out = {}
151
+ for x in nodes.NODE_CLASS_MAPPINGS:
152
+ obj_class = nodes.NODE_CLASS_MAPPINGS[x]
153
+ info = {}
154
+ info['input'] = obj_class.INPUT_TYPES()
155
+ info['output'] = obj_class.RETURN_TYPES
156
+ info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
157
+ info['name'] = x #TODO
158
+ info['description'] = ''
159
+ info['category'] = 'sd'
160
+ if hasattr(obj_class, 'CATEGORY'):
161
+ info['category'] = obj_class.CATEGORY
162
+ out[x] = info
163
+ return web.json_response(out)
164
+
165
+ @routes.get("/history")
166
+ async def get_history(request):
167
+ return web.json_response(self.prompt_queue.get_history())
168
+
169
+ @routes.get("/queue")
170
+ async def get_queue(request):
171
+ queue_info = {}
172
+ current_queue = self.prompt_queue.get_current_queue()
173
+ queue_info['queue_running'] = current_queue[0]
174
+ queue_info['queue_pending'] = current_queue[1]
175
+ return web.json_response(queue_info)
176
+
177
+ @routes.post("/prompt")
178
+ async def post_prompt(request):
179
+ print("got prompt")
180
+ resp_code = 200
181
+ out_string = ""
182
+ json_data = await request.json()
183
+
184
+ if "number" in json_data:
185
+ number = float(json_data['number'])
186
+ else:
187
+ number = self.number
188
+ if "front" in json_data:
189
+ if json_data['front']:
190
+ number = -number
191
+
192
+ self.number += 1
193
+
194
+ if "prompt" in json_data:
195
+ prompt = json_data["prompt"]
196
+ valid = execution.validate_prompt(prompt)
197
+ extra_data = {}
198
+ if "extra_data" in json_data:
199
+ extra_data = json_data["extra_data"]
200
+
201
+ if "client_id" in json_data:
202
+ extra_data["client_id"] = json_data["client_id"]
203
+ if valid[0]:
204
+ self.prompt_queue.put((number, id(prompt), prompt, extra_data))
205
+ else:
206
+ resp_code = 400
207
+ out_string = valid[1]
208
+ print("invalid prompt:", valid[1])
209
+
210
+ return web.Response(body=out_string, status=resp_code)
211
+
212
+ @routes.post("/queue")
213
+ async def post_queue(request):
214
+ json_data = await request.json()
215
+ if "clear" in json_data:
216
+ if json_data["clear"]:
217
+ self.prompt_queue.wipe_queue()
218
+ if "delete" in json_data:
219
+ to_delete = json_data['delete']
220
+ for id_to_delete in to_delete:
221
+ delete_func = lambda a: a[1] == int(id_to_delete)
222
+ self.prompt_queue.delete_queue_item(delete_func)
223
+
224
+ return web.Response(status=200)
225
+
226
+ @routes.post("/interrupt")
227
+ async def post_interrupt(request):
228
+ nodes.interrupt_processing()
229
+ return web.Response(status=200)
230
+
231
+ @routes.post("/history")
232
+ async def post_history(request):
233
+ json_data = await request.json()
234
+ if "clear" in json_data:
235
+ if json_data["clear"]:
236
+ self.prompt_queue.wipe_history()
237
+ if "delete" in json_data:
238
+ to_delete = json_data['delete']
239
+ for id_to_delete in to_delete:
240
+ self.prompt_queue.delete_history_item(id_to_delete)
241
+
242
+ return web.Response(status=200)
243
+
244
+ def add_routes(self):
245
+ self.app.add_routes(self.routes)
246
+ self.app.add_routes([
247
+ web.static('/', self.web_root),
248
+ ])
249
+
250
+ def get_queue_info(self):
251
+ prompt_info = {}
252
+ exec_info = {}
253
+ exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
254
+ prompt_info['exec_info'] = exec_info
255
+ return prompt_info
256
+
257
+ async def send(self, event, data, sid=None):
258
+ message = {"type": event, "data": data}
259
+
260
+ if isinstance(message, str) == False:
261
+ message = json.dumps(message)
262
+
263
+ if sid is None:
264
+ for ws in self.sockets.values():
265
+ await ws.send_str(message)
266
+ elif sid in self.sockets:
267
+ await self.sockets[sid].send_str(message)
268
+
269
+ def send_sync(self, event, data, sid=None):
270
+ self.loop.call_soon_threadsafe(
271
+ self.messages.put_nowait, (event, data, sid))
272
+
273
+ def queue_updated(self):
274
+ self.send_sync("status", { "status": self.get_queue_info() })
275
+
276
+ async def publish_loop(self):
277
+ while True:
278
+ msg = await self.messages.get()
279
+ await self.send(*msg)
280
+
281
+ async def start(self, address, port, verbose=True, call_on_start=None):
282
+ runner = web.AppRunner(self.app)
283
+ await runner.setup()
284
+ site = web.TCPSite(runner, address, port)
285
+ await site.start()
286
+
287
+ if address == '':
288
+ address = '0.0.0.0'
289
+ if verbose:
290
+ print("Starting server\n")
291
+ print("To see the GUI go to: http://{}:{}".format(address, port))
292
+ if call_on_start is not None:
293
+ call_on_start(address, port)
294
+