| import os |
| import warnings |
|
|
| from modules.logging_colors import logger |
| from modules.block_requests import OpenMonkeyPatch, RequestBlocker |
|
|
| os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' |
| os.environ['BITSANDBYTES_NOWELCOME'] = '1' |
| warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
|
|
| with RequestBlocker(): |
| import gradio as gr |
|
|
| import matplotlib |
| matplotlib.use('Agg') |
|
|
| import json |
| import os |
| import sys |
| import time |
| from functools import partial |
| from pathlib import Path |
| from threading import Lock |
|
|
| import yaml |
|
|
| import modules.extensions as extensions_module |
| from modules import ( |
| chat, |
| shared, |
| training, |
| ui, |
| ui_chat, |
| ui_default, |
| ui_file_saving, |
| ui_model_menu, |
| ui_notebook, |
| ui_parameters, |
| ui_session, |
| utils, |
| ) |
| from modules.extensions import apply_extensions |
| from modules.LoRA import add_lora_to_model |
| from modules.models import load_model |
| from modules.models_settings import ( |
| get_model_settings_from_yamls, |
| update_model_parameters |
| ) |
| from modules.utils import gradio |
|
|
|
|
| def create_interface(): |
|
|
| title = 'Text generation web UI' |
|
|
| |
| auth = [] |
| if shared.args.gradio_auth: |
| auth.extend(x.strip() for x in shared.args.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()) |
| if shared.args.gradio_auth_path: |
| with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: |
| auth.extend(x.strip() for line in file for x in line.split(',') if x.strip()) |
| auth = [tuple(cred.split(':')) for cred in auth] |
|
|
| |
| if shared.args.extensions is not None and len(shared.args.extensions) > 0: |
| extensions_module.load_extensions() |
|
|
| |
| shared.persistent_interface_state.update({ |
| 'loader': shared.args.loader or 'Transformers', |
| 'mode': shared.settings['mode'], |
| 'character_menu': shared.args.character or shared.settings['character'], |
| 'instruction_template': shared.settings['instruction_template'], |
| 'prompt_menu-default': shared.settings['prompt-default'], |
| 'prompt_menu-notebook': shared.settings['prompt-notebook'], |
| }) |
|
|
| if Path("cache/pfp_character.png").exists(): |
| Path("cache/pfp_character.png").unlink() |
|
|
| |
| css = ui.css |
| js = ui.js |
| css += apply_extensions('css') |
| js += apply_extensions('js') |
|
|
| |
| shared.input_elements = ui.list_interface_input_elements() |
|
|
| with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']: |
|
|
| |
| shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) |
|
|
| |
| if Path("notification.mp3").exists(): |
| shared.gradio['audio_notification'] = gr.Audio(interactive=False, value="notification.mp3", elem_id="audio_notification", visible=False) |
|
|
| |
| ui_file_saving.create_ui() |
|
|
| |
| shared.gradio['temporary_text'] = gr.Textbox(visible=False) |
|
|
| |
| ui_chat.create_ui() |
| ui_default.create_ui() |
| ui_notebook.create_ui() |
|
|
| ui_parameters.create_ui(shared.settings['preset']) |
| ui_model_menu.create_ui() |
| training.create_ui() |
| ui_session.create_ui() |
|
|
| |
| ui_chat.create_event_handlers() |
| ui_default.create_event_handlers() |
| ui_notebook.create_event_handlers() |
|
|
| |
| ui_file_saving.create_event_handlers() |
| ui_parameters.create_event_handlers() |
| ui_model_menu.create_event_handlers() |
|
|
| |
| if shared.settings['dark_theme']: |
| shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')") |
|
|
| shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}") |
| shared.gradio['interface'].load(None, gradio('show_controls'), None, _js=f'(x) => {{{ui.show_controls_js}; toggle_controls(x)}}') |
| shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False) |
| shared.gradio['interface'].load(chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display')) |
|
|
| extensions_module.create_extensions_tabs() |
| extensions_module.create_extensions_block() |
|
|
| |
| shared.gradio['interface'].queue(concurrency_count=64) |
| with OpenMonkeyPatch(): |
| shared.gradio['interface'].launch( |
| prevent_thread_lock=True, |
| share=shared.args.share, |
| server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'), |
| server_port=shared.args.listen_port, |
| inbrowser=shared.args.auto_launch, |
| auth=auth or None, |
| ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True, |
| ssl_keyfile=shared.args.ssl_keyfile, |
| ssl_certfile=shared.args.ssl_certfile |
| ) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| settings_file = None |
| if shared.args.settings is not None and Path(shared.args.settings).exists(): |
| settings_file = Path(shared.args.settings) |
| elif Path('settings.yaml').exists(): |
| settings_file = Path('settings.yaml') |
| elif Path('settings.json').exists(): |
| settings_file = Path('settings.json') |
|
|
| if settings_file is not None: |
| logger.info(f"Loading settings from {settings_file}...") |
| file_contents = open(settings_file, 'r', encoding='utf-8').read() |
| new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) |
| shared.settings.update(new_settings) |
|
|
| |
| shared.model_config['.*'] = { |
| 'wbits': 'None', |
| 'model_type': 'None', |
| 'groupsize': 'None', |
| 'pre_layer': 0, |
| 'mode': shared.settings['mode'], |
| 'skip_special_tokens': shared.settings['skip_special_tokens'], |
| 'custom_stopping_strings': shared.settings['custom_stopping_strings'], |
| 'truncation_length': shared.settings['truncation_length'], |
| 'n_gqa': 0, |
| 'rms_norm_eps': 0, |
| 'rope_freq_base': 0, |
| } |
|
|
| shared.model_config.move_to_end('.*', last=False) |
|
|
| |
| extensions_module.available_extensions = utils.get_available_extensions() |
| for extension in shared.settings['default_extensions']: |
| shared.args.extensions = shared.args.extensions or [] |
| if extension not in shared.args.extensions: |
| shared.args.extensions.append(extension) |
|
|
| available_models = utils.get_available_models() |
|
|
| |
| if shared.args.model is not None: |
| shared.model_name = shared.args.model |
|
|
| |
| elif shared.args.model_menu: |
| if len(available_models) == 0: |
| logger.error('No models are available! Please download at least one.') |
| sys.exit(0) |
| else: |
| print('The following models are available:\n') |
| for i, model in enumerate(available_models): |
| print(f'{i+1}. {model}') |
|
|
| print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') |
| i = int(input()) - 1 |
| print() |
|
|
| shared.model_name = available_models[i] |
|
|
| |
| if shared.model_name != 'None': |
| model_settings = get_model_settings_from_yamls(shared.model_name) |
| shared.settings.update(model_settings) |
| update_model_parameters(model_settings, initial=True) |
|
|
| |
| shared.model, shared.tokenizer = load_model(shared.model_name) |
| if shared.args.lora: |
| add_lora_to_model(shared.args.lora) |
|
|
| shared.generation_lock = Lock() |
|
|
| |
| create_interface() |
| while True: |
| time.sleep(0.5) |
| if shared.need_restart: |
| shared.need_restart = False |
| time.sleep(0.5) |
| shared.gradio['interface'].close() |
| time.sleep(0.5) |
| create_interface() |
|
|