nyanko7 commited on
Commit
28c56c2
1 Parent(s): fd46938

chore: add model

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -29,10 +29,9 @@ models = [
29
  ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF", 2),
30
  ("Pastal Mix", "andite/pastel-mix", 2),
31
  ("Basil Mix", "nuigurumi/basil_mix", 2),
 
32
  ]
33
 
34
- alt_models = []
35
-
36
  base_name, base_model, clip_skip = models[0]
37
 
38
  samplers_k_diffusion = [
@@ -102,57 +101,55 @@ def get_model_list():
102
  return models + alt_models
103
 
104
  te_cache = {
105
- base_name: text_encoder
106
  }
107
 
108
  unet_cache = {
109
- base_name: unet
110
  }
111
 
112
  lora_cache = {
113
- base_name: LoRANetwork(text_encoder, unet)
114
  }
115
 
116
  te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
117
  original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
 
118
 
119
  def setup_model(name, lora_state=None, lora_scale=1.0):
120
- global pipe
121
 
122
  keys = [k[0] for k in models]
123
- if name not in unet_cache:
124
- if name not in keys:
125
- raise ValueError(name)
126
- else:
127
-
128
- text_encoder = CLIPTextModel.from_pretrained(
129
- models[keys.index(name)][1],
130
- subfolder="text_encoder",
131
- torch_dtype=torch.float16,
132
- )
133
- unet = UNet2DConditionModel.from_pretrained(
134
- models[keys.index(name)][1],
135
- subfolder="unet",
136
- torch_dtype=torch.float16,
137
- )
138
-
139
- if torch.cuda.is_available():
140
- unet.to("cuda")
141
- text_encoder.to("cuda")
142
-
143
- unet_cache[name] = unet
144
- te_cache[name] = text_encoder
145
- lora_cache[name] = LoRANetwork(text_encoder, unet)
146
-
147
- local_te, local_unet, local_lora, = te_cache[name], unet_cache[name], lora_cache[name]
148
  local_unet.set_attn_processor(CrossAttnProcessor())
149
  local_lora.reset()
150
  clip_skip = models[keys.index(name)][2]
151
 
 
 
 
 
152
  if lora_state is not None and lora_state != "":
153
  local_lora.load(lora_state, lora_scale)
154
  local_lora.to(local_unet.device, dtype=local_unet.dtype)
155
 
 
156
  pipe.setup_unet(local_unet)
157
  pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
158
  pipe.tokenizer.added_tokens_encoder = {}
@@ -160,10 +157,6 @@ def setup_model(name, lora_state=None, lora_scale=1.0):
160
  pipe.setup_text_encoder(clip_skip, local_te)
161
  return pipe
162
 
163
- # precache on huggingface
164
- for model in models:
165
- setup_model(model[0])
166
-
167
  def error_str(error, title="Error"):
168
  return (
169
  f"""#### {title}
 
29
  ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF", 2),
30
  ("Pastal Mix", "andite/pastel-mix", 2),
31
  ("Basil Mix", "nuigurumi/basil_mix", 2),
32
+ ("OpenJourney V2", "prompthero/openjourney-v2", 1),
33
  ]
34
 
 
 
35
  base_name, base_model, clip_skip = models[0]
36
 
37
  samplers_k_diffusion = [
 
101
  return models + alt_models
102
 
103
  te_cache = {
104
+ base_model: text_encoder
105
  }
106
 
107
  unet_cache = {
108
+ base_model: unet
109
  }
110
 
111
  lora_cache = {
112
+ base_model: LoRANetwork(text_encoder, unet)
113
  }
114
 
115
  te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
116
  original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
117
+ current_model = base_model
118
 
119
  def setup_model(name, lora_state=None, lora_scale=1.0):
120
+ global pipe, current_model
121
 
122
  keys = [k[0] for k in models]
123
+ model = models[keys.index(name)][1]
124
+ if model not in unet_cache:
125
+ unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
126
+ text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16)
127
+
128
+ unet_cache[model] = unet
129
+ te_cache[model] = text_encoder
130
+ lora_cache[model] = LoRANetwork(text_encoder, unet)
131
+
132
+ if current_model != model:
133
+ # offload current model
134
+ unet_cache[current_model].to("cpu")
135
+ te_cache[current_model].to("cpu")
136
+ lora_cache[current_model].to("cpu")
137
+ current_model = model
138
+
139
+ local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model]
 
 
 
 
 
 
 
 
140
  local_unet.set_attn_processor(CrossAttnProcessor())
141
  local_lora.reset()
142
  clip_skip = models[keys.index(name)][2]
143
 
144
+ if torch.cuda.is_available():
145
+ local_unet.to("cuda")
146
+ local_te.to("cuda")
147
+
148
  if lora_state is not None and lora_state != "":
149
  local_lora.load(lora_state, lora_scale)
150
  local_lora.to(local_unet.device, dtype=local_unet.dtype)
151
 
152
+ pipe.text_encoder, pipe.unet = local_te, local_unet
153
  pipe.setup_unet(local_unet)
154
  pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
155
  pipe.tokenizer.added_tokens_encoder = {}
 
157
  pipe.setup_text_encoder(clip_skip, local_te)
158
  return pipe
159
 
 
 
 
 
160
  def error_str(error, title="Error"):
161
  return (
162
  f"""#### {title}