| 
							 | 
						import comfy.options | 
					
					
						
						| 
							 | 
						comfy.options.enable_args_parsing() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import importlib.util | 
					
					
						
						| 
							 | 
						import folder_paths | 
					
					
						
						| 
							 | 
						import time | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def execute_prestartup_script(): | 
					
					
						
						| 
							 | 
						    def execute_script(script_path): | 
					
					
						
						| 
							 | 
						        module_name = os.path.splitext(script_path)[0] | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            spec = importlib.util.spec_from_file_location(module_name, script_path) | 
					
					
						
						| 
							 | 
						            module = importlib.util.module_from_spec(spec) | 
					
					
						
						| 
							 | 
						            spec.loader.exec_module(module) | 
					
					
						
						| 
							 | 
						            return True | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						            print(f"Failed to execute startup-script: {script_path} / {e}") | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    node_paths = folder_paths.get_folder_paths("custom_nodes") | 
					
					
						
						| 
							 | 
						    for custom_node_path in node_paths: | 
					
					
						
						| 
							 | 
						        possible_modules = os.listdir(custom_node_path) | 
					
					
						
						| 
							 | 
						        node_prestartup_times = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for possible_module in possible_modules: | 
					
					
						
						| 
							 | 
						            module_path = os.path.join(custom_node_path, possible_module) | 
					
					
						
						| 
							 | 
						            if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": | 
					
					
						
						| 
							 | 
						                continue | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            script_path = os.path.join(module_path, "prestartup_script.py") | 
					
					
						
						| 
							 | 
						            if os.path.exists(script_path): | 
					
					
						
						| 
							 | 
						                time_before = time.perf_counter() | 
					
					
						
						| 
							 | 
						                success = execute_script(script_path) | 
					
					
						
						| 
							 | 
						                node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) | 
					
					
						
						| 
							 | 
						    if len(node_prestartup_times) > 0: | 
					
					
						
						| 
							 | 
						        print("\nPrestartup times for custom nodes:") | 
					
					
						
						| 
							 | 
						        for n in sorted(node_prestartup_times): | 
					
					
						
						| 
							 | 
						            if n[2]: | 
					
					
						
						| 
							 | 
						                import_message = "" | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                import_message = " (PRESTARTUP FAILED)" | 
					
					
						
						| 
							 | 
						            print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) | 
					
					
						
						| 
							 | 
						        print() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						execute_prestartup_script() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import asyncio | 
					
					
						
						| 
							 | 
						import itertools | 
					
					
						
						| 
							 | 
						import shutil | 
					
					
						
						| 
							 | 
						import threading | 
					
					
						
						| 
							 | 
						import gc | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from comfy.cli_args import args | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if os.name == "nt": | 
					
					
						
						| 
							 | 
						    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    if args.cuda_device is not None: | 
					
					
						
						| 
							 | 
						        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) | 
					
					
						
						| 
							 | 
						        logging.info("Set cuda device to: {}".format(args.cuda_device)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.deterministic: | 
					
					
						
						| 
							 | 
						        if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: | 
					
					
						
						| 
							 | 
						            os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    import cuda_malloc | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import comfy.utils | 
					
					
						
						| 
							 | 
						import yaml | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import execution | 
					
					
						
						| 
							 | 
						import server | 
					
					
						
						| 
							 | 
						from server import BinaryEventTypes | 
					
					
						
						| 
							 | 
						from nodes import init_custom_nodes | 
					
					
						
						| 
							 | 
						import comfy.model_management | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def cuda_malloc_warning(): | 
					
					
						
						| 
							 | 
						    device = comfy.model_management.get_torch_device() | 
					
					
						
						| 
							 | 
						    device_name = comfy.model_management.get_torch_device_name(device) | 
					
					
						
						| 
							 | 
						    cuda_malloc_warning = False | 
					
					
						
						| 
							 | 
						    if "cudaMallocAsync" in device_name: | 
					
					
						
						| 
							 | 
						        for b in cuda_malloc.blacklist: | 
					
					
						
						| 
							 | 
						            if b in device_name: | 
					
					
						
						| 
							 | 
						                cuda_malloc_warning = True | 
					
					
						
						| 
							 | 
						        if cuda_malloc_warning: | 
					
					
						
						| 
							 | 
						            logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def prompt_worker(q, server): | 
					
					
						
						| 
							 | 
						    e = execution.PromptExecutor(server) | 
					
					
						
						| 
							 | 
						    last_gc_collect = 0 | 
					
					
						
						| 
							 | 
						    need_gc = False | 
					
					
						
						| 
							 | 
						    gc_collect_interval = 10.0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    while True: | 
					
					
						
						| 
							 | 
						        timeout = 1000.0 | 
					
					
						
						| 
							 | 
						        if need_gc: | 
					
					
						
						| 
							 | 
						            timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        queue_item = q.get(timeout=timeout) | 
					
					
						
						| 
							 | 
						        if queue_item is not None: | 
					
					
						
						| 
							 | 
						            item, item_id = queue_item | 
					
					
						
						| 
							 | 
						            execution_start_time = time.perf_counter() | 
					
					
						
						| 
							 | 
						            prompt_id = item[1] | 
					
					
						
						| 
							 | 
						            server.last_prompt_id = prompt_id | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            e.execute(item[2], prompt_id, item[3], item[4]) | 
					
					
						
						| 
							 | 
						            need_gc = True | 
					
					
						
						| 
							 | 
						            q.task_done(item_id, | 
					
					
						
						| 
							 | 
						                        e.outputs_ui, | 
					
					
						
						| 
							 | 
						                        status=execution.PromptQueue.ExecutionStatus( | 
					
					
						
						| 
							 | 
						                            status_str='success' if e.success else 'error', | 
					
					
						
						| 
							 | 
						                            completed=e.success, | 
					
					
						
						| 
							 | 
						                            messages=e.status_messages)) | 
					
					
						
						| 
							 | 
						            if server.client_id is not None: | 
					
					
						
						| 
							 | 
						                server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            current_time = time.perf_counter() | 
					
					
						
						| 
							 | 
						            execution_time = current_time - execution_start_time | 
					
					
						
						| 
							 | 
						            logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        flags = q.get_flags() | 
					
					
						
						| 
							 | 
						        free_memory = flags.get("free_memory", False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if flags.get("unload_models", free_memory): | 
					
					
						
						| 
							 | 
						            comfy.model_management.unload_all_models() | 
					
					
						
						| 
							 | 
						            need_gc = True | 
					
					
						
						| 
							 | 
						            last_gc_collect = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if free_memory: | 
					
					
						
						| 
							 | 
						            e.reset() | 
					
					
						
						| 
							 | 
						            need_gc = True | 
					
					
						
						| 
							 | 
						            last_gc_collect = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if need_gc: | 
					
					
						
						| 
							 | 
						            current_time = time.perf_counter() | 
					
					
						
						| 
							 | 
						            if (current_time - last_gc_collect) > gc_collect_interval: | 
					
					
						
						| 
							 | 
						                gc.collect() | 
					
					
						
						| 
							 | 
						                comfy.model_management.soft_empty_cache() | 
					
					
						
						| 
							 | 
						                last_gc_collect = current_time | 
					
					
						
						| 
							 | 
						                need_gc = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						async def run(server, address='', port=8188, verbose=True, call_on_start=None): | 
					
					
						
						| 
							 | 
						    await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def hijack_progress(server): | 
					
					
						
						| 
							 | 
						    def hook(value, total, preview_image): | 
					
					
						
						| 
							 | 
						        comfy.model_management.throw_exception_if_processing_interrupted() | 
					
					
						
						| 
							 | 
						        progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        server.send_sync("progress", progress, server.client_id) | 
					
					
						
						| 
							 | 
						        if preview_image is not None: | 
					
					
						
						| 
							 | 
						            server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) | 
					
					
						
						| 
							 | 
						    comfy.utils.set_progress_bar_global_hook(hook) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def cleanup_temp(): | 
					
					
						
						| 
							 | 
						    temp_dir = folder_paths.get_temp_directory() | 
					
					
						
						| 
							 | 
						    if os.path.exists(temp_dir): | 
					
					
						
						| 
							 | 
						        shutil.rmtree(temp_dir, ignore_errors=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_extra_path_config(yaml_path): | 
					
					
						
						| 
							 | 
						    with open(yaml_path, 'r') as stream: | 
					
					
						
						| 
							 | 
						        config = yaml.safe_load(stream) | 
					
					
						
						| 
							 | 
						    for c in config: | 
					
					
						
						| 
							 | 
						        conf = config[c] | 
					
					
						
						| 
							 | 
						        if conf is None: | 
					
					
						
						| 
							 | 
						            continue | 
					
					
						
						| 
							 | 
						        base_path = None | 
					
					
						
						| 
							 | 
						        if "base_path" in conf: | 
					
					
						
						| 
							 | 
						            base_path = conf.pop("base_path") | 
					
					
						
						| 
							 | 
						        for x in conf: | 
					
					
						
						| 
							 | 
						            for y in conf[x].split("\n"): | 
					
					
						
						| 
							 | 
						                if len(y) == 0: | 
					
					
						
						| 
							 | 
						                    continue | 
					
					
						
						| 
							 | 
						                full_path = y | 
					
					
						
						| 
							 | 
						                if base_path is not None: | 
					
					
						
						| 
							 | 
						                    full_path = os.path.join(base_path, full_path) | 
					
					
						
						| 
							 | 
						                logging.info("Adding extra search path {} {}".format(x, full_path)) | 
					
					
						
						| 
							 | 
						                folder_paths.add_model_folder_path(x, full_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    if args.temp_directory: | 
					
					
						
						| 
							 | 
						        temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") | 
					
					
						
						| 
							 | 
						        logging.info(f"Setting temp directory to: {temp_dir}") | 
					
					
						
						| 
							 | 
						        folder_paths.set_temp_directory(temp_dir) | 
					
					
						
						| 
							 | 
						    cleanup_temp() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.windows_standalone_build: | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            import new_updater | 
					
					
						
						| 
							 | 
						            new_updater.update_windows_updater() | 
					
					
						
						| 
							 | 
						        except: | 
					
					
						
						| 
							 | 
						            pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    loop = asyncio.new_event_loop() | 
					
					
						
						| 
							 | 
						    asyncio.set_event_loop(loop) | 
					
					
						
						| 
							 | 
						    server = server.PromptServer(loop) | 
					
					
						
						| 
							 | 
						    q = execution.PromptQueue(server) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") | 
					
					
						
						| 
							 | 
						    if os.path.isfile(extra_model_paths_config_path): | 
					
					
						
						| 
							 | 
						        load_extra_path_config(extra_model_paths_config_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.extra_model_paths_config: | 
					
					
						
						| 
							 | 
						        for config_path in itertools.chain(*args.extra_model_paths_config): | 
					
					
						
						| 
							 | 
						            load_extra_path_config(config_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    init_custom_nodes() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    cuda_malloc_warning() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    server.add_routes() | 
					
					
						
						| 
							 | 
						    hijack_progress(server) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.output_directory: | 
					
					
						
						| 
							 | 
						        output_dir = os.path.abspath(args.output_directory) | 
					
					
						
						| 
							 | 
						        logging.info(f"Setting output directory to: {output_dir}") | 
					
					
						
						| 
							 | 
						        folder_paths.set_output_directory(output_dir) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) | 
					
					
						
						| 
							 | 
						    folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) | 
					
					
						
						| 
							 | 
						    folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.input_directory: | 
					
					
						
						| 
							 | 
						        input_dir = os.path.abspath(args.input_directory) | 
					
					
						
						| 
							 | 
						        logging.info(f"Setting input directory to: {input_dir}") | 
					
					
						
						| 
							 | 
						        folder_paths.set_input_directory(input_dir) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.quick_test_for_ci: | 
					
					
						
						| 
							 | 
						        exit(0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    call_on_start = None | 
					
					
						
						| 
							 | 
						    if args.auto_launch: | 
					
					
						
						| 
							 | 
						        def startup_server(address, port): | 
					
					
						
						| 
							 | 
						            import webbrowser | 
					
					
						
						| 
							 | 
						            if os.name == 'nt' and address == '0.0.0.0': | 
					
					
						
						| 
							 | 
						                address = '127.0.0.1' | 
					
					
						
						| 
							 | 
						            webbrowser.open(f"http://{address}:{port}") | 
					
					
						
						| 
							 | 
						        call_on_start = startup_server | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) | 
					
					
						
						| 
							 | 
						    except KeyboardInterrupt: | 
					
					
						
						| 
							 | 
						        logging.info("\nStopped server") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    cleanup_temp() | 
					
					
						
						| 
							 | 
						
 |