|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from transformers.image_processing_utils import BaseImageProcessor |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class VQModelImageProcessor(BaseImageProcessor): |
|
def __init__( |
|
self, |
|
size: int = 256, |
|
convert_rgb: bool = False, |
|
resample: Image.Resampling = Image.Resampling.LANCZOS, |
|
**kwargs: dict, |
|
) -> None: |
|
self.size = size |
|
self.convert_rgb = convert_rgb |
|
self.resample = resample |
|
|
|
def __call__(self, image: Image.Image) -> dict: |
|
return self.preprocess(image) |
|
|
|
def preprocess(self, image: Image.Image) -> dict: |
|
width, height = image.size |
|
size = (self.size, self.size) |
|
image = image.resize(size, resample=self.resample) |
|
image = image.convert("RGBA") |
|
|
|
if self.convert_rgb: |
|
|
|
image_new = Image.new("RGB", image.size, (255, 255, 255)) |
|
image_new.paste(image, mask=image.split()[3]) |
|
image = image_new |
|
|
|
return { |
|
"image": self.to_tensor(image), |
|
"width": width, |
|
"height": height, |
|
} |
|
|
|
def to_tensor(self, image: Image.Image) -> torch.Tensor: |
|
x = np.array(image) / 127.5 - 1.0 |
|
x = x.transpose(2, 0, 1).astype(np.float32) |
|
return torch.as_tensor(x) |
|
|
|
def postprocess( |
|
self, |
|
x: torch.Tensor, |
|
width: int | None = None, |
|
height: int | None = None, |
|
) -> Image.Image: |
|
x_np = x.detach().cpu().numpy() |
|
x_np = x_np.transpose(1, 2, 0) |
|
x_np = (x_np + 1.0) * 127.5 |
|
x_np = np.clip(x_np, 0, 255).astype(np.uint8) |
|
image = Image.fromarray(x_np) |
|
|
|
|
|
width = width or self.size |
|
height = height or self.size |
|
image = image.resize((width, height), resample=self.resample) |
|
|
|
return image |
|
|