import torch import numpy as np from PIL import Image import gc import numpy as np import numpy as np from PIL import Image from scripts.refine_lr_to_sr import run_sr_fast GRADIO_CACHE = "/tmp/gradio/" def clean_up(): torch.cuda.empty_cache() gc.collect() def remove_color(arr): if arr.shape[-1] == 4: arr = arr[..., :3] # calc diffs base = arr[0, 0] diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) alpha = (diffs <= 80) arr[alpha] = 255 alpha = ~alpha arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) return arr def simple_remove(imgs, run_sr=True): """Only works for normal""" if not isinstance(imgs, list): imgs = [imgs] single_input = True else: single_input = False if run_sr: imgs = run_sr_fast(imgs) rets = [] for img in imgs: arr = np.array(img) arr = remove_color(arr) rets.append(Image.fromarray(arr.astype(np.uint8))) if single_input: return rets[0] return rets def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"): new_image = Image.new("RGBA", rgba.size, bkgd) new_image.paste(rgba, (0, 0), rgba) new_image = new_image.convert('RGB') return new_image def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"): rgb_white = rgba_to_rgb(rgba, bkgd) new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1)) return new_rgba def split_image(image, rows=None, cols=None): """ inverse function of make_image_grid """ # image is in square if rows is None and cols is None: # image.size [W, H] rows = 1 cols = image.size[0] // image.size[1] assert cols * image.size[1] == image.size[0] subimg_size = image.size[1] elif rows is None: subimg_size = image.size[0] // cols rows = image.size[1] // subimg_size assert rows * subimg_size == image.size[1] elif cols is None: subimg_size = image.size[1] // rows cols = image.size[0] // subimg_size assert cols * subimg_size == image.size[0] else: subimg_size = image.size[1] // rows assert cols * subimg_size == image.size[0] subimgs = [] for i in range(rows): for j in range(cols): subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size)) subimgs.append(subimg) return subimgs def make_image_grid(images, rows=None, cols=None, resize=None): if rows is None and cols is None: rows = 1 cols = len(images) if rows is None: rows = len(images) // cols if len(images) % cols != 0: rows += 1 if cols is None: cols = len(images) // rows if len(images) % rows != 0: cols += 1 total_imgs = rows * cols if total_imgs > len(images): images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))] if resize is not None: images = [img.resize((resize, resize)) for img in images] w, h = images[0].size grid = Image.new(images[0].mode, size=(cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid