TallisHe commited on
Commit
9645472
1 Parent(s): 56568f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +482 -0
app.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import time
6
+ import importlib
7
+ import signal
8
+ import re
9
+ import warnings
10
+ import json
11
+ from threading import Thread
12
+ from typing import Iterable
13
+
14
+ from fastapi import FastAPI
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.middleware.gzip import GZipMiddleware
17
+ from packaging import version
18
+
19
+ import logging
20
+
21
+ # We can't use cmd_opts for this because it will not have been initialized at this point.
22
+ log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
23
+ if log_level:
24
+ log_level = getattr(logging, log_level.upper(), None) or logging.INFO
25
+ logging.basicConfig(
26
+ level=log_level,
27
+ format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
28
+ datefmt='%Y-%m-%d %H:%M:%S',
29
+ )
30
+
31
+ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
32
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
33
+
34
+ from modules import timer
35
+ startup_timer = timer.startup_timer
36
+ startup_timer.record("launcher")
37
+
38
+ import torch
39
+ import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
40
+ warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
41
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
42
+ startup_timer.record("import torch")
43
+
44
+ import gradio # noqa: F401
45
+ startup_timer.record("import gradio")
46
+
47
+ from modules import paths, timer, import_hook, errors, devices # noqa: F401
48
+ startup_timer.record("setup paths")
49
+
50
+ import ldm.modules.encoders.modules # noqa: F401
51
+ startup_timer.record("import ldm")
52
+
53
+ from modules import extra_networks
54
+ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
55
+
56
+ # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
57
+ if ".dev" in torch.__version__ or "+git" in torch.__version__:
58
+ torch.__long_version__ = torch.__version__
59
+ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
60
+
61
+ from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
62
+ import modules.codeformer_model as codeformer
63
+ import modules.face_restoration
64
+ import modules.gfpgan_model as gfpgan
65
+ import modules.img2img
66
+
67
+ import modules.lowvram
68
+ import modules.scripts
69
+ import modules.sd_hijack
70
+ import modules.sd_hijack_optimizations
71
+ import modules.sd_models
72
+ import modules.sd_vae
73
+ import modules.sd_unet
74
+ import modules.txt2img
75
+ import modules.script_callbacks
76
+ import modules.textual_inversion.textual_inversion
77
+ import modules.progress
78
+
79
+ import modules.ui
80
+ from modules import modelloader
81
+ from modules.shared import cmd_opts
82
+ import modules.hypernetworks.hypernetwork
83
+
84
+ startup_timer.record("other imports")
85
+
86
+
87
+ if cmd_opts.server_name:
88
+ server_name = cmd_opts.server_name
89
+ else:
90
+ server_name = "0.0.0.0" if cmd_opts.listen else None
91
+
92
+
93
+ def fix_asyncio_event_loop_policy():
94
+ """
95
+ The default `asyncio` event loop policy only automatically creates
96
+ event loops in the main threads. Other threads must create event
97
+ loops explicitly or `asyncio.get_event_loop` (and therefore
98
+ `.IOLoop.current`) will fail. Installing this policy allows event
99
+ loops to be created automatically on any thread, matching the
100
+ behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
101
+ """
102
+
103
+ import asyncio
104
+
105
+ if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
106
+ # "Any thread" and "selector" should be orthogonal, but there's not a clean
107
+ # interface for composing policies so pick the right base.
108
+ _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
109
+ else:
110
+ _BasePolicy = asyncio.DefaultEventLoopPolicy
111
+
112
+ class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
113
+ """Event loop policy that allows loop creation on any thread.
114
+ Usage::
115
+
116
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
117
+ """
118
+
119
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
120
+ try:
121
+ return super().get_event_loop()
122
+ except (RuntimeError, AssertionError):
123
+ # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
124
+ # and changed to a RuntimeError in 3.4.3.
125
+ # "There is no current event loop in thread %r"
126
+ loop = self.new_event_loop()
127
+ self.set_event_loop(loop)
128
+ return loop
129
+
130
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
131
+
132
+
133
+ def check_versions():
134
+ if shared.cmd_opts.skip_version_check:
135
+ return
136
+
137
+ expected_torch_version = "2.0.0"
138
+
139
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
140
+ errors.print_error_explanation(f"""
141
+ You are running torch {torch.__version__}.
142
+ The program is tested to work with torch {expected_torch_version}.
143
+ To reinstall the desired version, run with commandline flag --reinstall-torch.
144
+ Beware that this will cause a lot of large files to be downloaded, as well as
145
+ there are reports of issues with training tab on the latest version.
146
+
147
+ Use --skip-version-check commandline argument to disable this check.
148
+ """.strip())
149
+
150
+ expected_xformers_version = "0.0.20"
151
+ if shared.xformers_available:
152
+ import xformers
153
+
154
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
155
+ errors.print_error_explanation(f"""
156
+ You are running xformers {xformers.__version__}.
157
+ The program is tested to work with xformers {expected_xformers_version}.
158
+ To reinstall the desired version, run with commandline flag --reinstall-xformers.
159
+
160
+ Use --skip-version-check commandline argument to disable this check.
161
+ """.strip())
162
+
163
+
164
+ def restore_config_state_file():
165
+ config_state_file = shared.opts.restore_config_state_file
166
+ if config_state_file == "":
167
+ return
168
+
169
+ shared.opts.restore_config_state_file = ""
170
+ shared.opts.save(shared.config_filename)
171
+
172
+ if os.path.isfile(config_state_file):
173
+ print(f"*** About to restore extension state from file: {config_state_file}")
174
+ with open(config_state_file, "r", encoding="utf-8") as f:
175
+ config_state = json.load(f)
176
+ config_states.restore_extension_config(config_state)
177
+ startup_timer.record("restore extension config")
178
+ elif config_state_file:
179
+ print(f"!!! Config state backup not found: {config_state_file}")
180
+
181
+
182
+ def validate_tls_options():
183
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
184
+ return
185
+
186
+ try:
187
+ if not os.path.exists(cmd_opts.tls_keyfile):
188
+ print("Invalid path to TLS keyfile given")
189
+ if not os.path.exists(cmd_opts.tls_certfile):
190
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
191
+ except TypeError:
192
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
193
+ print("TLS setup invalid, running webui without TLS")
194
+ else:
195
+ print("Running with TLS")
196
+ startup_timer.record("TLS")
197
+
198
+
199
+ def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
200
+ """
201
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
202
+ an iterable of (username, password) tuples.
203
+ """
204
+ def process_credential_line(s) -> tuple[str, ...] | None:
205
+ s = s.strip()
206
+ if not s:
207
+ return None
208
+ return tuple(s.split(':', 1))
209
+
210
+ if cmd_opts.gradio_auth:
211
+ for cred in cmd_opts.gradio_auth.split(','):
212
+ cred = process_credential_line(cred)
213
+ if cred:
214
+ yield cred
215
+
216
+ if cmd_opts.gradio_auth_path:
217
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
218
+ for line in file.readlines():
219
+ for cred in line.strip().split(','):
220
+ cred = process_credential_line(cred)
221
+ if cred:
222
+ yield cred
223
+
224
+
225
+ def configure_sigint_handler():
226
+ # make the program just exit at ctrl+c without waiting for anything
227
+ def sigint_handler(sig, frame):
228
+ print(f'Interrupted with signal {sig} in {frame}')
229
+ os._exit(0)
230
+
231
+ if not os.environ.get("COVERAGE_RUN"):
232
+ # Don't install the immediate-quit handler when running under coverage,
233
+ # as then the coverage report won't be generated.
234
+ signal.signal(signal.SIGINT, sigint_handler)
235
+
236
+
237
+ def configure_opts_onchange():
238
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
239
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
240
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
241
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
242
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
243
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
244
+ startup_timer.record("opts onchange")
245
+
246
+
247
+ def initialize():
248
+ fix_asyncio_event_loop_policy()
249
+ validate_tls_options()
250
+ configure_sigint_handler()
251
+ check_versions()
252
+ modelloader.cleanup_models()
253
+ configure_opts_onchange()
254
+
255
+ modules.sd_models.setup_model()
256
+ startup_timer.record("setup SD model")
257
+
258
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
259
+ startup_timer.record("setup codeformer")
260
+
261
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
262
+ startup_timer.record("setup gfpgan")
263
+
264
+ initialize_rest(reload_script_modules=False)
265
+
266
+
267
+ def initialize_rest(*, reload_script_modules=False):
268
+ """
269
+ Called both from initialize() and when reloading the webui.
270
+ """
271
+ sd_samplers.set_samplers()
272
+ extensions.list_extensions()
273
+ startup_timer.record("list extensions")
274
+
275
+ restore_config_state_file()
276
+
277
+ if cmd_opts.ui_debug_mode:
278
+ shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
279
+ modules.scripts.load_scripts()
280
+ return
281
+
282
+ modules.sd_models.list_models()
283
+ startup_timer.record("list SD models")
284
+
285
+ localization.list_localizations(cmd_opts.localizations_dir)
286
+
287
+ with startup_timer.subcategory("load scripts"):
288
+ modules.scripts.load_scripts()
289
+
290
+ if reload_script_modules:
291
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
292
+ importlib.reload(module)
293
+ startup_timer.record("reload script modules")
294
+
295
+ modelloader.load_upscalers()
296
+ startup_timer.record("load upscalers")
297
+
298
+ modules.sd_vae.refresh_vae_list()
299
+ startup_timer.record("refresh VAE")
300
+ modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
301
+ startup_timer.record("refresh textual inversion templates")
302
+
303
+ modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
304
+ modules.sd_hijack.list_optimizers()
305
+ startup_timer.record("scripts list_optimizers")
306
+
307
+ modules.sd_unet.list_unets()
308
+ startup_timer.record("scripts list_unets")
309
+
310
+ def load_model():
311
+ """
312
+ Accesses shared.sd_model property to load model.
313
+ After it's available, if it has been loaded before this access by some extension,
314
+ its optimization may be None because the list of optimizaers has neet been filled
315
+ by that time, so we apply optimization again.
316
+ """
317
+
318
+ shared.sd_model # noqa: B018
319
+
320
+ if modules.sd_hijack.current_optimizer is None:
321
+ modules.sd_hijack.apply_optimizations()
322
+
323
+ Thread(target=load_model).start()
324
+
325
+ Thread(target=devices.first_time_calculation).start()
326
+
327
+ shared.reload_hypernetworks()
328
+ startup_timer.record("reload hypernetworks")
329
+
330
+ ui_extra_networks.initialize()
331
+ ui_extra_networks.register_default_pages()
332
+
333
+ extra_networks.initialize()
334
+ extra_networks.register_default_extra_networks()
335
+ startup_timer.record("initialize extra networks")
336
+
337
+
338
+ def setup_middleware(app):
339
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
340
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
341
+ configure_cors_middleware(app)
342
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
343
+
344
+
345
+ def configure_cors_middleware(app):
346
+ cors_options = {
347
+ "allow_methods": ["*"],
348
+ "allow_headers": ["*"],
349
+ "allow_credentials": True,
350
+ }
351
+ if cmd_opts.cors_allow_origins:
352
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
353
+ if cmd_opts.cors_allow_origins_regex:
354
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
355
+ app.add_middleware(CORSMiddleware, **cors_options)
356
+
357
+
358
+ def create_api(app):
359
+ from modules.api.api import Api
360
+ api = Api(app, queue_lock)
361
+ return api
362
+
363
+
364
+ def api_only():
365
+ initialize()
366
+
367
+ app = FastAPI()
368
+ setup_middleware(app)
369
+ api = create_api(app)
370
+
371
+ modules.script_callbacks.app_started_callback(None, app)
372
+
373
+ print(f"Startup time: {startup_timer.summary()}.")
374
+ api.launch(
375
+ server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
376
+ port=cmd_opts.port if cmd_opts.port else 7861,
377
+ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
378
+ )
379
+
380
+
381
+ def webui():
382
+ launch_api = cmd_opts.api
383
+ initialize()
384
+
385
+ while 1:
386
+ if shared.opts.clean_temp_dir_at_start:
387
+ ui_tempdir.cleanup_tmpdr()
388
+ startup_timer.record("cleanup temp dir")
389
+
390
+ modules.script_callbacks.before_ui_callback()
391
+ startup_timer.record("scripts before_ui_callback")
392
+
393
+ shared.demo = modules.ui.create_ui()
394
+ startup_timer.record("create ui")
395
+
396
+ if not cmd_opts.no_gradio_queue:
397
+ shared.demo.queue(64)
398
+
399
+ gradio_auth_creds = list(get_gradio_auth_creds()) or None
400
+
401
+ app, local_url, share_url = shared.demo.launch(
402
+ share=cmd_opts.share,
403
+ server_name=server_name,
404
+ server_port=cmd_opts.port,
405
+ ssl_keyfile=cmd_opts.tls_keyfile,
406
+ ssl_certfile=cmd_opts.tls_certfile,
407
+ ssl_verify=cmd_opts.disable_tls_verify,
408
+ debug=cmd_opts.gradio_debug,
409
+ auth=gradio_auth_creds,
410
+ inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
411
+ prevent_thread_lock=True,
412
+ allowed_paths=cmd_opts.gradio_allowed_path,
413
+ app_kwargs={
414
+ "docs_url": "/docs",
415
+ "redoc_url": "/redoc",
416
+ },
417
+ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
418
+ )
419
+
420
+ # after initial launch, disable --autolaunch for subsequent restarts
421
+ cmd_opts.autolaunch = False
422
+
423
+ startup_timer.record("gradio launch")
424
+
425
+ # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
426
+ # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
427
+ # running web ui and do whatever the attacker wants, including installing an extension and
428
+ # running its code. We disable this here. Suggested by RyotaK.
429
+ app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
430
+
431
+ setup_middleware(app)
432
+
433
+ modules.progress.setup_progress_api(app)
434
+ modules.ui.setup_ui_api(app)
435
+
436
+ if launch_api:
437
+ create_api(app)
438
+
439
+ ui_extra_networks.add_pages_to_demo(app)
440
+
441
+ startup_timer.record("add APIs")
442
+
443
+ with startup_timer.subcategory("app_started_callback"):
444
+ modules.script_callbacks.app_started_callback(shared.demo, app)
445
+
446
+ timer.startup_record = startup_timer.dump()
447
+ print(f"Startup time: {startup_timer.summary()}.")
448
+
449
+ try:
450
+ while True:
451
+ server_command = shared.state.wait_for_server_command(timeout=5)
452
+ if server_command:
453
+ if server_command in ("stop", "restart"):
454
+ break
455
+ else:
456
+ print(f"Unknown server command: {server_command}")
457
+ except KeyboardInterrupt:
458
+ print('Caught KeyboardInterrupt, stopping...')
459
+ server_command = "stop"
460
+
461
+ if server_command == "stop":
462
+ print("Stopping server...")
463
+ # If we catch a keyboard interrupt, we want to stop the server and exit.
464
+ shared.demo.close()
465
+ break
466
+
467
+ print('Restarting UI...')
468
+ shared.demo.close()
469
+ time.sleep(0.5)
470
+ startup_timer.reset()
471
+ modules.script_callbacks.app_reload_callback()
472
+ startup_timer.record("app reload callback")
473
+ modules.script_callbacks.script_unloaded_callback()
474
+ startup_timer.record("scripts unloaded callback")
475
+ initialize_rest(reload_script_modules=True)
476
+
477
+
478
+ if __name__ == "__main__":
479
+ if cmd_opts.nowebui:
480
+ api_only()
481
+ else:
482
+ webui()