Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from PIL import Image | |
import gc | |
import numpy as np | |
import numpy as np | |
from PIL import Image | |
from scripts.refine_lr_to_sr import run_sr_fast | |
GRADIO_CACHE = "/tmp/gradio/" | |
def clean_up(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def remove_color(arr): | |
if arr.shape[-1] == 4: | |
arr = arr[..., :3] | |
# calc diffs | |
base = arr[0, 0] | |
diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) | |
alpha = (diffs <= 80) | |
arr[alpha] = 255 | |
alpha = ~alpha | |
arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) | |
return arr | |
def simple_remove(imgs, run_sr=True): | |
"""Only works for normal""" | |
if not isinstance(imgs, list): | |
imgs = [imgs] | |
single_input = True | |
else: | |
single_input = False | |
if run_sr: | |
imgs = run_sr_fast(imgs) | |
rets = [] | |
for img in imgs: | |
arr = np.array(img) | |
arr = remove_color(arr) | |
rets.append(Image.fromarray(arr.astype(np.uint8))) | |
if single_input: | |
return rets[0] | |
return rets | |
def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"): | |
new_image = Image.new("RGBA", rgba.size, bkgd) | |
new_image.paste(rgba, (0, 0), rgba) | |
new_image = new_image.convert('RGB') | |
return new_image | |
def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"): | |
rgb_white = rgba_to_rgb(rgba, bkgd) | |
new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1)) | |
return new_rgba | |
def split_image(image, rows=None, cols=None): | |
""" | |
inverse function of make_image_grid | |
""" | |
# image is in square | |
if rows is None and cols is None: | |
# image.size [W, H] | |
rows = 1 | |
cols = image.size[0] // image.size[1] | |
assert cols * image.size[1] == image.size[0] | |
subimg_size = image.size[1] | |
elif rows is None: | |
subimg_size = image.size[0] // cols | |
rows = image.size[1] // subimg_size | |
assert rows * subimg_size == image.size[1] | |
elif cols is None: | |
subimg_size = image.size[1] // rows | |
cols = image.size[0] // subimg_size | |
assert cols * subimg_size == image.size[0] | |
else: | |
subimg_size = image.size[1] // rows | |
assert cols * subimg_size == image.size[0] | |
subimgs = [] | |
for i in range(rows): | |
for j in range(cols): | |
subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size)) | |
subimgs.append(subimg) | |
return subimgs | |
def make_image_grid(images, rows=None, cols=None, resize=None): | |
if rows is None and cols is None: | |
rows = 1 | |
cols = len(images) | |
if rows is None: | |
rows = len(images) // cols | |
if len(images) % cols != 0: | |
rows += 1 | |
if cols is None: | |
cols = len(images) // rows | |
if len(images) % rows != 0: | |
cols += 1 | |
total_imgs = rows * cols | |
if total_imgs > len(images): | |
images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))] | |
if resize is not None: | |
images = [img.resize((resize, resize)) for img in images] | |
w, h = images[0].size | |
grid = Image.new(images[0].mode, size=(cols * w, rows * h)) | |
for i, img in enumerate(images): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |