nyanko7 commited on
Commit
a7a1fa6
1 Parent(s): a18a27e

feat: add more models and sideload

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -25,13 +25,22 @@ import modules.safe as _
25
  from modules.lora import LoRANetwork
26
 
27
  models = [
28
- ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF"),
29
- ("Basil Mix", "nuigurumi/basil_mix"),
30
- ("Pastal Mix", "andite/pastel-mix"),
31
  ]
32
 
33
- base_name, base_model, = models[0]
34
- clip_skip = 2
 
 
 
 
 
 
 
 
 
 
35
 
36
  samplers_k_diffusion = [
37
  ("Euler a", "sample_euler_ancestral", {}),
@@ -91,11 +100,9 @@ pipe = StableDiffusionPipeline(
91
  )
92
 
93
  unet.set_attn_processor(CrossAttnProcessor)
94
- pipe.set_clip_skip(clip_skip)
95
  if torch.cuda.is_available():
96
  pipe = pipe.to("cuda")
97
 
98
-
99
  def get_model_list():
100
  return models
101
 
@@ -128,10 +135,10 @@ def get_model(name):
128
  g_lora = lora_cache[name]
129
  g_unet.set_attn_processor(CrossAttnProcessor())
130
  g_lora.reset()
131
- return g_unet, g_lora
132
 
133
  # precache on huggingface
134
- for model in get_model_list():
135
  get_model(model[0])
136
 
137
  def error_str(error, title="Error"):
@@ -214,7 +221,8 @@ def inference(
214
  restore_all()
215
  generator = torch.Generator("cuda").manual_seed(int(seed))
216
 
217
- local_unet, local_lora = get_model(model)
 
218
  if lora_state is not None and lora_state != "":
219
  local_lora.load(lora_state, lora_scale)
220
  local_lora.to(local_unet.device, dtype=local_unet.dtype)
@@ -279,6 +287,9 @@ def inference(
279
  if embs is not None and len(embs) > 0:
280
  restore_all()
281
 
 
 
 
282
  end_time = time.time()
283
  vram_free, vram_total = torch.cuda.mem_get_info()
284
  print(f"done: res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
 
25
  from modules.lora import LoRANetwork
26
 
27
  models = [
28
+ ("AbyssOrangeMix2 (fast)", "Korakoe/AbyssOrangeMix2-HF", 2),
29
+ ("Pastal Mix (fast)", "andite/pastel-mix", 2),
 
30
  ]
31
 
32
+ alt_models = [
33
+ ("ACertainModel", "JosephusCheung/ACertainModel", 2),
34
+ ("Basil Mix", "nuigurumi/basil_mix", 2),
35
+ ("Stable Diffusion V1.5", "runwayml/stable-diffusion-v1-5", 1),
36
+ ("Anything V3.0", "Linaqruf/anything-v3.0", 2),
37
+ ("Open Journey", "prompthero/openjourney", 1),
38
+ ("Eimis AnimeDiffusion", "eimiss/EimisAnimeDiffusion_1.0v", 2)
39
+ ("Dreamlike Photoreal 2.0", "dreamlike-art/dreamlike-photoreal-2.0", 1)
40
+ ("Redshift Diffusion", "nitrosocke/redshift-diffusion", 1)
41
+ ]
42
+
43
+ base_name, base_model, clip_skip = models[0]
44
 
45
  samplers_k_diffusion = [
46
  ("Euler a", "sample_euler_ancestral", {}),
 
100
  )
101
 
102
  unet.set_attn_processor(CrossAttnProcessor)
 
103
  if torch.cuda.is_available():
104
  pipe = pipe.to("cuda")
105
 
 
106
  def get_model_list():
107
  return models
108
 
 
135
  g_lora = lora_cache[name]
136
  g_unet.set_attn_processor(CrossAttnProcessor())
137
  g_lora.reset()
138
+ return g_unet, g_lora, models[keys.index(name)][2]
139
 
140
  # precache on huggingface
141
+ for model in models:
142
  get_model(model[0])
143
 
144
  def error_str(error, title="Error"):
 
221
  restore_all()
222
  generator = torch.Generator("cuda").manual_seed(int(seed))
223
 
224
+ local_unet, local_lora, clip_skip = get_model(model)
225
+ pipe.set_clip_skip(clip_skip)
226
  if lora_state is not None and lora_state != "":
227
  local_lora.load(lora_state, lora_scale)
228
  local_lora.to(local_unet.device, dtype=local_unet.dtype)
 
287
  if embs is not None and len(embs) > 0:
288
  restore_all()
289
 
290
+ if model in [key[0] for key in alt_models]:
291
+ local_unet.to("cpu")
292
+
293
  end_time = time.time()
294
  vram_free, vram_total = torch.cuda.mem_get_info()
295
  print(f"done: res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")