File size: 4,049 Bytes
cff1674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import gradio as gr
import os

from kohya_gui.utilities import utilities_tab
from kohya_gui.lora_gui import lora_tab

from kohya_gui.custom_logging import setup_logging
from kohya_gui.localization_ext import add_javascript

# Set up logging
log = setup_logging()


def UI(**kwargs):
    try:
        # Your main code goes here
        while True:
            add_javascript(kwargs.get("language"))
            css = ""

            headless = kwargs.get("headless", False)
            log.info(f"headless: {headless}")

            if os.path.exists("./assets/style.css"):
                with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
                    log.info("Load CSS...")
                    css += file.read() + "\n"

            interface = gr.Blocks(
                css=css, title="Kohya_ss GUI", theme=gr.themes.Default()
            )

            with interface:
                with gr.Tab("LoRA"):
                    (
                        train_data_dir_input,
                        reg_data_dir_input,
                        output_dir_input,
                        logging_dir_input,
                    ) = lora_tab(headless=headless)
                with gr.Tab("Utilities"):
                    utilities_tab(
                        train_data_dir_input=train_data_dir_input,
                        reg_data_dir_input=reg_data_dir_input,
                        output_dir_input=output_dir_input,
                        logging_dir_input=logging_dir_input,
                        enable_copy_info_button=True,
                        headless=headless,
                    )

            # Show the interface
            launch_kwargs = {}
            username = kwargs.get("username")
            password = kwargs.get("password")
            server_port = kwargs.get("server_port", 0)
            inbrowser = kwargs.get("inbrowser", False)
            share = kwargs.get("share", False)
            server_name = kwargs.get("listen")

            launch_kwargs["server_name"] = server_name
            if username and password:
                launch_kwargs["auth"] = (username, password)
            if server_port > 0:
                launch_kwargs["server_port"] = server_port
            if inbrowser:
                launch_kwargs["inbrowser"] = inbrowser
            if share:
                launch_kwargs["share"] = share
            log.info(launch_kwargs)
            interface.launch(**launch_kwargs)
    except KeyboardInterrupt:
        # Code to execute when Ctrl+C is pressed
        print("You pressed Ctrl+C!")


if __name__ == "__main__":
    # torch.cuda.set_per_process_memory_fraction(0.48)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--listen",
        type=str,
        default="127.0.0.1",
        help="IP to listen on for connections to Gradio",
    )
    parser.add_argument(
        "--username", type=str, default="", help="Username for authentication"
    )
    parser.add_argument(
        "--password", type=str, default="", help="Password for authentication"
    )
    parser.add_argument(
        "--server_port",
        type=int,
        default=0,
        help="Port to run the server listener on",
    )
    parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
    parser.add_argument("--share", action="store_true", help="Share the gradio UI")
    parser.add_argument(
        "--headless", action="store_true", help="Is the server headless"
    )
    parser.add_argument(
        "--language", type=str, default=None, help="Set custom language"
    )

    args = parser.parse_args()

    UI(
        username=args.username,
        password=args.password,
        inbrowser=args.inbrowser,
        server_port=args.server_port,
        share=args.share,
        listen=args.listen,
        headless=args.headless,
        language=args.language,
    )