|
from transformers import PreTrainedModel, ViTMAEModel |
|
from .configuration_magiv2 import Magiv2Config |
|
import torch |
|
import numpy as np |
|
from transformers import ViTImageProcessor |
|
import PIL |
|
|
|
def move_to_device(inputs, device): |
|
if hasattr(inputs, "keys"): |
|
return {k: move_to_device(v, device) for k, v in inputs.items()} |
|
elif isinstance(inputs, list): |
|
return [move_to_device(v, device) for v in inputs] |
|
elif isinstance(inputs, tuple): |
|
return tuple([move_to_device(v, device) for v in inputs]) |
|
elif isinstance(inputs, np.ndarray): |
|
return torch.from_numpy(inputs).to(device) |
|
else: |
|
return inputs.to(device) |
|
|
|
class Magiv2Model(PreTrainedModel): |
|
config_class = Magiv2Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.processor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config) |
|
self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config) |
|
|
|
def move_to_device(self, input): |
|
return move_to_device(input, self.device) |
|
|
|
def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256): |
|
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn |
|
if len(images) == 0: |
|
return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size)) |
|
|
|
assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images" |
|
images = [np.array(image.convert("L").convert("RGB")) for image in images] |
|
images = self.processor(images, return_tensors="pt").pixel_values |
|
images = move_to_device_fn(images) |
|
|
|
|
|
old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio |
|
self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio |
|
|
|
|
|
embeddings = [] |
|
for i in range(0, len(images), batch_size): |
|
crops = images[i:i+batch_size] |
|
embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0] |
|
embeddings.append(embeddings_per_batch) |
|
embeddings = torch.cat(embeddings, dim=0) |
|
|
|
|
|
self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio |
|
|
|
return embeddings |