yyk19 commited on
Commit
5b7a49a
1 Parent(s): 4c34c9a

support xformers. modify configs.

Browse files
config.yaml CHANGED
@@ -85,4 +85,4 @@ model:
85
  params:
86
  freeze: True
87
  layer: "penultimate"
88
- device: "cpu" #TODO: specify
 
85
  params:
86
  freeze: True
87
  layer: "penultimate"
88
+ # device: "cpu" #TODO: specify
ldm/modules/encoders/modules.py CHANGED
@@ -118,34 +118,6 @@ class FrozenT5Embedder(AbstractEncoder):
118
  def encode(self, text):
119
  return self(text)
120
 
121
- # class FrozenByT5Embedder(FrozenT5Embedder):
122
- # """Uses the ByT5 transformer encoder for text"""
123
- # def __init__(self, version="google/byt5-small", device="cuda", max_length=77, freeze=True): # others are google/byt5-v1_1-xl and google/t5-v1_1-xxl
124
- # # super(super()).__init__()
125
- # nn.Module.__init__(self)
126
- # print(dir(super()))
127
- # # self.tokenizer = T5Tokenizer.from_pretrained(version)
128
- # # self.transformer = T5EncoderModel.from_pretrained(version)
129
- # self.tokenizer = AutoTokenizer.from_pretrained(version)
130
- # self.model = T5ForConditionalGeneration.from_pretrained(version)
131
- # self.tokenizer_new = ByT5Tokenizer.from_pretrained(version)
132
- # self.transformer = T5EncoderModel.from_pretrained(version)
133
- # self.device = device
134
- # self.max_length = max_length # TODO: typical value?
135
- # if freeze:
136
- # self.freeze()
137
-
138
- # def forward(self, text):
139
- # # code base: https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/tokenization_utils_base.py#L2414
140
- # batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
141
- # return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
142
- # tokens = batch_encoding["input_ids"].to(self.device)
143
- # outputs = self.transformer(input_ids=tokens)
144
-
145
- # z = outputs.last_hidden_state
146
- # return z
147
-
148
-
149
  class FrozenCLIPEmbedder(AbstractEncoder):
150
  """Uses the CLIP transformer encoder for text (from huggingface)"""
151
  LAYERS = [
@@ -211,8 +183,12 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
211
  # aa = model.encode_image(torch.zeros((1, 3,224,224)))
212
  del model.visual
213
  self.model = model
214
-
215
- self.device = device
 
 
 
 
216
  self.max_length = max_length
217
  if freeze:
218
  self.freeze()
@@ -439,12 +415,6 @@ class FrozenOpenCLIPT5ByT5SepEncoder(FrozenOpenCLIPT5ByT5Encoder):
439
  clip_text = text[0]
440
  t5_text = text[1] if len(text) > 1 else text[0]
441
  byt5_text = text[-1]
442
- # clip_z = self.clip_encoder.encode(text[0]) #B*77*1024
443
- # t5_z = self.t5_encoder.encode(text[1]) #B*77*Z_1
444
- # if len(text) == 2:
445
- # byt5_z = self.byt5_encoder.encode(text[1]) #B*77*Z_2
446
- # else:
447
- # byt5_z = self.byt5_encoder.encode(text[2]) #B*77*Z_2
448
  else:
449
  clip_text = text
450
  t5_text = text
@@ -460,17 +430,10 @@ class OpenCLIPImageEmbedder(AbstractEncoder):
460
  """
461
  Uses the OpenCLIP transformer encoder for image
462
  """
463
- # LAYERS = [
464
- # #"pooled",
465
- # "last",
466
- # "penultimate"
467
- # ]
468
- def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", #stage = "train", # max_length=77, layer = "last",
469
  freeze=True, set_grad_checkpointing = True):
470
  super().__init__()
471
- # assert layer in self.LAYERS
472
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
473
- # del model.visual
474
  self.image_mean = model.visual.image_mean
475
  self.image_std = model.visual.image_std
476
  del model.transformer
@@ -480,29 +443,16 @@ class OpenCLIPImageEmbedder(AbstractEncoder):
480
  del model.text_projection
481
  del model.logit_scale
482
  # only model.visual is left
483
- # open_clip.model._build_vision_tower()
484
 
485
  self.model = model
486
  self.device = device
487
 
488
  if not freeze and set_grad_checkpointing:
489
  self.model.visual.set_grad_checkpointing(True)
490
- # if freeze:
491
- # self.freeze()
492
- # else:
493
- # if set_grad_checkpointing:
494
- # self.model.visual.set_grad_checkpointing(True)
495
  self.freeze_model = freeze
496
 
497
- # def freeze(self):
498
- # self.model = self.model.eval()
499
- # for param in self.parameters(): #392
500
- # param.requires_grad = False
501
-
502
  def forward(self, img):
503
  z = self.model.encode_image(img) # 2.0.2 , normalize=False) 2.7.0
504
- # tokens = open_clip.tokenize(text)
505
- # z = self.encode_with_transformer(tokens.to(self.device))
506
  return z
507
 
508
  def encode(self, img):
 
118
  def encode(self, text):
119
  return self(text)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  class FrozenCLIPEmbedder(AbstractEncoder):
122
  """Uses the CLIP transformer encoder for text (from huggingface)"""
123
  LAYERS = [
 
183
  # aa = model.encode_image(torch.zeros((1, 3,224,224)))
184
  del model.visual
185
  self.model = model
186
+
187
+ if not torch.cuda.is_available():
188
+ self.device = "cpu"
189
+ else:
190
+ self.device = device
191
+
192
  self.max_length = max_length
193
  if freeze:
194
  self.freeze()
 
415
  clip_text = text[0]
416
  t5_text = text[1] if len(text) > 1 else text[0]
417
  byt5_text = text[-1]
 
 
 
 
 
 
418
  else:
419
  clip_text = text
420
  t5_text = text
 
430
  """
431
  Uses the OpenCLIP transformer encoder for image
432
  """
433
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda",
 
 
 
 
 
434
  freeze=True, set_grad_checkpointing = True):
435
  super().__init__()
 
436
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
 
437
  self.image_mean = model.visual.image_mean
438
  self.image_std = model.visual.image_std
439
  del model.transformer
 
443
  del model.text_projection
444
  del model.logit_scale
445
  # only model.visual is left
 
446
 
447
  self.model = model
448
  self.device = device
449
 
450
  if not freeze and set_grad_checkpointing:
451
  self.model.visual.set_grad_checkpointing(True)
 
 
 
 
 
452
  self.freeze_model = freeze
453
 
 
 
 
 
 
454
  def forward(self, img):
455
  z = self.model.encode_image(img) # 2.0.2 , normalize=False) 2.7.0
 
 
456
  return z
457
 
458
  def encode(self, img):
requirements.txt CHANGED
@@ -6,4 +6,6 @@ gradio
6
  einops
7
  pytorch-lightning==1.6.5
8
  transformers
9
- open_clip_torch
 
 
 
6
  einops
7
  pytorch-lightning==1.6.5
8
  transformers
9
+ open_clip_torch
10
+ ninja
11
+ git+https://github.com/facebookresearch/xformers.git@main#egg=xformers