update preprocessing
Browse files- demo.ipynb +3 -5
- image_processing_blip_3.py +3 -18
- vlm.py +1 -126
demo.ipynb
CHANGED
@@ -253,10 +253,10 @@
|
|
253 |
" for fn in sample['image_path']:\n",
|
254 |
" img = PIL.Image.open(fn)\n",
|
255 |
" display.display(Image(filename=fn, width=300))\n",
|
256 |
-
" image_list.append(image_processor([img], image_aspect_ratio='anyres')[\"pixel_values\"])\n",
|
257 |
" image_sizes.append(img.size)\n",
|
258 |
" inputs = {\n",
|
259 |
-
" \"pixel_values\": image_list\n",
|
260 |
" }\n",
|
261 |
" for query in sample['question']:\n",
|
262 |
" prompt = apply_prompt_template(query)\n",
|
@@ -266,9 +266,7 @@
|
|
266 |
" for name, value in inputs.items():\n",
|
267 |
" if isinstance(value, torch.Tensor):\n",
|
268 |
" inputs[name] = value.cuda()\n",
|
269 |
-
"
|
270 |
-
" inputs[name] = [v.cuda() for v in value]\n",
|
271 |
-
" generated_text = model.generate(**inputs, image_size=image_sizes,\n",
|
272 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
273 |
" do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1,\n",
|
274 |
" )\n",
|
|
|
253 |
" for fn in sample['image_path']:\n",
|
254 |
" img = PIL.Image.open(fn)\n",
|
255 |
" display.display(Image(filename=fn, width=300))\n",
|
256 |
+
" image_list.append(image_processor([img], image_aspect_ratio='anyres')[\"pixel_values\"].cuda())\n",
|
257 |
" image_sizes.append(img.size)\n",
|
258 |
" inputs = {\n",
|
259 |
+
" \"pixel_values\": [image_list]\n",
|
260 |
" }\n",
|
261 |
" for query in sample['question']:\n",
|
262 |
" prompt = apply_prompt_template(query)\n",
|
|
|
266 |
" for name, value in inputs.items():\n",
|
267 |
" if isinstance(value, torch.Tensor):\n",
|
268 |
" inputs[name] = value.cuda()\n",
|
269 |
+
" generated_text = model.generate(**inputs, image_size=[image_sizes],\n",
|
|
|
|
|
270 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
271 |
" do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1,\n",
|
272 |
" )\n",
|
image_processing_blip_3.py
CHANGED
@@ -109,26 +109,11 @@ class Blip3ImageProcessor(BaseImageProcessor):
|
|
109 |
|
110 |
if all(x.shape == new_images[0].shape for x in new_images):
|
111 |
new_images = torch.stack(new_images, dim=0)
|
112 |
-
if image_aspect_ratio == '
|
113 |
-
new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0).unsqueeze(0)}, tensor_type=return_tensors)
|
114 |
-
else:
|
115 |
new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
|
|
|
|
|
116 |
return new_images
|
117 |
-
# def preprocess(self,
|
118 |
-
# images: ImageInput,
|
119 |
-
# return_tensors: Optional[Union[str, TensorType]] = None,
|
120 |
-
# **kwargs) -> BatchFeature:
|
121 |
-
# transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
|
122 |
-
# transforms.extend([
|
123 |
-
# self.convert_rgb,
|
124 |
-
# ToTensor(),
|
125 |
-
# Normalize(mean=self.image_mean, std=self.image_std)
|
126 |
-
# ])
|
127 |
-
# composed_transforms = Compose(transforms)
|
128 |
-
# images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
129 |
-
# encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
|
130 |
-
# return encoded_outputs
|
131 |
-
|
132 |
|
133 |
class ResizeKeepRatio:
|
134 |
""" Resize and Keep Ratio
|
|
|
109 |
|
110 |
if all(x.shape == new_images[0].shape for x in new_images):
|
111 |
new_images = torch.stack(new_images, dim=0)
|
112 |
+
if image_aspect_ratio == 'anyres':
|
|
|
|
|
113 |
new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
|
114 |
+
else:
|
115 |
+
new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)}, tensor_type=return_tensors)
|
116 |
return new_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
class ResizeKeepRatio:
|
119 |
""" Resize and Keep Ratio
|
vlm.py
CHANGED
@@ -1043,10 +1043,6 @@ class VLMWithLanguageStream(VLM):
|
|
1043 |
multimodal_labels.append(labels[i].clone())
|
1044 |
continue
|
1045 |
|
1046 |
-
# since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
|
1047 |
-
for j, img_idx in enumerate(image_token_idxs):
|
1048 |
-
image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j # FIXME: different offset for any resolution encoding when has multiple images.
|
1049 |
-
|
1050 |
# loop through the image_token_idxs and insert the vision tokens
|
1051 |
new_embed = lang_embeds[i].clone()
|
1052 |
new_attention_mask = (
|
@@ -1056,9 +1052,6 @@ class VLMWithLanguageStream(VLM):
|
|
1056 |
new_label = labels[i].clone()
|
1057 |
|
1058 |
for img_num, img_idx in enumerate(image_token_idxs):
|
1059 |
-
if img_num > 0:
|
1060 |
-
# FIXME: hardcoded as such to avoid assertion error, but this only works for single image samples.
|
1061 |
-
break
|
1062 |
# Get vision token attention mask for padded llava-style any resolution image tokens.
|
1063 |
if self.image_aspect_ratio =='anyres':
|
1064 |
num_vis_tokens = vision_tokens[i][img_num].shape[0]
|
@@ -1078,7 +1071,6 @@ class VLMWithLanguageStream(VLM):
|
|
1078 |
vis_attention_mask = torch.ones(
|
1079 |
num_vis_tokens, dtype=torch.long
|
1080 |
).to(attention_mask.device)
|
1081 |
-
|
1082 |
|
1083 |
new_embed = torch.cat(
|
1084 |
(
|
@@ -1275,123 +1267,6 @@ class XGenMMPerceiver(VLMWithLanguageStream):
|
|
1275 |
"""
|
1276 |
return True
|
1277 |
|
1278 |
-
def forward(
|
1279 |
-
self,
|
1280 |
-
vision_x: Optional[torch.Tensor],
|
1281 |
-
lang_x: torch.Tensor,
|
1282 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1283 |
-
labels: Optional[torch.Tensor] = None,
|
1284 |
-
image_size: Optional[Tuple] = None,
|
1285 |
-
past_key_values: Optional[
|
1286 |
-
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
1287 |
-
] = None,
|
1288 |
-
past_media_locations: Optional[torch.Tensor] = None,
|
1289 |
-
past_vision_tokens: Optional[torch.Tensor] = None,
|
1290 |
-
use_cache: Optional[bool] = False,
|
1291 |
-
**kwargs,
|
1292 |
-
):
|
1293 |
-
"""
|
1294 |
-
Args:
|
1295 |
-
vision_x: Vision input
|
1296 |
-
shape (B, T_img, F, C, H, W) with F=1
|
1297 |
-
only F = 1 is supported (single-frame videos)
|
1298 |
-
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
|
1299 |
-
only the first number of media tokens in lang_x are used
|
1300 |
-
lang_x: Language input ids, with media tokens denoting where
|
1301 |
-
visual media should be inserted.
|
1302 |
-
shape (B, T_txt)
|
1303 |
-
attention_mask: Attention mask. Defaults to None.
|
1304 |
-
labels: Labels. Defaults to None.
|
1305 |
-
shape (B, T_txt)
|
1306 |
-
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
|
1307 |
-
list of length = number of decoder layers in the LM
|
1308 |
-
exact implementation depends on LM, see Hugging Face docs
|
1309 |
-
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
|
1310 |
-
shape (B, T_txt)
|
1311 |
-
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
|
1312 |
-
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
|
1313 |
-
If True, includes key_values, media_locations, and vision_tokens in the output.
|
1314 |
-
"""
|
1315 |
-
assert not (past_vision_tokens is None) ^ (
|
1316 |
-
past_media_locations is None
|
1317 |
-
), "past_vision_tokens and past_media_locations must both be None or both be not None"
|
1318 |
-
|
1319 |
-
# convert pixels to vision tokens
|
1320 |
-
vision_attention_mask = None
|
1321 |
-
if vision_x is not None:
|
1322 |
-
if self.image_aspect_ratio == 'anyres':
|
1323 |
-
input_dict = dict(image=vision_x, image_size=image_size)
|
1324 |
-
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
|
1325 |
-
else:
|
1326 |
-
vision_features = self._encode_vision_x(vision_x=vision_x)
|
1327 |
-
vision_attn_masks = None
|
1328 |
-
# Same for attention masks: [b, Np, v] -> [b*Np, v]
|
1329 |
-
if self.anyres_patch_sampling:
|
1330 |
-
split_sizes = [feature.shape[0] for feature in vision_features]
|
1331 |
-
# Nested splits for multi-image samples.
|
1332 |
-
if isinstance(vision_x[0], list):
|
1333 |
-
nt_images = [len(images) for images in vision_x]
|
1334 |
-
split_split_sizes = []
|
1335 |
-
img_id = 0
|
1336 |
-
for nt in nt_images:
|
1337 |
-
split_split_sizes.append(split_sizes[img_id:img_id+nt])
|
1338 |
-
img_id += nt
|
1339 |
-
else:
|
1340 |
-
nt_images = [1] * len(vision_x)
|
1341 |
-
split_split_sizes = split_sizes
|
1342 |
-
vision_features = torch.cat(vision_features, dim=0)
|
1343 |
-
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
|
1344 |
-
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
|
1345 |
-
# TODO: add an option that allows restoring the T dimension for video tokenization.
|
1346 |
-
vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
|
1347 |
-
|
1348 |
-
# Post-processing: Split the batches into groups of patches and concatenate them together.
|
1349 |
-
if self.anyres_patch_sampling:
|
1350 |
-
# assert isinstance(vision_x, list)
|
1351 |
-
if isinstance(vision_x[0], list):
|
1352 |
-
vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
|
1353 |
-
vision_tokens = []
|
1354 |
-
|
1355 |
-
for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
|
1356 |
-
patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
|
1357 |
-
flatten_vision_tokens = []
|
1358 |
-
for image_vis_token in patch_vis_token_groups:
|
1359 |
-
image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
|
1360 |
-
flatten_vision_tokens.append(image_vis_token)
|
1361 |
-
vision_tokens_i = flatten_vision_tokens
|
1362 |
-
vision_tokens.append(vision_tokens_i)
|
1363 |
-
else:
|
1364 |
-
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
|
1365 |
-
vision_tokens = []
|
1366 |
-
for patch_vis_tokens in vision_token_groups:
|
1367 |
-
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
|
1368 |
-
vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
|
1369 |
-
else:
|
1370 |
-
vision_tokens = None
|
1371 |
-
|
1372 |
-
# fuse the vision and language tokens
|
1373 |
-
new_inputs = self._prepare_inputs_for_forward(
|
1374 |
-
vision_tokens=vision_tokens,
|
1375 |
-
lang_x=lang_x,
|
1376 |
-
attention_mask=attention_mask,
|
1377 |
-
vision_attention_mask=vision_attention_mask,
|
1378 |
-
labels=labels,
|
1379 |
-
past_key_values=past_key_values,
|
1380 |
-
past_media_locations=past_media_locations,
|
1381 |
-
padding_side="right",
|
1382 |
-
past_vision_tokens=past_vision_tokens,
|
1383 |
-
)
|
1384 |
-
output = self.lang_model(
|
1385 |
-
**new_inputs,
|
1386 |
-
use_cache=use_cache,
|
1387 |
-
past_key_values=past_key_values,
|
1388 |
-
**kwargs,
|
1389 |
-
)
|
1390 |
-
|
1391 |
-
# postforward hooks
|
1392 |
-
self._post_forward_hook()
|
1393 |
-
return output
|
1394 |
-
|
1395 |
def generate(
|
1396 |
self,
|
1397 |
vision_x: torch.Tensor,
|
@@ -1429,7 +1304,7 @@ class XGenMMPerceiver(VLMWithLanguageStream):
|
|
1429 |
else:
|
1430 |
vision_features = self._encode_vision_x(vision_x=vision_x)
|
1431 |
vision_attn_masks = None
|
1432 |
-
#
|
1433 |
# Same for attention masks: [b, Np, v] -> [b*Np, v]
|
1434 |
if self.anyres_patch_sampling:
|
1435 |
split_sizes = [feature.shape[0] for feature in vision_features]
|
|
|
1043 |
multimodal_labels.append(labels[i].clone())
|
1044 |
continue
|
1045 |
|
|
|
|
|
|
|
|
|
1046 |
# loop through the image_token_idxs and insert the vision tokens
|
1047 |
new_embed = lang_embeds[i].clone()
|
1048 |
new_attention_mask = (
|
|
|
1052 |
new_label = labels[i].clone()
|
1053 |
|
1054 |
for img_num, img_idx in enumerate(image_token_idxs):
|
|
|
|
|
|
|
1055 |
# Get vision token attention mask for padded llava-style any resolution image tokens.
|
1056 |
if self.image_aspect_ratio =='anyres':
|
1057 |
num_vis_tokens = vision_tokens[i][img_num].shape[0]
|
|
|
1071 |
vis_attention_mask = torch.ones(
|
1072 |
num_vis_tokens, dtype=torch.long
|
1073 |
).to(attention_mask.device)
|
|
|
1074 |
|
1075 |
new_embed = torch.cat(
|
1076 |
(
|
|
|
1267 |
"""
|
1268 |
return True
|
1269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1270 |
def generate(
|
1271 |
self,
|
1272 |
vision_x: torch.Tensor,
|
|
|
1304 |
else:
|
1305 |
vision_features = self._encode_vision_x(vision_x=vision_x)
|
1306 |
vision_attn_masks = None
|
1307 |
+
# If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
|
1308 |
# Same for attention masks: [b, Np, v] -> [b*Np, v]
|
1309 |
if self.anyres_patch_sampling:
|
1310 |
split_sizes = [feature.shape[0] for feature in vision_features]
|