yamildiego commited on
Commit
1ad6f34
1 Parent(s): 317bb70

test float 16

Browse files
Files changed (3) hide show
  1. handler.py +4 -4
  2. ip_adapter/ip_adapter.py +11 -11
  3. ip_adapter/utils.py +1 -1
handler.py CHANGED
@@ -22,10 +22,10 @@ from diffusers import (
22
  # global variable
23
  MAX_SEED = np.iinfo(np.int32).max
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
- # dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
26
 
27
  # device = torch.device("cpu")
28
- dtype = torch.float32
29
 
30
  # initialization
31
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -48,13 +48,13 @@ class EndpointHandler():
48
  self.ip_ckpt = os.path.join("sdxl_models", "ip-adapter_sdxl.safetensors")
49
 
50
  self.controlnet = ControlNetModel.from_pretrained(
51
- controlnet_path, use_safetensors=False, torch_dtype=torch.float32
52
  ).to(device)
53
 
54
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
55
  base_model_path,
56
  controlnet=self.controlnet,
57
- torch_dtype=torch.float32,
58
  variant="fp16",
59
  add_watermarker=False,
60
  ).to(device)
 
22
  # global variable
23
  MAX_SEED = np.iinfo(np.int32).max
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ # dtype = torch.float16 if str(device).__contains__("cuda") else torch.float16
26
 
27
  # device = torch.device("cpu")
28
+ dtype = torch.float16
29
 
30
  # initialization
31
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
 
48
  self.ip_ckpt = os.path.join("sdxl_models", "ip-adapter_sdxl.safetensors")
49
 
50
  self.controlnet = ControlNetModel.from_pretrained(
51
+ controlnet_path, use_safetensors=False, torch_dtype=torch.float16
52
  ).to(device)
53
 
54
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
55
  base_model_path,
56
  controlnet=self.controlnet,
57
+ torch_dtype=torch.float16,
58
  variant="fp16",
59
  add_watermarker=False,
60
  ).to(device)
ip_adapter/ip_adapter.py CHANGED
@@ -102,7 +102,7 @@ class IPAdapter:
102
 
103
  # load image encoder
104
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
105
- self.device, dtype=torch.float32
106
  )
107
 
108
  self.clip_image_processor = CLIPImageProcessor()
@@ -117,7 +117,7 @@ class IPAdapter:
117
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
118
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
119
  clip_extra_context_tokens=self.num_tokens,
120
- ).to(self.device, dtype=torch.float32)
121
  return image_proj_model
122
 
123
  def set_ip_adapter(self):
@@ -147,7 +147,7 @@ class IPAdapter:
147
  cross_attention_dim=cross_attention_dim,
148
  scale=1.0,
149
  num_tokens=self.num_tokens,
150
- ).to(self.device, dtype=torch.float32)
151
  else:
152
  attn_procs[name] = IPAttnProcessor(
153
  hidden_size=hidden_size,
@@ -155,7 +155,7 @@ class IPAdapter:
155
  scale=1.0,
156
  num_tokens=self.num_tokens,
157
  skip=True
158
- ).to(self.device, dtype=torch.float32)
159
  unet.set_attn_processor(attn_procs)
160
  if hasattr(self.pipe, "controlnet"):
161
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
@@ -185,9 +185,9 @@ class IPAdapter:
185
  if isinstance(pil_image, Image.Image):
186
  pil_image = [pil_image]
187
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
188
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
189
  else:
190
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32)
191
 
192
  if content_prompt_embeds is not None:
193
  clip_image_embeds = clip_image_embeds - content_prompt_embeds
@@ -367,7 +367,7 @@ class IPAdapterPlus(IPAdapter):
367
  embedding_dim=self.image_encoder.config.hidden_size,
368
  output_dim=self.pipe.unet.config.cross_attention_dim,
369
  ff_mult=4,
370
- ).to(self.device, dtype=torch.float32)
371
  return image_proj_model
372
 
373
  @torch.inference_mode()
@@ -375,7 +375,7 @@ class IPAdapterPlus(IPAdapter):
375
  if isinstance(pil_image, Image.Image):
376
  pil_image = [pil_image]
377
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
378
- clip_image = clip_image.to(self.device, dtype=torch.float32)
379
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
380
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
381
  uncond_clip_image_embeds = self.image_encoder(
@@ -392,7 +392,7 @@ class IPAdapterFull(IPAdapterPlus):
392
  image_proj_model = MLPProjModel(
393
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
394
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
395
- ).to(self.device, dtype=torch.float32)
396
  return image_proj_model
397
 
398
 
@@ -409,7 +409,7 @@ class IPAdapterPlusXL(IPAdapter):
409
  embedding_dim=self.image_encoder.config.hidden_size,
410
  output_dim=self.pipe.unet.config.cross_attention_dim,
411
  ff_mult=4,
412
- ).to(self.device, dtype=torch.float32)
413
  return image_proj_model
414
 
415
  @torch.inference_mode()
@@ -417,7 +417,7 @@ class IPAdapterPlusXL(IPAdapter):
417
  if isinstance(pil_image, Image.Image):
418
  pil_image = [pil_image]
419
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
420
- clip_image = clip_image.to(self.device, dtype=torch.float32)
421
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
422
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
423
  uncond_clip_image_embeds = self.image_encoder(
 
102
 
103
  # load image encoder
104
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
105
+ self.device, dtype=torch.float16
106
  )
107
 
108
  self.clip_image_processor = CLIPImageProcessor()
 
117
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
118
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
119
  clip_extra_context_tokens=self.num_tokens,
120
+ ).to(self.device, dtype=torch.float16)
121
  return image_proj_model
122
 
123
  def set_ip_adapter(self):
 
147
  cross_attention_dim=cross_attention_dim,
148
  scale=1.0,
149
  num_tokens=self.num_tokens,
150
+ ).to(self.device, dtype=torch.float16)
151
  else:
152
  attn_procs[name] = IPAttnProcessor(
153
  hidden_size=hidden_size,
 
155
  scale=1.0,
156
  num_tokens=self.num_tokens,
157
  skip=True
158
+ ).to(self.device, dtype=torch.float16)
159
  unet.set_attn_processor(attn_procs)
160
  if hasattr(self.pipe, "controlnet"):
161
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
 
185
  if isinstance(pil_image, Image.Image):
186
  pil_image = [pil_image]
187
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
188
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
189
  else:
190
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
191
 
192
  if content_prompt_embeds is not None:
193
  clip_image_embeds = clip_image_embeds - content_prompt_embeds
 
367
  embedding_dim=self.image_encoder.config.hidden_size,
368
  output_dim=self.pipe.unet.config.cross_attention_dim,
369
  ff_mult=4,
370
+ ).to(self.device, dtype=torch.float16)
371
  return image_proj_model
372
 
373
  @torch.inference_mode()
 
375
  if isinstance(pil_image, Image.Image):
376
  pil_image = [pil_image]
377
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
378
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
379
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
380
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
381
  uncond_clip_image_embeds = self.image_encoder(
 
392
  image_proj_model = MLPProjModel(
393
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
394
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
395
+ ).to(self.device, dtype=torch.float16)
396
  return image_proj_model
397
 
398
 
 
409
  embedding_dim=self.image_encoder.config.hidden_size,
410
  output_dim=self.pipe.unet.config.cross_attention_dim,
411
  ff_mult=4,
412
+ ).to(self.device, dtype=torch.float16)
413
  return image_proj_model
414
 
415
  @torch.inference_mode()
 
417
  if isinstance(pil_image, Image.Image):
418
  pil_image = [pil_image]
419
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
420
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
421
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
422
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
423
  uncond_clip_image_embeds = self.image_encoder(
ip_adapter/utils.py CHANGED
@@ -35,7 +35,7 @@ def upscale(attn_map, target_size):
35
  attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
 
37
  attn_map = F.interpolate(
38
- attn_map.unsqueeze(0).to(dtype=torch.float32),
39
  size=target_size,
40
  mode='bilinear',
41
  align_corners=False
 
35
  attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
 
37
  attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float16),
39
  size=target_size,
40
  mode='bilinear',
41
  align_corners=False