anzorq commited on
Commit
b8843c9
·
1 Parent(s): 5983023

Pre-load all models in RAM

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -13,6 +13,8 @@ class Model:
13
  self.name = name
14
  self.path = path
15
  self.prefix = prefix
 
 
16
 
17
  models = [
18
  Model("Custom model", "", ""),
@@ -27,17 +29,24 @@ models = [
27
  Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""),
28
  Model("Robo Diffusion", "nousr/robo-diffusion", ""),
29
  Model("Cyberpunk Anime", "DGSpitzer/Cyberpunk-Anime-Diffusion", "dgs illustration style "),
30
- Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy ")
31
  ]
32
 
33
  last_mode = "txt2img"
34
  current_model = models[1]
35
  current_model_path = current_model.path
36
- pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16)
37
- # pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16)
38
- if torch.cuda.is_available():
39
- pipe = pipe.to("cuda")
40
- # pipe_i2i = pipe_i2i.to("cuda")
 
 
 
 
 
 
 
41
 
42
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
43
 
@@ -69,7 +78,12 @@ def txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, g
69
  if model_path != current_model_path or last_mode != "txt2img":
70
  current_model_path = model_path
71
 
72
- pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16)
 
 
 
 
 
73
  if torch.cuda.is_available():
74
  pipe = pipe.to("cuda")
75
  last_mode = "txt2img"
@@ -95,7 +109,11 @@ def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, w
95
  if model_path != current_model_path or last_mode != "img2img":
96
  current_model_path = model_path
97
 
98
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16)
 
 
 
 
99
 
100
  if torch.cuda.is_available():
101
  pipe = pipe.to("cuda")
 
13
  self.name = name
14
  self.path = path
15
  self.prefix = prefix
16
+ self.pipe_t2i = None
17
+ self.pipe_i2i = None
18
 
19
  models = [
20
  Model("Custom model", "", ""),
 
29
  Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""),
30
  Model("Robo Diffusion", "nousr/robo-diffusion", ""),
31
  Model("Cyberpunk Anime", "DGSpitzer/Cyberpunk-Anime-Diffusion", "dgs illustration style "),
32
+ Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy")
33
  ]
34
 
35
  last_mode = "txt2img"
36
  current_model = models[1]
37
  current_model_path = current_model.path
38
+
39
+ if is_colab:
40
+ pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16)
41
+ if torch.cuda.is_available():
42
+ pipe = pipe.to("cuda")
43
+
44
+ else: # download all models
45
+ vae = AutoencoderKL.from_pretrained(current_model, subfolder="vae", torch_dtype=torch.float16)
46
+ for model in models[1:]:
47
+ unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
48
+ model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model, unet=unet, vae=vae, torch_dtype=torch.float16)
49
+ model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model, unet=unet, vae=vae, torch_dtype=torch.float16)
50
 
51
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
52
 
 
78
  if model_path != current_model_path or last_mode != "txt2img":
79
  current_model_path = model_path
80
 
81
+ if is_colab:
82
+ pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16)
83
+ else:
84
+ pipe = pipe.to("cpu")
85
+ pipe = current_model.pipe_t2i
86
+
87
  if torch.cuda.is_available():
88
  pipe = pipe.to("cuda")
89
  last_mode = "txt2img"
 
109
  if model_path != current_model_path or last_mode != "img2img":
110
  current_model_path = model_path
111
 
112
+ if is_colab:
113
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16)
114
+ else:
115
+ pipe = pipe.to("cpu")
116
+ pipe = current_model.pipe_t2i
117
 
118
  if torch.cuda.is_available():
119
  pipe = pipe.to("cuda")