hpoghos commited on
Commit
c048c54
1 Parent(s): 3301107

add open_clip

Browse files
t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from torchvision.transforms.functional import to_pil_image
5
  import torch.nn as nn
6
  import kornia
7
- # import open_clip
8
  from transformers import CLIPVisionModelWithProjection, AutoProcessor
9
  from transformers.models.bit.image_processing_bit import BitImageProcessor
10
  from einops import rearrange, repeat
@@ -73,15 +73,15 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
73
  output_tokens=False,
74
  ):
75
  super().__init__()
76
- # model, _, _ = open_clip.create_model_and_transforms(
77
- # arch,
78
- # device=torch.device("cpu"),
79
- # pretrained=version,
80
- # )
81
- # del model.transformer
82
- # self.model = model
83
- self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
84
- self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
85
 
86
  self.max_crops = num_image_crops
87
  self.pad_to_max_len = self.max_crops > 0
@@ -120,10 +120,10 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
120
  return x
121
 
122
  def freeze(self):
123
- # self.model = self.model.eval()
124
  for param in self.parameters():
125
  param.requires_grad = False
126
- self.model_t = self.model_t.eval()
127
 
128
  def forward(self, image, no_dropout=False):
129
  z = self.encode_with_vision_transformer(image)
@@ -181,40 +181,40 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
181
  def encode_with_vision_transformer(self, img):
182
  if self.max_crops > 0:
183
  img = self.preprocess_by_cropping(img)
184
- pil_img = to_pil_image(img[0]*0.5 + 0.5)
185
- inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda")
186
- outputs = self.model_t(**inputs)
187
- return outputs.image_embeds
188
- # if img.dim() == 5:
189
- # assert self.max_crops == img.shape[1]
190
- # img = rearrange(img, "b n c h w -> (b n) c h w")
191
- # img = self.preprocess(img)
192
- # if not self.output_tokens:
193
- # assert not self.model.visual.output_tokens
194
- # x = self.model.visual(img)
195
- # tokens = None
196
- # else:
197
- # assert self.model.visual.output_tokens
198
- # x, tokens = self.model.visual(img)
199
- # if self.max_crops > 0:
200
- # x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
201
- # # drop out between 0 and all along the sequence axis
202
- # x = (
203
- # torch.bernoulli(
204
- # (1.0 - self.ucg_rate)
205
- # * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
206
- # )
207
- # * x
208
- # )
209
- # if tokens is not None:
210
- # tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
211
- # print(
212
- # f"You are running very experimental token-concat in {self.__class__.__name__}. "
213
- # f"Check what you are doing, and then remove this message."
214
- # )
215
- # if self.output_tokens:
216
- # return x, tokens
217
- # return x
218
 
219
  def encode(self, text):
220
  return self(text)
 
4
  from torchvision.transforms.functional import to_pil_image
5
  import torch.nn as nn
6
  import kornia
7
+ import open_clip
8
  from transformers import CLIPVisionModelWithProjection, AutoProcessor
9
  from transformers.models.bit.image_processing_bit import BitImageProcessor
10
  from einops import rearrange, repeat
 
73
  output_tokens=False,
74
  ):
75
  super().__init__()
76
+ model, _, _ = open_clip.create_model_and_transforms(
77
+ arch,
78
+ device=torch.device("cpu"),
79
+ pretrained=version,
80
+ )
81
+ del model.transformer
82
+ self.model = model
83
+ # self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
84
+ # self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
85
 
86
  self.max_crops = num_image_crops
87
  self.pad_to_max_len = self.max_crops > 0
 
120
  return x
121
 
122
  def freeze(self):
123
+ self.model = self.model.eval()
124
  for param in self.parameters():
125
  param.requires_grad = False
126
+ # self.model_t = self.model_t.eval()
127
 
128
  def forward(self, image, no_dropout=False):
129
  z = self.encode_with_vision_transformer(image)
 
181
  def encode_with_vision_transformer(self, img):
182
  if self.max_crops > 0:
183
  img = self.preprocess_by_cropping(img)
184
+ # pil_img = to_pil_image(img[0]*0.5 + 0.5)
185
+ # inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda")
186
+ # outputs = self.model_t(**inputs)
187
+ # return outputs.image_embeds
188
+ if img.dim() == 5:
189
+ assert self.max_crops == img.shape[1]
190
+ img = rearrange(img, "b n c h w -> (b n) c h w")
191
+ img = self.preprocess(img)
192
+ if not self.output_tokens:
193
+ assert not self.model.visual.output_tokens
194
+ x = self.model.visual(img)
195
+ tokens = None
196
+ else:
197
+ assert self.model.visual.output_tokens
198
+ x, tokens = self.model.visual(img)
199
+ if self.max_crops > 0:
200
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
201
+ # drop out between 0 and all along the sequence axis
202
+ x = (
203
+ torch.bernoulli(
204
+ (1.0 - self.ucg_rate)
205
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
206
+ )
207
+ * x
208
+ )
209
+ if tokens is not None:
210
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
211
+ print(
212
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
213
+ f"Check what you are doing, and then remove this message."
214
+ )
215
+ if self.output_tokens:
216
+ return x, tokens
217
+ return x
218
 
219
  def encode(self, text):
220
  return self(text)