| from typing import Any, Dict, List, Optional, Union
|
|
|
| import numpy as np
|
| import torch
|
| from PIL import Image
|
|
|
| from transformers import ImageProcessingMixin
|
|
|
|
|
| def _to_rgb(img: Image.Image) -> Image.Image:
|
| if img.mode != "RGB":
|
| return img.convert("RGB")
|
| return img
|
|
|
|
|
| class UpscalerImageProcessor(ImageProcessingMixin):
|
| """
|
| Minimal processor:
|
| - input: PIL or list of PIL
|
| - output: pixel_values float32 in [0,1], shape (B,3,H,W)
|
| No ImageNet normalization (recommended for SR trained on [0,1]).
|
| """
|
|
|
| model_input_names = ["pixel_values"]
|
|
|
| def __init__(self, **kwargs):
|
| super().__init__(**kwargs)
|
|
|
| def _pil_to_tensor_01(self, img: Image.Image) -> torch.FloatTensor:
|
| img = _to_rgb(img)
|
| arr = np.array(img, dtype=np.float32) / 255.0
|
| t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
| return t
|
|
|
| def __call__(
|
| self,
|
| images: Union[Image.Image, List[Image.Image]],
|
| return_tensors: Optional[str] = None,
|
| **kwargs,
|
| ) -> Dict[str, Any]:
|
| if isinstance(images, Image.Image):
|
| images = [images]
|
|
|
| tensors = [self._pil_to_tensor_01(im) for im in images]
|
| pixel_values = torch.stack(tensors, dim=0)
|
|
|
| if return_tensors is None or return_tensors == "pt":
|
| return {"pixel_values": pixel_values}
|
| raise ValueError("Only return_tensors=None or 'pt' is supported.") |