import gradio as gr import os import argparse from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab from textual_inversion_gui import ti_tab from library.utilities import utilities_tab from lora_gui import lora_tab from library.class_lora_tab import LoRATools import os from library.custom_logging import setup_logging # Set up logging log = setup_logging() def UI(**kwargs): css = '' headless = kwargs.get('headless', False) log.info(f'headless: {headless}') if os.path.exists('./style.css'): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: log.info('Load CSS...') css += file.read() + '\n' if os.path.exists('./.release'): with open(os.path.join('./.release'), 'r', encoding='utf8') as file: release = file.read() if os.path.exists('./README.md'): with open(os.path.join('./README.md'), 'r', encoding='utf8') as file: README = file.read() interface = gr.Blocks( css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default() ) with interface: with gr.Tab('Dreambooth'): ( train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input, ) = dreambooth_tab(headless=headless) with gr.Tab('LoRA'): lora_tab(headless=headless) with gr.Tab('Textual Inversion'): ti_tab(headless=headless) with gr.Tab('Finetuning'): finetune_tab(headless=headless) with gr.Tab('Utilities'): utilities_tab( train_data_dir_input=train_data_dir_input, reg_data_dir_input=reg_data_dir_input, output_dir_input=output_dir_input, logging_dir_input=logging_dir_input, enable_copy_info_button=True, headless=headless, ) with gr.Tab('LoRA'): _ = LoRATools(headless=headless) with gr.Tab('About'): gr.Markdown(f'kohya_ss GUI release {release}') with gr.Tab('README'): gr.Markdown(README) htmlStr = f"""
{release}
""" gr.HTML(htmlStr) # Show the interface launch_kwargs = {} username = kwargs.get('username') password = kwargs.get('password') server_port = kwargs.get('server_port', 0) inbrowser = kwargs.get('inbrowser', False) share = kwargs.get('share', False) server_name = kwargs.get('listen') launch_kwargs['server_name'] = server_name if username and password: launch_kwargs['auth'] = (username, password) if server_port > 0: launch_kwargs['server_port'] = server_port if inbrowser: launch_kwargs['inbrowser'] = inbrowser if share: launch_kwargs['share'] = share interface.launch(**launch_kwargs) if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() parser.add_argument( '--listen', type=str, default='127.0.0.1', help='IP to listen on for connections to Gradio', ) parser.add_argument( '--username', type=str, default='', help='Username for authentication' ) parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) parser.add_argument( '--server_port', type=int, default=0, help='Port to run the server listener on', ) parser.add_argument( '--inbrowser', action='store_true', help='Open in browser' ) parser.add_argument( '--share', action='store_true', help='Share the gradio UI' ) parser.add_argument( '--headless', action='store_true', help='Is the server headless' ) args = parser.parse_args() UI( username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port, share=args.share, listen=args.listen, headless=args.headless, )