File size: 7,668 Bytes
ef9fd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os

from modules.call_queue import wrap_gradio_call
from modules.hypernetworks.ui import keys
import modules.scripts as scripts
from modules import script_callbacks, shared, sd_hijack
import gradio as gr

from modules.paths import script_path
from modules.ui import create_refresh_button, gr_show
import patches.clip_hijack as clip_hijack
import patches.textual_inversion as textual_inversion
import patches.ui as ui
import patches.shared as shared_patch
import patches.external_pr.ui as external_patch_ui
from webui import wrap_gradio_gpu_call

setattr(shared.opts,'pin_memory', False)

def create_extension_tab(params=None):
    with gr.Tab(label="Create Beta hypernetwork") as create_beta:
        new_hypernetwork_name = gr.Textbox(label="Name")
        new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1024", "1280"],
                                                  choices=["768", "320", "640", "1024", "1280"])
        new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure",
                                                      placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
        new_hypernetwork_activation_func = gr.Dropdown(value="linear",
                                                       label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)",
                                                       choices=keys)
        new_hypernetwork_initialization_option = gr.Dropdown(value="Normal",
                                                             label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise",
                                                             choices=["Normal", "KaimingUniform", "KaimingNormal",
                                                                      "XavierUniform", "XavierNormal"])
        show_additional_options = gr.Checkbox(
            label='Show advanced options')
        with gr.Row(visible=False) as weight_options:
            generation_seed = gr.Number(label='Weight initialization seed, set -1 for default', value=-1, precision=0)
            normal_std = gr.Textbox(label="Standard Deviation for Normal weight initialization", placeholder="must be positive float", value="0.01")
        show_additional_options.change(
            fn=lambda show: gr_show(show),
            inputs=[show_additional_options],
            outputs=[weight_options],)
        new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
        new_hypernetwork_use_dropout = gr.Checkbox(
            label="Use dropout. Might improve training when dataset is small / limited.")
        new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0",
                                                        label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15",
                                                        placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
        skip_connection = gr.Checkbox(label="Use skip-connection. Won't work without extension!")
        optional_info = gr.Textbox("", label="Optional information about Hypernetwork", placeholder="Training information, dateset, etc")
        overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")

        with gr.Row():
            with gr.Column(scale=3):
                gr.HTML(value="")

            with gr.Column():
                create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
        setting_name = gr.Textbox(label="Setting file name", value="")
        save_setting = gr.Button(value="Save hypernetwork setting to file")
        ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
        ti_outcome = gr.HTML(elem_id="ti_error2", value="")



        save_setting.click(
            fn=wrap_gradio_call(external_patch_ui.save_hypernetwork_setting),
            inputs=[
                setting_name,
                new_hypernetwork_sizes,
                overwrite_old_hypernetwork,
                new_hypernetwork_layer_structure,
                new_hypernetwork_activation_func,
                new_hypernetwork_initialization_option,
                new_hypernetwork_add_layer_norm,
                new_hypernetwork_use_dropout,
                new_hypernetwork_dropout_structure,
                optional_info,
                generation_seed if generation_seed.visible else None,
                normal_std if normal_std.visible else 0.01,
                skip_connection],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )
        create_hypernetwork.click(
            fn=ui.create_hypernetwork,
            inputs=[
                new_hypernetwork_name,
                new_hypernetwork_sizes,
                overwrite_old_hypernetwork,
                new_hypernetwork_layer_structure,
                new_hypernetwork_activation_func,
                new_hypernetwork_initialization_option,
                new_hypernetwork_add_layer_norm,
                new_hypernetwork_use_dropout,
                new_hypernetwork_dropout_structure,
                optional_info,
                generation_seed if generation_seed.visible else None,
                normal_std if normal_std.visible else 0.01,
                skip_connection
            ],
            outputs=[
                new_hypernetwork_name,
                ti_output,
                ti_outcome,
            ]
        )
    return [(create_beta, "Create_beta", "create_beta")]


def create_extension_tab2(params=None):
    with gr.Blocks(analytics_enabled=False) as CLIP_test_interface:
        with gr.Tab(label="CLIP-test") as clip_test:
            with gr.Row():
                clipTextModelPath = gr.Textbox("openai/clip-vit-large-patch14", label="CLIP Text models. Set to empty to not change.")
                # see https://huggingface.co/openai/clip-vit-large-patch14 and related pages to find model.
                change_model = gr.Checkbox(label="Enable clip model change. This will be triggered from next model changes.")
                change_model.change(
                    fn=clip_hijack.trigger_sd_hijack,
                    inputs=[
                        change_model,
                        clipTextModelPath
                    ],
                    outputs=[]
                )
    return [(CLIP_test_interface, "CLIP_test", "clip_test")]

def on_ui_settings():
    shared.opts.add_option("disable_ema",
        shared.OptionInfo(False, "Detach grad from conditioning models",
        section=('training', "Training")))
    if not hasattr(shared.opts, 'training_enable_tensorboard'):
        shared.opts.add_option("training_enable_tensorboard",
                               shared.OptionInfo(False, "Enable tensorboard logging",
                                                 section=('training', "Training")))

#script_callbacks.on_ui_train_tabs(create_training_tab)   # Deprecate Beta Training
script_callbacks.on_ui_train_tabs(create_extension_tab)
script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_gamma_tab)
script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_tuning)
script_callbacks.on_ui_tabs(create_extension_tab2)
script_callbacks.on_ui_settings(on_ui_settings)
class Script(scripts.Script):
    def title(self):
        return "Hypernetwork Monkey Patch"

    def show(self, _):
        return scripts.AlwaysVisible