pOps-space / model /pops_utils.py
pOpsPaper's picture
Added space
71d3bec
raw
history blame
No virus
1.8 kB
from typing import List, Tuple
import torch
from torch import nn
def preprocess(image_a: torch.Tensor, image_b: torch.Tensor, image_encoder: nn.Module, clip_mean: torch.Tensor,
clip_std: torch.Tensor, should_drop_cond: List[Tuple[bool, bool]] = None, concat_hidden_states=None,
image_list=None):
with torch.no_grad():
image_list = [] if image_list is None else image_list
additional_list = []
if image_a is not None:
additional_list.append(image_a)
if image_b is not None:
additional_list.append(image_b)
image_list = additional_list + image_list
embeds_list = []
for image in image_list:
# If already is vector skip encoder
if len(image.shape) == 2:
image_embeds = image
else:
encoder_outs = image_encoder(image, output_hidden_states=False)
image_embeds = encoder_outs.image_embeds
image_embeds = (image_embeds - clip_mean) / clip_std
embeds_list.append(image_embeds.unsqueeze(1))
if should_drop_cond is not None:
for b_ind in range(embeds_list[0].shape[0]):
should_drop_a, should_drop_b = should_drop_cond[b_ind]
if should_drop_a:
embeds_list[0][b_ind] = torch.zeros_like(embeds_list[0][b_ind])
if should_drop_b and image_b is not None:
embeds_list[1][b_ind] = torch.zeros_like(embeds_list[1][b_ind])
if concat_hidden_states is not None:
embeds_list.append(concat_hidden_states)
out_hidden_states = torch.concat(embeds_list, dim=1)
image_embeds = torch.zeros_like(embeds_list[0].squeeze(1))
return image_embeds, out_hidden_states