supertori commited on
Commit
fa26b56
1 Parent(s): 77409b9

Upload composable_lora_script.py

Browse files
Files changed (1) hide show
  1. composable_lora_script.py +57 -0
composable_lora_script.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Composable-Diffusion with Lora
3
+ #
4
+ import torch
5
+ import gradio as gr
6
+
7
+ import composable_lora
8
+ import modules.scripts as scripts
9
+ from modules import script_callbacks
10
+ from modules.processing import StableDiffusionProcessing
11
+
12
+
13
+ def unload():
14
+ torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
15
+ torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
16
+
17
+
18
+ if not hasattr(torch.nn, 'Linear_forward_before_lora'):
19
+ torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
20
+
21
+ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
22
+ torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
23
+
24
+ torch.nn.Linear.forward = composable_lora.lora_Linear_forward
25
+ torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward
26
+
27
+ script_callbacks.on_script_unloaded(unload)
28
+
29
+
30
+ class ComposableLoraScript(scripts.Script):
31
+ def title(self):
32
+ return "Composable Lora"
33
+
34
+ def show(self, is_img2img):
35
+ return scripts.AlwaysVisible
36
+
37
+ def ui(self, is_img2img):
38
+ with gr.Group():
39
+ with gr.Accordion("Composable Lora", open=False):
40
+ enabled = gr.Checkbox(value=False, label="Enabled")
41
+ opt_uc_text_model_encoder = gr.Checkbox(value=False, label="Use Lora in uc text model encoder")
42
+ opt_uc_diffusion_model = gr.Checkbox(value=False, label="Use Lora in uc diffusion model")
43
+
44
+ return [enabled, opt_uc_text_model_encoder, opt_uc_diffusion_model]
45
+
46
+ def process(self, p: StableDiffusionProcessing, enabled: bool, opt_uc_text_model_encoder: bool, opt_uc_diffusion_model: bool):
47
+ composable_lora.enabled = enabled
48
+ composable_lora.opt_uc_text_model_encoder = opt_uc_text_model_encoder
49
+ composable_lora.opt_uc_diffusion_model = opt_uc_diffusion_model
50
+
51
+ composable_lora.num_batches = p.batch_size
52
+
53
+ prompt = p.all_prompts[0]
54
+ composable_lora.load_prompt_loras(prompt)
55
+
56
+ def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
57
+ composable_lora.reset_counters()