magiv2-crop-embedder / modelling_magiv2.py
ragavsachdeva's picture
Update modelling_magiv2.py
75798cc verified
from transformers import PreTrainedModel, ViTMAEModel
from .configuration_magiv2 import Magiv2Config
import torch
import numpy as np
from transformers import ViTImageProcessor
import PIL
def move_to_device(inputs, device):
if hasattr(inputs, "keys"):
return {k: move_to_device(v, device) for k, v in inputs.items()}
elif isinstance(inputs, list):
return [move_to_device(v, device) for v in inputs]
elif isinstance(inputs, tuple):
return tuple([move_to_device(v, device) for v in inputs])
elif isinstance(inputs, np.ndarray):
return torch.from_numpy(inputs).to(device)
else:
return inputs.to(device)
class Magiv2Model(PreTrainedModel):
config_class = Magiv2Config
def __init__(self, config):
super().__init__(config)
self.config = config
self.processor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
def move_to_device(self, input):
return move_to_device(input, self.device)
def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
if len(images) == 0:
return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size))
assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images"
images = [np.array(image.convert("L").convert("RGB")) for image in images]
images = self.processor(images, return_tensors="pt").pixel_values
images = move_to_device_fn(images)
# temporarily change the mask ratio from default to the one specified
old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
# process the crops in batches to avoid OOM
embeddings = []
for i in range(0, len(images), batch_size):
crops = images[i:i+batch_size]
embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
embeddings.append(embeddings_per_batch)
embeddings = torch.cat(embeddings, dim=0)
# restore the mask ratio to the default
self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
return embeddings