Cheers / image_processing_umm.py
PengDa02's picture
upload Cheers-v1.0
bb46e5d verified
from networkx import to_numpy_array
import numpy as np
import torch
from PIL import Image, ImageOps
import math
from functools import partial, reduce
from transformers.image_transforms import (
convert_to_rgb,
center_crop,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array
class UMMImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values", "grid_hws"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(256, 256),
crop_size = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
scale_resolution=256,
patch_size=16,
**kwargs,
):
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.scale_resolution = scale_resolution
self.patch_size = patch_size
def preprocess(self, image, max_resolution=None, return_tensors = 'pt', und=True, **kwargs) -> BatchFeature:
if max_resolution is not None:
scale_resolution = max_resolution
else:
scale_resolution = self.scale_resolution
if image is not None:
pixel_values, grid_hws = [], []
if und:
image = self._preprocess_und(image, scale_resolution)
else:
image = self._preprocess_gen(image, scale_resolution)
if not torch.is_tensor(image):
image = torch.tensor(image)
_,H,W = image.shape
grid_h = int(H // self.patch_size)
grid_w = int(W // self.patch_size)
grid_hw = (grid_h, grid_w)
pixel_values = torch.stack([image], dim=0)
grid_hws = torch.tensor([grid_hw])
data = {
"pixel_values": pixel_values,
"grid_hws": grid_hws
}
return BatchFeature(data=data, tensor_type=return_tensors)
def _preprocess_gen(self, source_image, scale_resolution):
w, h = source_image.size
scale = scale_resolution / min(h, w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
source_image = source_image.resize((new_w, new_h), Image.Resampling.BICUBIC)
source_image = [source_image]
transforms = [
convert_to_rgb,
to_numpy_array,
]
transforms.append(partial(center_crop, size=(scale_resolution, scale_resolution)))
transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
return image[0] if len(image) == 1 else image
def _preprocess_und(self, source_image, scale_resolution):
w, h = source_image.size
scale = min(scale_resolution / h, scale_resolution / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
resized_image = source_image.resize((new_w, new_h), Image.Resampling.BICUBIC)
pad_w = scale_resolution - new_w
pad_h = scale_resolution - new_h
left = pad_w // 2
right = pad_w - left
top = pad_h // 2
bottom = pad_h - top
new_image = ImageOps.expand(resized_image, border=(left, top, right, bottom), fill=(0,0,0))
# new_image.save("test_path")
source_image = [new_image]
transforms = [
convert_to_rgb,
to_numpy_array
]
transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
return image[0] if len(image) == 1 else image
__all__ = ["UMMImageProcessor"]