surena26 commited on
Commit
7595c28
·
verified ·
1 Parent(s): 3d692ec

Upload ComfyUI/main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ComfyUI/main.py +258 -0
ComfyUI/main.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.options
2
+ comfy.options.enable_args_parsing()
3
+
4
+ import os
5
+ import importlib.util
6
+ import folder_paths
7
+ import time
8
+
9
+ def execute_prestartup_script():
10
+ def execute_script(script_path):
11
+ module_name = os.path.splitext(script_path)[0]
12
+ try:
13
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
14
+ module = importlib.util.module_from_spec(spec)
15
+ spec.loader.exec_module(module)
16
+ return True
17
+ except Exception as e:
18
+ print(f"Failed to execute startup-script: {script_path} / {e}")
19
+ return False
20
+
21
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
22
+ for custom_node_path in node_paths:
23
+ possible_modules = os.listdir(custom_node_path)
24
+ node_prestartup_times = []
25
+
26
+ for possible_module in possible_modules:
27
+ module_path = os.path.join(custom_node_path, possible_module)
28
+ if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
29
+ continue
30
+
31
+ script_path = os.path.join(module_path, "prestartup_script.py")
32
+ if os.path.exists(script_path):
33
+ time_before = time.perf_counter()
34
+ success = execute_script(script_path)
35
+ node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
36
+ if len(node_prestartup_times) > 0:
37
+ print("\nPrestartup times for custom nodes:")
38
+ for n in sorted(node_prestartup_times):
39
+ if n[2]:
40
+ import_message = ""
41
+ else:
42
+ import_message = " (PRESTARTUP FAILED)"
43
+ print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
44
+ print()
45
+
46
+ execute_prestartup_script()
47
+
48
+
49
+ # Main code
50
+ import asyncio
51
+ import itertools
52
+ import shutil
53
+ import threading
54
+ import gc
55
+
56
+ from comfy.cli_args import args
57
+ import logging
58
+
59
+ if os.name == "nt":
60
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
61
+
62
+ if __name__ == "__main__":
63
+ if args.cuda_device is not None:
64
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
65
+ logging.info("Set cuda device to: {}".format(args.cuda_device))
66
+
67
+ if args.deterministic:
68
+ if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
69
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
70
+
71
+ import cuda_malloc
72
+
73
+ import comfy.utils
74
+ import yaml
75
+
76
+ import execution
77
+ import server
78
+ from server import BinaryEventTypes
79
+ from nodes import init_custom_nodes
80
+ import comfy.model_management
81
+
82
+ def cuda_malloc_warning():
83
+ device = comfy.model_management.get_torch_device()
84
+ device_name = comfy.model_management.get_torch_device_name(device)
85
+ cuda_malloc_warning = False
86
+ if "cudaMallocAsync" in device_name:
87
+ for b in cuda_malloc.blacklist:
88
+ if b in device_name:
89
+ cuda_malloc_warning = True
90
+ if cuda_malloc_warning:
91
+ logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
92
+
93
+ def prompt_worker(q, server):
94
+ e = execution.PromptExecutor(server)
95
+ last_gc_collect = 0
96
+ need_gc = False
97
+ gc_collect_interval = 10.0
98
+
99
+ while True:
100
+ timeout = 1000.0
101
+ if need_gc:
102
+ timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
103
+
104
+ queue_item = q.get(timeout=timeout)
105
+ if queue_item is not None:
106
+ item, item_id = queue_item
107
+ execution_start_time = time.perf_counter()
108
+ prompt_id = item[1]
109
+ server.last_prompt_id = prompt_id
110
+
111
+ e.execute(item[2], prompt_id, item[3], item[4])
112
+ need_gc = True
113
+ q.task_done(item_id,
114
+ e.outputs_ui,
115
+ status=execution.PromptQueue.ExecutionStatus(
116
+ status_str='success' if e.success else 'error',
117
+ completed=e.success,
118
+ messages=e.status_messages))
119
+ if server.client_id is not None:
120
+ server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
121
+
122
+ current_time = time.perf_counter()
123
+ execution_time = current_time - execution_start_time
124
+ logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
125
+
126
+ flags = q.get_flags()
127
+ free_memory = flags.get("free_memory", False)
128
+
129
+ if flags.get("unload_models", free_memory):
130
+ comfy.model_management.unload_all_models()
131
+ need_gc = True
132
+ last_gc_collect = 0
133
+
134
+ if free_memory:
135
+ e.reset()
136
+ need_gc = True
137
+ last_gc_collect = 0
138
+
139
+ if need_gc:
140
+ current_time = time.perf_counter()
141
+ if (current_time - last_gc_collect) > gc_collect_interval:
142
+ comfy.model_management.cleanup_models()
143
+ gc.collect()
144
+ comfy.model_management.soft_empty_cache()
145
+ last_gc_collect = current_time
146
+ need_gc = False
147
+
148
+ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
149
+ await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
150
+
151
+
152
+ def hijack_progress(server):
153
+ def hook(value, total, preview_image):
154
+ comfy.model_management.throw_exception_if_processing_interrupted()
155
+ progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
156
+
157
+ server.send_sync("progress", progress, server.client_id)
158
+ if preview_image is not None:
159
+ server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
160
+ comfy.utils.set_progress_bar_global_hook(hook)
161
+
162
+
163
+ def cleanup_temp():
164
+ temp_dir = folder_paths.get_temp_directory()
165
+ if os.path.exists(temp_dir):
166
+ shutil.rmtree(temp_dir, ignore_errors=True)
167
+
168
+
169
+ def load_extra_path_config(yaml_path):
170
+ with open(yaml_path, 'r') as stream:
171
+ config = yaml.safe_load(stream)
172
+ for c in config:
173
+ conf = config[c]
174
+ if conf is None:
175
+ continue
176
+ base_path = None
177
+ if "base_path" in conf:
178
+ base_path = conf.pop("base_path")
179
+ for x in conf:
180
+ for y in conf[x].split("\n"):
181
+ if len(y) == 0:
182
+ continue
183
+ full_path = y
184
+ if base_path is not None:
185
+ full_path = os.path.join(base_path, full_path)
186
+ logging.info("Adding extra search path {} {}".format(x, full_path))
187
+ folder_paths.add_model_folder_path(x, full_path)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ if args.temp_directory:
192
+ temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
193
+ logging.info(f"Setting temp directory to: {temp_dir}")
194
+ folder_paths.set_temp_directory(temp_dir)
195
+ cleanup_temp()
196
+
197
+ if args.windows_standalone_build:
198
+ try:
199
+ import new_updater
200
+ new_updater.update_windows_updater()
201
+ except:
202
+ pass
203
+
204
+ loop = asyncio.new_event_loop()
205
+ asyncio.set_event_loop(loop)
206
+ server = server.PromptServer(loop)
207
+ q = execution.PromptQueue(server)
208
+
209
+ extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
210
+ if os.path.isfile(extra_model_paths_config_path):
211
+ load_extra_path_config(extra_model_paths_config_path)
212
+
213
+ if args.extra_model_paths_config:
214
+ for config_path in itertools.chain(*args.extra_model_paths_config):
215
+ load_extra_path_config(config_path)
216
+
217
+ init_custom_nodes()
218
+
219
+ cuda_malloc_warning()
220
+
221
+ server.add_routes()
222
+ hijack_progress(server)
223
+
224
+ threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
225
+
226
+ if args.output_directory:
227
+ output_dir = os.path.abspath(args.output_directory)
228
+ logging.info(f"Setting output directory to: {output_dir}")
229
+ folder_paths.set_output_directory(output_dir)
230
+
231
+ #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
232
+ folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
233
+ folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
234
+ folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
235
+
236
+ if args.input_directory:
237
+ input_dir = os.path.abspath(args.input_directory)
238
+ logging.info(f"Setting input directory to: {input_dir}")
239
+ folder_paths.set_input_directory(input_dir)
240
+
241
+ if args.quick_test_for_ci:
242
+ exit(0)
243
+
244
+ call_on_start = None
245
+ if args.auto_launch:
246
+ def startup_server(scheme, address, port):
247
+ import webbrowser
248
+ if os.name == 'nt' and address == '0.0.0.0':
249
+ address = '127.0.0.1'
250
+ webbrowser.open(f"{scheme}://{address}:{port}")
251
+ call_on_start = startup_server
252
+
253
+ try:
254
+ loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
255
+ except KeyboardInterrupt:
256
+ logging.info("\nStopped server")
257
+
258
+ cleanup_temp()