Upload folder using huggingface_hub
Browse files- modeling_MMRet_CLIP.py +14 -11
modeling_MMRet_CLIP.py
CHANGED
@@ -38,10 +38,10 @@ from transformers.utils import (
|
|
38 |
replace_return_docstrings,
|
39 |
)
|
40 |
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
41 |
-
|
42 |
|
43 |
if is_flash_attn_2_available():
|
44 |
-
from
|
45 |
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
@@ -50,7 +50,7 @@ logger = logging.get_logger(__name__)
|
|
50 |
_CONFIG_FOR_DOC = "MMRet_CLIP"
|
51 |
|
52 |
# Image classification docstring
|
53 |
-
_IMAGE_CLASS_CHECKPOINT = "
|
54 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
55 |
|
56 |
|
@@ -1160,6 +1160,9 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
1160 |
# Initialize weights and apply final processing
|
1161 |
self.post_init()
|
1162 |
|
|
|
|
|
|
|
1163 |
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
1164 |
def get_text_features(
|
1165 |
self,
|
@@ -1258,18 +1261,18 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
1258 |
|
1259 |
|
1260 |
def encode_image(self, images):
|
1261 |
-
embeddings = self.
|
1262 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
1263 |
return embeddings
|
1264 |
|
1265 |
def encode_text(self, text):
|
1266 |
-
embeddings = self.
|
1267 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
1268 |
return embeddings
|
1269 |
|
1270 |
def encode_multimodal(self, images, text):
|
1271 |
-
text_embeddings = self.
|
1272 |
-
image_embeddings = self.
|
1273 |
|
1274 |
embeddings = text_embeddings + image_embeddings
|
1275 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
@@ -1278,7 +1281,7 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
1278 |
|
1279 |
def data_process(self, images=None, text=None):
|
1280 |
if images is None and text is not None:
|
1281 |
-
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.
|
1282 |
|
1283 |
return images, text, "text"
|
1284 |
elif images is not None and text is None:
|
@@ -1286,7 +1289,7 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
1286 |
images = Image.open(images).convert("RGB")
|
1287 |
elif isinstance(images, list):
|
1288 |
images = [Image.open(image).convert("RGB") for image in images]
|
1289 |
-
images = self.processor(images=images, return_tensors="pt").to(self.
|
1290 |
images = images["pixel_values"]
|
1291 |
return images, text, "images"
|
1292 |
elif images is not None and text is not None:
|
@@ -1296,9 +1299,9 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
1296 |
elif isinstance(images, list):
|
1297 |
assert len(images) == len(text), "images and text must be lists of the same length when use list"
|
1298 |
images = [Image.open(image).convert("RGB") for image in images]
|
1299 |
-
images = self.processor(images=images, return_tensors="pt").to(self.
|
1300 |
images = images["pixel_values"]
|
1301 |
-
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.
|
1302 |
return images, text, "multimodal"
|
1303 |
else:
|
1304 |
raise ValueError("images and text cannot both be None")
|
|
|
38 |
replace_return_docstrings,
|
39 |
)
|
40 |
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
41 |
+
from transformers import CLIPProcessor
|
42 |
|
43 |
if is_flash_attn_2_available():
|
44 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
45 |
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
|
|
50 |
_CONFIG_FOR_DOC = "MMRet_CLIP"
|
51 |
|
52 |
# Image classification docstring
|
53 |
+
_IMAGE_CLASS_CHECKPOINT = "JUNJIE99/MMRet-base"
|
54 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
55 |
|
56 |
|
|
|
1160 |
# Initialize weights and apply final processing
|
1161 |
self.post_init()
|
1162 |
|
1163 |
+
def set_processor(self, model_name):
|
1164 |
+
self.processor = CLIPProcessor.from_pretrained(model_name)
|
1165 |
+
|
1166 |
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
1167 |
def get_text_features(
|
1168 |
self,
|
|
|
1261 |
|
1262 |
|
1263 |
def encode_image(self, images):
|
1264 |
+
embeddings = self.get_image_features(images)
|
1265 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
1266 |
return embeddings
|
1267 |
|
1268 |
def encode_text(self, text):
|
1269 |
+
embeddings = self.get_text_features(**text)
|
1270 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
1271 |
return embeddings
|
1272 |
|
1273 |
def encode_multimodal(self, images, text):
|
1274 |
+
text_embeddings = self.get_text_features(**text)
|
1275 |
+
image_embeddings = self.get_image_features(images)
|
1276 |
|
1277 |
embeddings = text_embeddings + image_embeddings
|
1278 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
|
1281 |
|
1282 |
def data_process(self, images=None, text=None):
|
1283 |
if images is None and text is not None:
|
1284 |
+
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
|
1285 |
|
1286 |
return images, text, "text"
|
1287 |
elif images is not None and text is None:
|
|
|
1289 |
images = Image.open(images).convert("RGB")
|
1290 |
elif isinstance(images, list):
|
1291 |
images = [Image.open(image).convert("RGB") for image in images]
|
1292 |
+
images = self.processor(images=images, return_tensors="pt").to(self.device)
|
1293 |
images = images["pixel_values"]
|
1294 |
return images, text, "images"
|
1295 |
elif images is not None and text is not None:
|
|
|
1299 |
elif isinstance(images, list):
|
1300 |
assert len(images) == len(text), "images and text must be lists of the same length when use list"
|
1301 |
images = [Image.open(image).convert("RGB") for image in images]
|
1302 |
+
images = self.processor(images=images, return_tensors="pt").to(self.device)
|
1303 |
images = images["pixel_values"]
|
1304 |
+
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
|
1305 |
return images, text, "multimodal"
|
1306 |
else:
|
1307 |
raise ValueError("images and text cannot both be None")
|