Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| import sys | |
| import glob | |
| import signal | |
| import asyncio | |
| import logging | |
| import importlib | |
| import contextlib | |
| from threading import Thread | |
| import modules.loader | |
| import torch # pylint: disable=wrong-import-order | |
| from modules import timer, errors, paths # pylint: disable=unused-import | |
| from installer import log, git_commit, custom_excepthook | |
| import ldm.modules.encoders.modules # pylint: disable=W0611,C0411,E0401 | |
| from modules import shared, extensions, gr_tempdir, modelloader # pylint: disable=ungrouped-imports | |
| from modules import extra_networks, ui_extra_networks # pylint: disable=ungrouped-imports | |
| from modules.paths import create_paths | |
| from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=W0611,C0411,C0412 | |
| import modules.devices | |
| import modules.sd_samplers | |
| import modules.lowvram | |
| import modules.scripts | |
| import modules.sd_models | |
| import modules.sd_vae | |
| import modules.progress | |
| import modules.ui | |
| import modules.txt2img | |
| import modules.img2img | |
| import modules.upscaler | |
| import modules.textual_inversion.textual_inversion | |
| import modules.hypernetworks.hypernetwork | |
| import modules.script_callbacks | |
| from modules.api.middleware import setup_middleware | |
| from modules.shared import cmd_opts, opts | |
| sys.excepthook = custom_excepthook | |
| local_url = None | |
| state = shared.state | |
| backend = shared.backend | |
| if not modules.loader.initialized: | |
| timer.startup.record("libraries") | |
| if cmd_opts.server_name: | |
| server_name = cmd_opts.server_name | |
| else: | |
| server_name = "0.0.0.0" if cmd_opts.listen else None | |
| fastapi_args = { | |
| "version": f'0.0.{git_commit}', | |
| "title": "SD.Next", | |
| "description": "SD.Next", | |
| "docs_url": "/docs" if cmd_opts.docs else None, | |
| "redoc_url": "/redocs" if cmd_opts.docs else None, | |
| "swagger_ui_parameters": { | |
| "displayOperationId": True, | |
| "showCommonExtensions": True, | |
| "deepLinking": False, | |
| } | |
| } | |
| import modules.sd_hijack | |
| timer.startup.record("ldm") | |
| modules.loader.initialized = True | |
| def check_rollback_vae(): | |
| if shared.cmd_opts.rollback_vae: | |
| if not torch.cuda.is_available(): | |
| log.error("Rollback VAE functionality requires compatible GPU") | |
| shared.cmd_opts.rollback_vae = False | |
| elif not torch.__version__.startswith('2.1'): | |
| log.error("Rollback VAE functionality requires Torch 2.1 or higher") | |
| shared.cmd_opts.rollback_vae = False | |
| elif 0 < torch.cuda.get_device_capability()[0] < 8: | |
| log.error('Rollback VAE functionality device capabilities not met') | |
| shared.cmd_opts.rollback_vae = False | |
| def initialize(): | |
| log.debug('Initializing') | |
| check_rollback_vae() | |
| modules.sd_samplers.list_samplers() | |
| timer.startup.record("samplers") | |
| modules.sd_vae.refresh_vae_list() | |
| timer.startup.record("vae") | |
| extensions.list_extensions() | |
| timer.startup.record("extensions") | |
| modelloader.cleanup_models() | |
| modules.sd_models.setup_model() | |
| timer.startup.record("models") | |
| import modules.postprocess.codeformer_model as codeformer | |
| codeformer.setup_model(opts.codeformer_models_path) | |
| sys.modules["modules.codeformer_model"] = codeformer | |
| import modules.postprocess.gfpgan_model as gfpgan | |
| gfpgan.setup_model(opts.gfpgan_models_path) | |
| timer.startup.record("face-restore") | |
| log.debug('Load extensions') | |
| t_timer, t_total = modules.scripts.load_scripts() | |
| timer.startup.record("extensions") | |
| timer.startup.records["extensions"] = t_total # scripts can reset the time | |
| log.debug(f'Extensions init time: {t_timer.summary()}') | |
| modelloader.load_upscalers() | |
| timer.startup.record("upscalers") | |
| shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) | |
| shared.opts.onchange("temp_dir", gr_tempdir.on_tmpdir_changed) | |
| timer.startup.record("onchange") | |
| modules.textual_inversion.textual_inversion.list_textual_inversion_templates() | |
| shared.reload_hypernetworks() | |
| shared.prompt_styles.reload() | |
| ui_extra_networks.initialize() | |
| ui_extra_networks.register_pages() | |
| extra_networks.initialize() | |
| extra_networks.register_default_extra_networks() | |
| timer.startup.record("networks") | |
| if cmd_opts.tls_keyfile is not None and cmd_opts.tls_certfile is not None: | |
| try: | |
| if not os.path.exists(cmd_opts.tls_keyfile): | |
| log.error("Invalid path to TLS keyfile given") | |
| if not os.path.exists(cmd_opts.tls_certfile): | |
| log.error(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") | |
| except TypeError: | |
| cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None | |
| log.error("TLS setup invalid, running webui without TLS") | |
| else: | |
| log.info("Running with TLS") | |
| timer.startup.record("tls") | |
| # make the program just exit at ctrl+c without waiting for anything | |
| def sigint_handler(_sig, _frame): | |
| log.info('Exiting') | |
| try: | |
| for f in glob.glob("*.lock"): | |
| os.remove(f) | |
| except Exception: | |
| pass | |
| sys.exit(0) | |
| signal.signal(signal.SIGINT, sigint_handler) | |
| def load_model(): | |
| if not opts.sd_checkpoint_autoload or (shared.cmd_opts.ckpt is not None and shared.cmd_opts.ckpt.lower() != 'none'): | |
| log.debug('Model auto load disabled') | |
| else: | |
| shared.state.begin('load') | |
| thread_model = Thread(target=lambda: shared.sd_model) | |
| thread_model.start() | |
| thread_refiner = Thread(target=lambda: shared.sd_refiner) | |
| thread_refiner.start() | |
| shared.state.end() | |
| thread_model.join() | |
| thread_refiner.join() | |
| shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='model')), call=False) | |
| shared.opts.onchange("sd_model_refiner", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='refiner')), call=False) | |
| shared.opts.onchange("sd_model_dict", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='dict')), call=False) | |
| shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) | |
| shared.opts.onchange("sd_backend", wrap_queued_call(lambda: modules.sd_models.change_backend()), call=False) | |
| timer.startup.record("checkpoint") | |
| def create_api(app): | |
| log.debug('Creating API') | |
| from modules.api.api import Api | |
| api = Api(app, queue_lock) | |
| return api | |
| def async_policy(): | |
| _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy | |
| class AnyThreadEventLoopPolicy(_BasePolicy): | |
| def handle_exception(self, context): | |
| msg = context.get("exception", context["message"]) | |
| log.error(f"AsyncIO loop: {msg}") | |
| def get_event_loop(self) -> asyncio.AbstractEventLoop: | |
| try: | |
| self.loop = super().get_event_loop() | |
| except (RuntimeError, AssertionError): | |
| self.loop = self.new_event_loop() | |
| self.set_event_loop(self.loop) | |
| return self.loop | |
| def __init__(self): | |
| super().__init__() | |
| self.loop = self.get_event_loop() | |
| self.loop.set_exception_handler(self.handle_exception) | |
| # log.debug(f"Event loop: {self.loop}") | |
| asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) | |
| def start_common(): | |
| log.debug('Entering start sequence') | |
| if shared.cmd_opts.data_dir is not None and len(shared.cmd_opts.data_dir) > 0: | |
| log.info(f'Using data path: {shared.cmd_opts.data_dir}') | |
| if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models': | |
| log.info(f'Using models path: {shared.cmd_opts.models_dir}') | |
| create_paths(opts) | |
| async_policy() | |
| initialize() | |
| if shared.opts.clean_temp_dir_at_start: | |
| gr_tempdir.cleanup_tmpdr() | |
| timer.startup.record("cleanup") | |
| def start_ui(): | |
| log.debug('Creating UI') | |
| modules.script_callbacks.before_ui_callback() | |
| timer.startup.record("before-ui") | |
| shared.demo = modules.ui.create_ui(timer.startup) | |
| timer.startup.record("ui") | |
| if cmd_opts.disable_queue: | |
| log.info('Server queues disabled') | |
| shared.demo.progress_tracking = False | |
| else: | |
| shared.demo.queue(concurrency_count=64) | |
| gradio_auth_creds = [] | |
| if cmd_opts.auth: | |
| gradio_auth_creds += [x.strip() for x in cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()] | |
| if cmd_opts.auth_file: | |
| if not os.path.exists(cmd_opts.auth_file): | |
| log.error(f"Invalid path to auth file: '{cmd_opts.auth_file}'") | |
| else: | |
| with open(cmd_opts.auth_file, 'r', encoding="utf8") as file: | |
| for line in file.readlines(): | |
| gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] | |
| if len(gradio_auth_creds) > 0: | |
| log.info(f'Authentication enabled: users={len(list(gradio_auth_creds))}') | |
| global local_url # pylint: disable=global-statement | |
| stdout = io.StringIO() | |
| allowed_paths = [os.path.dirname(__file__)] | |
| if cmd_opts.data_dir is not None and os.path.isdir(cmd_opts.data_dir): | |
| allowed_paths.append(cmd_opts.data_dir) | |
| if cmd_opts.allowed_paths is not None: | |
| allowed_paths += [p for p in cmd_opts.allowed_paths if os.path.isdir(p)] | |
| shared.log.debug(f'Root paths: {allowed_paths}') | |
| with contextlib.redirect_stdout(stdout): | |
| app, local_url, share_url = shared.demo.launch( # app is FastAPI(Starlette) instance | |
| share=cmd_opts.share, | |
| server_name=server_name, | |
| server_port=cmd_opts.port if cmd_opts.port != 7860 else None, | |
| ssl_keyfile=cmd_opts.tls_keyfile, | |
| ssl_certfile=cmd_opts.tls_certfile, | |
| ssl_verify=not cmd_opts.tls_selfsign, | |
| debug=False, | |
| auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, | |
| prevent_thread_lock=True, | |
| max_threads=64, | |
| show_api=False, | |
| quiet=True, | |
| favicon_path='html/logo.ico', | |
| allowed_paths=allowed_paths, | |
| app_kwargs=fastapi_args, | |
| _frontend=True and cmd_opts.share, | |
| ) | |
| if cmd_opts.data_dir is not None: | |
| gr_tempdir.register_tmp_file(shared.demo, os.path.join(cmd_opts.data_dir, 'x')) | |
| shared.log.info(f'Local URL: {local_url}') | |
| if cmd_opts.docs: | |
| shared.log.info(f'API Docs: {local_url[:-1]}/docs') # pylint: disable=unsubscriptable-object | |
| shared.log.info(f'API ReDocs: {local_url[:-1]}/redocs') # pylint: disable=unsubscriptable-object | |
| if share_url is not None: | |
| shared.log.info(f'Share URL: {share_url}') | |
| shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}') | |
| shared.demo.server.wants_restart = False | |
| setup_middleware(app, cmd_opts) | |
| if cmd_opts.subpath: | |
| import gradio | |
| gradio.mount_gradio_app(app, shared.demo, path=f"/{cmd_opts.subpath}") | |
| shared.log.info(f'Redirector mounted: /{cmd_opts.subpath}') | |
| timer.startup.record("launch") | |
| modules.progress.setup_progress_api(app) | |
| shared.api = create_api(app) | |
| timer.startup.record("api") | |
| ui_extra_networks.init_api(app) | |
| modules.script_callbacks.app_started_callback(shared.demo, app) | |
| timer.startup.record("app-started") | |
| time_setup = [f'{k}:{round(v,3)}' for (k,v) in modules.scripts.time_setup.items() if v > 0.005] | |
| shared.log.debug(f'Scripts setup: {time_setup}') | |
| time_component = [f'{k}:{round(v,3)}' for (k,v) in modules.scripts.time_component.items() if v > 0.005] | |
| if len(time_component) > 0: | |
| shared.log.debug(f'Scripts components: {time_component}') | |
| def webui(restart=False): | |
| if restart: | |
| modules.script_callbacks.app_reload_callback() | |
| modules.script_callbacks.script_unloaded_callback() | |
| start_common() | |
| start_ui() | |
| modules.script_callbacks.after_ui_callback() | |
| modules.sd_models.write_metadata() | |
| load_model() | |
| shared.opts.save(shared.config_filename) | |
| if cmd_opts.profile: | |
| for k, v in modules.script_callbacks.callback_map.items(): | |
| shared.log.debug(f'Registered callbacks: {k}={len(v)} {[c.script for c in v]}') | |
| debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None | |
| debug('Trace: SCRIPTS') | |
| for m in modules.scripts.scripts_data: | |
| debug(f' {m}') | |
| debug('Loaded postprocessing scripts:') | |
| for m in modules.scripts.postprocessing_scripts_data: | |
| debug(f' {m}') | |
| modules.script_callbacks.print_timers() | |
| log.info(f"Startup time: {timer.startup.summary()}") | |
| timer.startup.reset() | |
| if not restart: | |
| # override all loggers to use the same handlers as the main logger | |
| for logger in [logging.getLogger(name) for name in logging.root.manager.loggerDict]: # pylint: disable=no-member | |
| if logger.name.startswith('uvicorn') or logger.name.startswith('sd'): | |
| continue | |
| logger.handlers = log.handlers | |
| # autolaunch only on initial start | |
| if cmd_opts.autolaunch and local_url is not None: | |
| cmd_opts.autolaunch = False | |
| shared.log.info('Launching browser') | |
| import webbrowser | |
| webbrowser.open(local_url, new=2, autoraise=True) | |
| else: | |
| for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: | |
| importlib.reload(module) | |
| return shared.demo.server | |
| def api_only(): | |
| start_common() | |
| from fastapi import FastAPI | |
| app = FastAPI(**fastapi_args) | |
| setup_middleware(app, cmd_opts) | |
| shared.api = create_api(app) | |
| shared.api.wants_restart = False | |
| modules.script_callbacks.app_started_callback(None, app) | |
| modules.sd_models.write_metadata() | |
| log.info(f"Startup time: {timer.startup.summary()}") | |
| server = shared.api.launch() | |
| return server | |
| if __name__ == "__main__": | |
| if cmd_opts.api_only: | |
| api_only() | |
| else: | |
| webui() | |