aayushmnit commited on
Commit
1173e62
1 Parent(s): 07e4d15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -16,10 +16,10 @@ def load_artifacts():
16
  '''
17
  A function to load all diffusion artifacts
18
  '''
19
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float32,use_auth_token=auth_token).to(device)
20
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float32, use_auth_token=auth_token).to(device)
21
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float32, use_auth_token=auth_token)
22
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float32, use_auth_token=auth_token).to(device)
23
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
24
  return vae, unet, tokenizer, text_encoder, scheduler
25
 
@@ -34,7 +34,7 @@ def pil_to_latents(image):
34
  Function to convert image to latents
35
  '''
36
  init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
37
- init_image = init_image.to(device=device, dtype=torch.float32)
38
  init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
39
  return init_latent_dist
40
 
@@ -57,7 +57,7 @@ def text_enc(prompts, maxlen=None):
57
  '''
58
  if maxlen is None: maxlen = tokenizer.model_max_length
59
  inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
60
- return text_encoder(inp.input_ids.to(device))[0]
61
 
62
  def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
63
  """
@@ -140,7 +140,8 @@ def improve_mask(mask):
140
  vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
141
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
142
  "runwayml/stable-diffusion-inpainting",
143
- torch_dtype=torch.float32,
 
144
  use_auth_token=auth_token
145
  ).to(device)
146
 
 
16
  '''
17
  A function to load all diffusion artifacts
18
  '''
19
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16,use_auth_token=auth_token).to(device)
20
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
21
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token)
22
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
23
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
24
  return vae, unet, tokenizer, text_encoder, scheduler
25
 
 
34
  Function to convert image to latents
35
  '''
36
  init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
37
+ init_image = init_image.to(device=device, dtype=torch.float16)
38
  init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
39
  return init_latent_dist
40
 
 
57
  '''
58
  if maxlen is None: maxlen = tokenizer.model_max_length
59
  inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
60
+ return text_encoder(inp.input_ids.to(device))[0].half()
61
 
62
  def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
63
  """
 
140
  vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
141
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
142
  "runwayml/stable-diffusion-inpainting",
143
+ revision="fp16",
144
+ torch_dtype=torch.float16,
145
  use_auth_token=auth_token
146
  ).to(device)
147