yamildiego commited on
Commit
43e69fc
1 Parent(s): 850b601

change float 16 to 32

Browse files
Files changed (4) hide show
  1. .gitignore +4 -1
  2. handler.py +20 -11
  3. ip_adapter/ip_adapter.py +11 -11
  4. test.py +12 -0
.gitignore CHANGED
@@ -1 +1,4 @@
1
- /sdxl_models/*
 
 
 
 
1
+ /sdxl_models/*
2
+ **/__pycache__
3
+
4
+ **/.DS_Store
handler.py CHANGED
@@ -23,8 +23,11 @@ from diffusers import (
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
 
 
 
28
 
29
  # initialization
30
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -39,7 +42,6 @@ class EndpointHandler():
39
 
40
  repo_id = "h94/IP-Adapter"
41
 
42
- # Descargar todo el contenido del directorio image_encoder
43
  local_repo_path = snapshot_download(repo_id=repo_id)
44
  # image_encoder_local_path = os.path.join(local_repo_path, "image_encoder")
45
  self.image_encoder_local_path = os.path.join(local_repo_path, "sdxl_models", "image_encoder")
@@ -47,7 +49,7 @@ class EndpointHandler():
47
 
48
 
49
  self.controlnet = ControlNetModel.from_pretrained(
50
- controlnet_path, use_safetensors=False, torch_dtype=torch.float16
51
  ).to(device)
52
 
53
  # load SDXL lightnining
@@ -55,7 +57,7 @@ class EndpointHandler():
55
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
56
  base_model_path,
57
  controlnet=self.controlnet,
58
- torch_dtype=torch.float16,
59
  variant="fp16",
60
  add_watermarker=False,
61
  ).to(device)
@@ -63,14 +65,21 @@ class EndpointHandler():
63
  self.pipe.scheduler = EulerDiscreteScheduler.from_config(
64
  self.pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
65
  )
66
- self.pipe.unet.load_state_dict(
67
- load_file(
68
- hf_hub_download(
69
- "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
70
- ),
71
- device="cuda",
 
 
 
 
 
72
  )
73
  )
 
 
74
 
75
  self.ip_model = IPAdapterXL(
76
  self.pipe,
 
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ # dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
+
29
+ device = torch.device("cpu")
30
+ dtype = torch.float32
31
 
32
  # initialization
33
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
 
42
 
43
  repo_id = "h94/IP-Adapter"
44
 
 
45
  local_repo_path = snapshot_download(repo_id=repo_id)
46
  # image_encoder_local_path = os.path.join(local_repo_path, "image_encoder")
47
  self.image_encoder_local_path = os.path.join(local_repo_path, "sdxl_models", "image_encoder")
 
49
 
50
 
51
  self.controlnet = ControlNetModel.from_pretrained(
52
+ controlnet_path, use_safetensors=False, torch_dtype=torch.float32
53
  ).to(device)
54
 
55
  # load SDXL lightnining
 
57
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
58
  base_model_path,
59
  controlnet=self.controlnet,
60
+ torch_dtype=torch.float32,
61
  variant="fp16",
62
  add_watermarker=False,
63
  ).to(device)
 
65
  self.pipe.scheduler = EulerDiscreteScheduler.from_config(
66
  self.pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
67
  )
68
+ # self.pipe.unet.load_state_dict(
69
+ # load_file(
70
+ # hf_hub_download(
71
+ # "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
72
+ # ),
73
+ # device="cuda",
74
+ # )
75
+ # )
76
+ state_dict = load_file(
77
+ hf_hub_download(
78
+ "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
79
  )
80
  )
81
+ self.pipe.unet.load_state_dict(state_dict)
82
+ self.pipe.unet.to(device)
83
 
84
  self.ip_model = IPAdapterXL(
85
  self.pipe,
ip_adapter/ip_adapter.py CHANGED
@@ -102,7 +102,7 @@ class IPAdapter:
102
 
103
  # load image encoder
104
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
105
- self.device, dtype=torch.float16
106
  )
107
 
108
  self.clip_image_processor = CLIPImageProcessor()
@@ -117,7 +117,7 @@ class IPAdapter:
117
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
118
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
119
  clip_extra_context_tokens=self.num_tokens,
120
- ).to(self.device, dtype=torch.float16)
121
  return image_proj_model
122
 
123
  def set_ip_adapter(self):
@@ -147,7 +147,7 @@ class IPAdapter:
147
  cross_attention_dim=cross_attention_dim,
148
  scale=1.0,
149
  num_tokens=self.num_tokens,
150
- ).to(self.device, dtype=torch.float16)
151
  else:
152
  attn_procs[name] = IPAttnProcessor(
153
  hidden_size=hidden_size,
@@ -155,7 +155,7 @@ class IPAdapter:
155
  scale=1.0,
156
  num_tokens=self.num_tokens,
157
  skip=True
158
- ).to(self.device, dtype=torch.float16)
159
  unet.set_attn_processor(attn_procs)
160
  if hasattr(self.pipe, "controlnet"):
161
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
@@ -185,9 +185,9 @@ class IPAdapter:
185
  if isinstance(pil_image, Image.Image):
186
  pil_image = [pil_image]
187
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
188
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
189
  else:
190
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
191
 
192
  if content_prompt_embeds is not None:
193
  clip_image_embeds = clip_image_embeds - content_prompt_embeds
@@ -367,7 +367,7 @@ class IPAdapterPlus(IPAdapter):
367
  embedding_dim=self.image_encoder.config.hidden_size,
368
  output_dim=self.pipe.unet.config.cross_attention_dim,
369
  ff_mult=4,
370
- ).to(self.device, dtype=torch.float16)
371
  return image_proj_model
372
 
373
  @torch.inference_mode()
@@ -375,7 +375,7 @@ class IPAdapterPlus(IPAdapter):
375
  if isinstance(pil_image, Image.Image):
376
  pil_image = [pil_image]
377
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
378
- clip_image = clip_image.to(self.device, dtype=torch.float16)
379
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
380
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
381
  uncond_clip_image_embeds = self.image_encoder(
@@ -392,7 +392,7 @@ class IPAdapterFull(IPAdapterPlus):
392
  image_proj_model = MLPProjModel(
393
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
394
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
395
- ).to(self.device, dtype=torch.float16)
396
  return image_proj_model
397
 
398
 
@@ -409,7 +409,7 @@ class IPAdapterPlusXL(IPAdapter):
409
  embedding_dim=self.image_encoder.config.hidden_size,
410
  output_dim=self.pipe.unet.config.cross_attention_dim,
411
  ff_mult=4,
412
- ).to(self.device, dtype=torch.float16)
413
  return image_proj_model
414
 
415
  @torch.inference_mode()
@@ -417,7 +417,7 @@ class IPAdapterPlusXL(IPAdapter):
417
  if isinstance(pil_image, Image.Image):
418
  pil_image = [pil_image]
419
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
420
- clip_image = clip_image.to(self.device, dtype=torch.float16)
421
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
422
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
423
  uncond_clip_image_embeds = self.image_encoder(
 
102
 
103
  # load image encoder
104
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
105
+ self.device, dtype=torch.float32
106
  )
107
 
108
  self.clip_image_processor = CLIPImageProcessor()
 
117
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
118
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
119
  clip_extra_context_tokens=self.num_tokens,
120
+ ).to(self.device, dtype=torch.float32)
121
  return image_proj_model
122
 
123
  def set_ip_adapter(self):
 
147
  cross_attention_dim=cross_attention_dim,
148
  scale=1.0,
149
  num_tokens=self.num_tokens,
150
+ ).to(self.device, dtype=torch.float32)
151
  else:
152
  attn_procs[name] = IPAttnProcessor(
153
  hidden_size=hidden_size,
 
155
  scale=1.0,
156
  num_tokens=self.num_tokens,
157
  skip=True
158
+ ).to(self.device, dtype=torch.float32)
159
  unet.set_attn_processor(attn_procs)
160
  if hasattr(self.pipe, "controlnet"):
161
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
 
185
  if isinstance(pil_image, Image.Image):
186
  pil_image = [pil_image]
187
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
188
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
189
  else:
190
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32)
191
 
192
  if content_prompt_embeds is not None:
193
  clip_image_embeds = clip_image_embeds - content_prompt_embeds
 
367
  embedding_dim=self.image_encoder.config.hidden_size,
368
  output_dim=self.pipe.unet.config.cross_attention_dim,
369
  ff_mult=4,
370
+ ).to(self.device, dtype=torch.float32)
371
  return image_proj_model
372
 
373
  @torch.inference_mode()
 
375
  if isinstance(pil_image, Image.Image):
376
  pil_image = [pil_image]
377
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
378
+ clip_image = clip_image.to(self.device, dtype=torch.float32)
379
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
380
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
381
  uncond_clip_image_embeds = self.image_encoder(
 
392
  image_proj_model = MLPProjModel(
393
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
394
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
395
+ ).to(self.device, dtype=torch.float32)
396
  return image_proj_model
397
 
398
 
 
409
  embedding_dim=self.image_encoder.config.hidden_size,
410
  output_dim=self.pipe.unet.config.cross_attention_dim,
411
  ff_mult=4,
412
+ ).to(self.device, dtype=torch.float32)
413
  return image_proj_model
414
 
415
  @torch.inference_mode()
 
417
  if isinstance(pil_image, Image.Image):
418
  pil_image = [pil_image]
419
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
420
+ clip_image = clip_image.to(self.device, dtype=torch.float32)
421
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
422
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
423
  uncond_clip_image_embeds = self.image_encoder(
test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # Crear una instancia del handler
4
+ handler = EndpointHandler(model_dir="./")
5
+
6
+ # Llamar al handler con datos de prueba
7
+ data = {
8
+ "inputs": "A photo of a cat"
9
+ }
10
+ resultado = handler(data=data)
11
+
12
+ print(resultado)