Shap_Editor_demo / shap_e /util /image_util.py
silentchen's picture
first commit
19c4ddf
raw
history blame contribute delete
No virus
5.22 kB
import random
from typing import Any, List, Optional, Union
import blobfile as bf
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def center_crop(
img: Union[Image.Image, torch.Tensor, np.ndarray]
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
Center crops an image.
"""
if isinstance(img, (np.ndarray, torch.Tensor)):
height, width = img.shape[:2]
else:
width, height = img.size
size = min(width, height)
left, top = (width - size) // 2, (height - size) // 2
right, bottom = left + size, top + size
if isinstance(img, (np.ndarray, torch.Tensor)):
img = img[top:bottom, left:right]
else:
img = img.crop((left, top, right, bottom))
return img
def resize(
img: Union[Image.Image, torch.Tensor, np.ndarray],
*,
height: int,
width: int,
min_value: Optional[Any] = None,
max_value: Optional[Any] = None,
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
:param: img: image in HWC order
:return: currently written for downsampling
"""
orig, cls = img, type(img)
if isinstance(img, Image.Image):
img = np.array(img)
dtype = img.dtype
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
ndim = img.ndim
if img.ndim == 2:
img = img.unsqueeze(-1)
if min_value is None and max_value is None:
# .clamp throws an error when both are None
min_value = -np.inf
img = img.permute(2, 0, 1)
size = (height, width)
img = (
F.interpolate(img[None].float(), size=size, mode="area")[0]
.clamp(min_value, max_value)
.to(img.dtype)
.permute(1, 2, 0)
)
if ndim < img.ndim:
img = img.squeeze(-1)
if not isinstance(orig, torch.Tensor):
img = img.numpy()
img = img.astype(dtype)
if isinstance(orig, Image.Image):
img = Image.fromarray(img)
return img
def get_alpha(img: Image.Image) -> Image.Image:
"""
:return: the alpha channel separated out as a grayscale image
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
alpha = img_arr[:, :, 3]
else:
alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8)
alpha = Image.fromarray(alpha)
return alpha
def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image:
"""
No op if the image doesn't have an alpha channel.
:param: mode: Defaults to "random" but has an option to use a "black" or
"white" background
:return: image with alpha removed
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
# Add bg to get rid of alpha channel
if mode == "random":
height, width = img_arr.shape[:2]
bg = Image.fromarray(
random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width)
)
bg.paste(img, mask=img)
img = bg
elif mode == "black" or mode == "white":
img_arr = img_arr.astype(float)
rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255
background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255)
rgb = rgb * alpha + background * (1 - alpha)
img = Image.fromarray(np.round(rgb).astype(np.uint8))
return img
def _black_bg(h: int, w: int) -> np.ndarray:
return np.zeros([h, w, 3], dtype=np.uint8)
def _gray_bg(h: int, w: int) -> np.ndarray:
return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8)
def _checker_bg(h: int, w: int) -> np.ndarray:
checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w))))
c1 = np.random.randint(low=0, high=256)
c2 = np.random.randint(low=0, high=256)
xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1)
ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1)
fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0)
return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8)
def _noise_bg(h: int, w: int) -> np.ndarray:
return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8)
def load_image(image_path: str) -> Image.Image:
with bf.BlobFile(image_path, "rb") as thefile:
img = Image.open(thefile)
img.load()
return img
def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image:
"""
to test, run
>>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)]))
"""
images = list(map(np.array, images))
size = images[0].shape[0]
n = round_up(len(images), columns)
n_blanks = n - len(images)
images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks)
images = (
np.array(images)
.reshape(n // columns, columns, size, size, 3)
.transpose([0, 2, 1, 3, 4])
.reshape(n // columns * size, columns * size, 3)
)
return Image.fromarray(images)
def round_up(n: int, b: int) -> int:
return (n + b - 1) // b * b