Spaces:
Build error
Build error
File size: 3,358 Bytes
f889344 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import argparse
import gradio as gr
import os
from kohya_gui.dreambooth_gui import dreambooth_tab
from kohya_gui.utilities import utilities_tab
from kohya_gui.custom_logging import setup_logging
from kohya_gui.localization_ext import add_javascript
# Set up logging
log = setup_logging()
def UI(**kwargs):
add_javascript(kwargs.get("language"))
css = ""
headless = kwargs.get("headless", False)
log.info(f"headless: {headless}")
if os.path.exists("./assets/style.css"):
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
log.info("Load CSS...")
css += file.read() + "\n"
interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default())
with interface:
with gr.Tab("Dreambooth"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(headless=headless)
with gr.Tab("Utilities"):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
headless=headless,
)
# Show the interface
launch_kwargs = {}
username = kwargs.get("username")
password = kwargs.get("password")
server_port = kwargs.get("server_port", 0)
inbrowser = kwargs.get("inbrowser", False)
share = kwargs.get("share", False)
server_name = kwargs.get("listen")
launch_kwargs["server_name"] = server_name
if username and password:
launch_kwargs["auth"] = (username, password)
if server_port > 0:
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs["inbrowser"] = inbrowser
if share:
launch_kwargs["share"] = share
interface.launch(**launch_kwargs)
if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
"--listen",
type=str,
default="127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument(
"--username", type=str, default="", help="Username for authentication"
)
parser.add_argument(
"--password", type=str, default="", help="Password for authentication"
)
parser.add_argument(
"--server_port",
type=int,
default=0,
help="Port to run the server listener on",
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--headless", action="store_true", help="Is the server headless"
)
parser.add_argument(
"--language", type=str, default=None, help="Set custom language"
)
args = parser.parse_args()
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
share=args.share,
listen=args.listen,
headless=args.headless,
language=args.language,
)
|