zwl commited on
Commit
89a405e
1 Parent(s): 44fe15d

sdxl pipeline

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -1,4 +1,6 @@
1
  from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler
 
 
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
@@ -43,9 +45,16 @@ if torch.cuda.is_available():
43
  vae = AutoencoderKL.from_pretrained(current_model.path, subfolder="vae", torch_dtype=torch.float16, use_auth_token=auth_token)
44
  for model in models:
45
  try:
 
 
 
 
 
 
 
46
  unet = UNet2DConditionModel.from_pretrained(model.path, subfolder="unet", torch_dtype=torch.float16, use_auth_token=auth_token)
47
- model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler, use_auth_token=auth_token)
48
- model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler, use_auth_token=auth_token)
49
  except:
50
  models.remove(model)
51
  pipe = models[0].pipe_t2i
 
1
  from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler
2
+ from diffusers import StableDiffusionXLPipeline
3
+ from diffusers import StableDiffusionXLImg2ImgPipeline
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
 
45
  vae = AutoencoderKL.from_pretrained(current_model.path, subfolder="vae", torch_dtype=torch.float16, use_auth_token=auth_token)
46
  for model in models:
47
  try:
48
+ if 'XL'in model.name:
49
+ PipeClass = StableDiffusionXLPipeline
50
+ PipeI2IClass = StableDiffusionXLImg2ImgPipeline
51
+ else:
52
+ PipeClass = StableDIffusionXPipeline
53
+ PipeI2IClass = StableDiffusionImg2ImgPipeline
54
+
55
  unet = UNet2DConditionModel.from_pretrained(model.path, subfolder="unet", torch_dtype=torch.float16, use_auth_token=auth_token)
56
+ model.pipe_t2i = PipeClass.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler, use_auth_token=auth_token)
57
+ model.pipe_i2i = PipeI2IClass.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler, use_auth_token=auth_token)
58
  except:
59
  models.remove(model)
60
  pipe = models[0].pipe_t2i