JUNJIE99 commited on
Commit
c25733e
·
verified ·
1 Parent(s): 2eb6ce7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_MMRet_CLIP.py +14 -11
modeling_MMRet_CLIP.py CHANGED
@@ -38,10 +38,10 @@ from transformers.utils import (
38
  replace_return_docstrings,
39
  )
40
  from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
-
42
 
43
  if is_flash_attn_2_available():
44
- from ...modeling_flash_attention_utils import _flash_attention_forward
45
 
46
 
47
  logger = logging.get_logger(__name__)
@@ -50,7 +50,7 @@ logger = logging.get_logger(__name__)
50
  _CONFIG_FOR_DOC = "MMRet_CLIP"
51
 
52
  # Image classification docstring
53
- _IMAGE_CLASS_CHECKPOINT = "JUNJIE/MMRet-large"
54
  _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
55
 
56
 
@@ -1160,6 +1160,9 @@ class CLIPModel(CLIPPreTrainedModel):
1160
  # Initialize weights and apply final processing
1161
  self.post_init()
1162
 
 
 
 
1163
  @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1164
  def get_text_features(
1165
  self,
@@ -1258,18 +1261,18 @@ class CLIPModel(CLIPPreTrainedModel):
1258
 
1259
 
1260
  def encode_image(self, images):
1261
- embeddings = self.model.get_image_features(images)
1262
  embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1263
  return embeddings
1264
 
1265
  def encode_text(self, text):
1266
- embeddings = self.model.get_text_features(**text)
1267
  embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1268
  return embeddings
1269
 
1270
  def encode_multimodal(self, images, text):
1271
- text_embeddings = self.model.get_text_features(**text)
1272
- image_embeddings = self.model.get_image_features(images)
1273
 
1274
  embeddings = text_embeddings + image_embeddings
1275
  embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
@@ -1278,7 +1281,7 @@ class CLIPModel(CLIPPreTrainedModel):
1278
 
1279
  def data_process(self, images=None, text=None):
1280
  if images is None and text is not None:
1281
- text = self.processor(text=text, return_tensors="pt", padding=True).to(self.model.device)
1282
 
1283
  return images, text, "text"
1284
  elif images is not None and text is None:
@@ -1286,7 +1289,7 @@ class CLIPModel(CLIPPreTrainedModel):
1286
  images = Image.open(images).convert("RGB")
1287
  elif isinstance(images, list):
1288
  images = [Image.open(image).convert("RGB") for image in images]
1289
- images = self.processor(images=images, return_tensors="pt").to(self.model.device)
1290
  images = images["pixel_values"]
1291
  return images, text, "images"
1292
  elif images is not None and text is not None:
@@ -1296,9 +1299,9 @@ class CLIPModel(CLIPPreTrainedModel):
1296
  elif isinstance(images, list):
1297
  assert len(images) == len(text), "images and text must be lists of the same length when use list"
1298
  images = [Image.open(image).convert("RGB") for image in images]
1299
- images = self.processor(images=images, return_tensors="pt").to(self.model.device)
1300
  images = images["pixel_values"]
1301
- text = self.processor(text=text, return_tensors="pt", padding=True).to(self.model.device)
1302
  return images, text, "multimodal"
1303
  else:
1304
  raise ValueError("images and text cannot both be None")
 
38
  replace_return_docstrings,
39
  )
40
  from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
+ from transformers import CLIPProcessor
42
 
43
  if is_flash_attn_2_available():
44
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
45
 
46
 
47
  logger = logging.get_logger(__name__)
 
50
  _CONFIG_FOR_DOC = "MMRet_CLIP"
51
 
52
  # Image classification docstring
53
+ _IMAGE_CLASS_CHECKPOINT = "JUNJIE99/MMRet-base"
54
  _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
55
 
56
 
 
1160
  # Initialize weights and apply final processing
1161
  self.post_init()
1162
 
1163
+ def set_processor(self, model_name):
1164
+ self.processor = CLIPProcessor.from_pretrained(model_name)
1165
+
1166
  @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1167
  def get_text_features(
1168
  self,
 
1261
 
1262
 
1263
  def encode_image(self, images):
1264
+ embeddings = self.get_image_features(images)
1265
  embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1266
  return embeddings
1267
 
1268
  def encode_text(self, text):
1269
+ embeddings = self.get_text_features(**text)
1270
  embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1271
  return embeddings
1272
 
1273
  def encode_multimodal(self, images, text):
1274
+ text_embeddings = self.get_text_features(**text)
1275
+ image_embeddings = self.get_image_features(images)
1276
 
1277
  embeddings = text_embeddings + image_embeddings
1278
  embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
 
1281
 
1282
  def data_process(self, images=None, text=None):
1283
  if images is None and text is not None:
1284
+ text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
1285
 
1286
  return images, text, "text"
1287
  elif images is not None and text is None:
 
1289
  images = Image.open(images).convert("RGB")
1290
  elif isinstance(images, list):
1291
  images = [Image.open(image).convert("RGB") for image in images]
1292
+ images = self.processor(images=images, return_tensors="pt").to(self.device)
1293
  images = images["pixel_values"]
1294
  return images, text, "images"
1295
  elif images is not None and text is not None:
 
1299
  elif isinstance(images, list):
1300
  assert len(images) == len(text), "images and text must be lists of the same length when use list"
1301
  images = [Image.open(image).convert("RGB") for image in images]
1302
+ images = self.processor(images=images, return_tensors="pt").to(self.device)
1303
  images = images["pixel_values"]
1304
+ text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
1305
  return images, text, "multimodal"
1306
  else:
1307
  raise ValueError("images and text cannot both be None")