Unique3D / gradio_app /utils.py
Wuvin's picture
rename files
5a3e910
import torch
import numpy as np
from PIL import Image
import gc
import numpy as np
import numpy as np
from PIL import Image
from scripts.refine_lr_to_sr import run_sr_fast
GRADIO_CACHE = "/tmp/gradio/"
def clean_up():
torch.cuda.empty_cache()
gc.collect()
def remove_color(arr):
if arr.shape[-1] == 4:
arr = arr[..., :3]
# calc diffs
base = arr[0, 0]
diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
alpha = (diffs <= 80)
arr[alpha] = 255
alpha = ~alpha
arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
return arr
def simple_remove(imgs, run_sr=True):
"""Only works for normal"""
if not isinstance(imgs, list):
imgs = [imgs]
single_input = True
else:
single_input = False
if run_sr:
imgs = run_sr_fast(imgs)
rets = []
for img in imgs:
arr = np.array(img)
arr = remove_color(arr)
rets.append(Image.fromarray(arr.astype(np.uint8)))
if single_input:
return rets[0]
return rets
def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
new_image = Image.new("RGBA", rgba.size, bkgd)
new_image.paste(rgba, (0, 0), rgba)
new_image = new_image.convert('RGB')
return new_image
def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
rgb_white = rgba_to_rgb(rgba, bkgd)
new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
return new_rgba
def split_image(image, rows=None, cols=None):
"""
inverse function of make_image_grid
"""
# image is in square
if rows is None and cols is None:
# image.size [W, H]
rows = 1
cols = image.size[0] // image.size[1]
assert cols * image.size[1] == image.size[0]
subimg_size = image.size[1]
elif rows is None:
subimg_size = image.size[0] // cols
rows = image.size[1] // subimg_size
assert rows * subimg_size == image.size[1]
elif cols is None:
subimg_size = image.size[1] // rows
cols = image.size[0] // subimg_size
assert cols * subimg_size == image.size[0]
else:
subimg_size = image.size[1] // rows
assert cols * subimg_size == image.size[0]
subimgs = []
for i in range(rows):
for j in range(cols):
subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
subimgs.append(subimg)
return subimgs
def make_image_grid(images, rows=None, cols=None, resize=None):
if rows is None and cols is None:
rows = 1
cols = len(images)
if rows is None:
rows = len(images) // cols
if len(images) % cols != 0:
rows += 1
if cols is None:
cols = len(images) // rows
if len(images) % rows != 0:
cols += 1
total_imgs = rows * cols
if total_imgs > len(images):
images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
w, h = images[0].size
grid = Image.new(images[0].mode, size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid