nyanko7 commited on
Commit
ebb4814
1 Parent(s): 49ba457

chore: sync with upstream

Browse files
Files changed (3) hide show
  1. app.py +38 -24
  2. modules/lora.py +6 -4
  3. modules/model.py +7 -5
app.py CHANGED
@@ -94,12 +94,17 @@ pipe = StableDiffusionPipeline(
94
  )
95
 
96
  unet.set_attn_processor(CrossAttnProcessor)
 
97
  if torch.cuda.is_available():
98
  pipe = pipe.to("cuda")
99
 
100
  def get_model_list():
101
  return models + alt_models
102
 
 
 
 
 
103
  unet_cache = {
104
  base_name: unet
105
  }
@@ -108,35 +113,50 @@ lora_cache = {
108
  base_name: LoRANetwork(text_encoder, unet)
109
  }
110
 
111
- def get_model(name):
112
- local_models = models + alt_models
113
- keys = [k[0] for k in local_models]
 
114
  if name not in unet_cache:
115
  if name not in keys:
116
  raise ValueError(name)
117
  else:
 
 
 
 
 
 
118
  unet = UNet2DConditionModel.from_pretrained(
119
- local_models[keys.index(name)][1],
120
  subfolder="unet",
121
  torch_dtype=torch.float16,
122
  )
123
- unet.to("cuda")
 
 
 
 
124
  unet_cache[name] = unet
125
- lora_cache[name] = LoRANetwork(lora_cache[base_name].text_encoder_loras, unet)
126
-
127
- g_unet = unet_cache[name]
128
- g_lora = lora_cache[name]
129
- g_unet.set_attn_processor(CrossAttnProcessor())
130
- g_lora.reset()
131
- clip_skip = local_models[keys.index(name)][2]
132
- if torch.cuda.is_available():
133
- g_unet.to("cuda")
134
- g_lora.to("cuda")
135
- return g_unet, g_lora, clip_skip
 
 
 
 
136
 
137
  # precache on huggingface
138
  for model in models:
139
- get_model(model[0])
140
 
141
  def error_str(error, title="Error"):
142
  return (
@@ -218,13 +238,7 @@ def inference(
218
  restore_all()
219
  generator = torch.Generator("cuda").manual_seed(int(seed))
220
 
221
- local_unet, local_lora, clip_skip = get_model(model)
222
- pipe.set_clip_skip(clip_skip)
223
- if lora_state is not None and lora_state != "":
224
- local_lora.load(lora_state, lora_scale)
225
- local_lora.to(local_unet.device, dtype=local_unet.dtype)
226
-
227
- pipe.setup_unet(local_unet)
228
  sampler_name, sampler_opt = None, None
229
  for label, funcname, options in samplers_k_diffusion:
230
  if label == sampler:
 
94
  )
95
 
96
  unet.set_attn_processor(CrossAttnProcessor)
97
+ pipe.setup_text_encoder(clip_skip, text_encoder)
98
  if torch.cuda.is_available():
99
  pipe = pipe.to("cuda")
100
 
101
  def get_model_list():
102
  return models + alt_models
103
 
104
+ te_cache = {
105
+ base_name: text_encoder
106
+ }
107
+
108
  unet_cache = {
109
  base_name: unet
110
  }
 
113
  base_name: LoRANetwork(text_encoder, unet)
114
  }
115
 
116
+ def setup_model(name, lora_state=None, lora_scale=1.0):
117
+ global pipe
118
+
119
+ keys = [k[0] for k in models]
120
  if name not in unet_cache:
121
  if name not in keys:
122
  raise ValueError(name)
123
  else:
124
+
125
+ text_encoder = CLIPTextModel.from_pretrained(
126
+ models[keys.index(name)][1],
127
+ subfolder="text_encoder",
128
+ torch_dtype=torch.float16,
129
+ )
130
  unet = UNet2DConditionModel.from_pretrained(
131
+ models[keys.index(name)][1],
132
  subfolder="unet",
133
  torch_dtype=torch.float16,
134
  )
135
+
136
+ if torch.cuda.is_available():
137
+ unet.to("cuda")
138
+ text_encoder.to("cuda")
139
+
140
  unet_cache[name] = unet
141
+ te_cache[name] = text_encoder
142
+ lora_cache[name] = LoRANetwork(text_encoder, unet)
143
+
144
+ local_te, local_unet, local_lora, = te_cache[name], unet_cache[name], lora_cache[name]
145
+ local_unet.set_attn_processor(CrossAttnProcessor())
146
+ local_lora.reset()
147
+ clip_skip = models[keys.index(name)][2]
148
+
149
+ if lora_state is not None and lora_state != "":
150
+ local_lora.load(lora_state, lora_scale)
151
+ local_lora.to(local_unet.device, dtype=local_unet.dtype)
152
+
153
+ pipe.setup_unet(local_unet)
154
+ pipe.setup_text_encoder(clip_skip, local_te)
155
+ return pipe
156
 
157
  # precache on huggingface
158
  for model in models:
159
+ setup_model(model[0])
160
 
161
  def error_str(error, title="Error"):
162
  return (
 
238
  restore_all()
239
  generator = torch.Generator("cuda").manual_seed(int(seed))
240
 
241
+ setup_model(model, lora_state, lora_scale)
 
 
 
 
 
 
242
  sampler_name, sampler_opt = None, None
243
  for label, funcname, options in samplers_k_diffusion:
244
  if label == sampler:
modules/lora.py CHANGED
@@ -55,8 +55,9 @@ class LoRAModule(torch.nn.Module):
55
  self.org_module = org_module # remove in applying
56
  self.enable = False
57
 
58
- def resize(self, rank, alpha):
59
  self.alpha = torch.tensor(alpha)
 
60
  self.scale = alpha / rank
61
  if self.lora_down.__class__.__name__ == "Conv2d":
62
  in_dim = self.lora_down.in_channels
@@ -172,10 +173,11 @@ class LoRANetwork(torch.nn.Module):
172
  weights_to_modify += self.unet_loras
173
 
174
  for lora in self.text_encoder_loras + self.unet_loras:
175
- lora.resize(network_dim, network_alpha)
176
  if lora in weights_to_modify:
177
  lora.enable = True
178
 
179
  info = self.load_state_dict(weights, False)
180
- print(f"Weights are loaded. Unexpect keys={info.unexpected_keys}")
181
-
 
 
55
  self.org_module = org_module # remove in applying
56
  self.enable = False
57
 
58
+ def resize(self, rank, alpha, multiplier):
59
  self.alpha = torch.tensor(alpha)
60
+ self.multiplier = multiplier
61
  self.scale = alpha / rank
62
  if self.lora_down.__class__.__name__ == "Conv2d":
63
  in_dim = self.lora_down.in_channels
 
173
  weights_to_modify += self.unet_loras
174
 
175
  for lora in self.text_encoder_loras + self.unet_loras:
176
+ lora.resize(network_dim, network_alpha, scale)
177
  if lora in weights_to_modify:
178
  lora.enable = True
179
 
180
  info = self.load_state_dict(weights, False)
181
+ if len(info.unexpected_keys) > 0:
182
+ print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")
183
+
modules/model.py CHANGED
@@ -185,11 +185,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
185
  scheduler=scheduler,
186
  )
187
  self.setup_unet(self.unet)
188
- self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(
189
- self.tokenizer, self.text_encoder
190
- )
191
-
192
- def set_clip_skip(self, n):
 
 
193
  self.prompt_parser.CLIP_stop_at_last_layers = n
194
 
195
  def setup_unet(self, unet):
 
185
  scheduler=scheduler,
186
  )
187
  self.setup_unet(self.unet)
188
+ self.setup_text_encoder()
189
+
190
+ def setup_text_encoder(self, n=1, new_encoder=None):
191
+ if new_encoder is not None:
192
+ self.text_encoder = new_encoder
193
+
194
+ self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
195
  self.prompt_parser.CLIP_stop_at_last_layers = n
196
 
197
  def setup_unet(self, unet):