|
import gradio as gr |
|
import os |
|
import argparse |
|
from dreambooth_gui import dreambooth_tab |
|
from finetune_gui import finetune_tab |
|
from textual_inversion_gui import ti_tab |
|
from library.utilities import utilities_tab |
|
from lora_gui import lora_tab |
|
from library.class_lora_tab import LoRATools |
|
|
|
import os |
|
from library.custom_logging import setup_logging |
|
|
|
|
|
log = setup_logging() |
|
|
|
|
|
def UI(**kwargs): |
|
css = '' |
|
|
|
headless = kwargs.get('headless', False) |
|
log.info(f'headless: {headless}') |
|
|
|
if os.path.exists('./style.css'): |
|
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: |
|
log.info('Load CSS...') |
|
css += file.read() + '\n' |
|
|
|
if os.path.exists('./.release'): |
|
with open(os.path.join('./.release'), 'r', encoding='utf8') as file: |
|
release = file.read() |
|
|
|
if os.path.exists('./README.md'): |
|
with open(os.path.join('./README.md'), 'r', encoding='utf8') as file: |
|
README = file.read() |
|
|
|
interface = gr.Blocks( |
|
css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default() |
|
) |
|
|
|
with interface: |
|
with gr.Tab('Dreambooth'): |
|
( |
|
train_data_dir_input, |
|
reg_data_dir_input, |
|
output_dir_input, |
|
logging_dir_input, |
|
) = dreambooth_tab(headless=headless) |
|
with gr.Tab('LoRA'): |
|
lora_tab(headless=headless) |
|
with gr.Tab('Textual Inversion'): |
|
ti_tab(headless=headless) |
|
with gr.Tab('Finetuning'): |
|
finetune_tab(headless=headless) |
|
with gr.Tab('Utilities'): |
|
utilities_tab( |
|
train_data_dir_input=train_data_dir_input, |
|
reg_data_dir_input=reg_data_dir_input, |
|
output_dir_input=output_dir_input, |
|
logging_dir_input=logging_dir_input, |
|
enable_copy_info_button=True, |
|
headless=headless, |
|
) |
|
with gr.Tab('LoRA'): |
|
_ = LoRATools(headless=headless) |
|
with gr.Tab('About'): |
|
gr.Markdown(f'kohya_ss GUI release {release}') |
|
with gr.Tab('README'): |
|
gr.Markdown(README) |
|
|
|
htmlStr = f""" |
|
<html> |
|
<body> |
|
<div class="ver-class">{release}</div> |
|
</body> |
|
</html> |
|
""" |
|
gr.HTML(htmlStr) |
|
|
|
launch_kwargs = {} |
|
username = kwargs.get('username') |
|
password = kwargs.get('password') |
|
server_port = kwargs.get('server_port', 0) |
|
inbrowser = kwargs.get('inbrowser', False) |
|
share = kwargs.get('share', False) |
|
server_name = kwargs.get('listen') |
|
|
|
launch_kwargs['server_name'] = server_name |
|
if username and password: |
|
launch_kwargs['auth'] = (username, password) |
|
if server_port > 0: |
|
launch_kwargs['server_port'] = server_port |
|
if inbrowser: |
|
launch_kwargs['inbrowser'] = inbrowser |
|
if share: |
|
launch_kwargs['share'] = share |
|
interface.launch(**launch_kwargs) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--listen', |
|
type=str, |
|
default='127.0.0.1', |
|
help='IP to listen on for connections to Gradio', |
|
) |
|
parser.add_argument( |
|
'--username', type=str, default='', help='Username for authentication' |
|
) |
|
parser.add_argument( |
|
'--password', type=str, default='', help='Password for authentication' |
|
) |
|
parser.add_argument( |
|
'--server_port', |
|
type=int, |
|
default=0, |
|
help='Port to run the server listener on', |
|
) |
|
parser.add_argument( |
|
'--inbrowser', action='store_true', help='Open in browser' |
|
) |
|
parser.add_argument( |
|
'--share', action='store_true', help='Share the gradio UI' |
|
) |
|
|
|
parser.add_argument( |
|
'--headless', action='store_true', help='Is the server headless' |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
UI( |
|
username=args.username, |
|
password=args.password, |
|
inbrowser=args.inbrowser, |
|
server_port=args.server_port, |
|
share=args.share, |
|
listen=args.listen, |
|
headless=args.headless, |
|
) |
|
|