Spaces:
Paused
Paused
| import os | |
| import importlib | |
| import imageio | |
| import torch | |
| import rembg | |
| import numpy as np | |
| import PIL.Image | |
| from PIL import Image | |
| from typing import Any | |
| from torchvision import transforms | |
| def instantiate_from_config(config): | |
| if not "target" in config: | |
| if config == '__is_first_stage__': | |
| return None | |
| elif config == "__is_unconditional__": | |
| return None | |
| raise KeyError("Expected key `target` to instantiate.") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def get_obj_from_str(string, reload=False): | |
| module, cls = string.rsplit(".", 1) | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| # def resize_without_crop(pil_image, target_width, target_height): | |
| # resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) | |
| # return np.array(resized_image)[:, :, :3] | |
| # @torch.inference_mode() | |
| # def numpy2pytorch(imgs): | |
| # h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 255.0 * 2.0 - 1.0 | |
| # h = h.movedim(-1, 1) | |
| # return h | |
| # @torch.inference_mode() | |
| # def remove_background( | |
| # image: PIL.Image.Image, | |
| # rembg: Any = None, | |
| # force: bool = False, | |
| # **rembg_kwargs, | |
| # ) -> PIL.Image.Image: | |
| # do_remove = True | |
| # if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
| # do_remove = False | |
| # do_remove = do_remove or force | |
| # if do_remove: | |
| # W, H = image.size | |
| # k = (256.0 / float(H * W)) ** 0.5 | |
| # feed = resize_without_crop(image, int(64 * round(W * k)), int(64 * round(H * k))) | |
| # feed = numpy2pytorch([feed]).to(device=rembg.device, dtype=torch.float32) | |
| # alpha = rembg(feed)[0][0] | |
| # alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear") | |
| # alpha = alpha.squeeze().clamp(0, 1) | |
| # alpha = (alpha * 255).cpu().data.numpy().astype(np.uint8) | |
| # alpha = Image.fromarray(alpha) | |
| # no_bg_image = Image.new("RGBA", alpha.size, (0, 0, 0, 0)) | |
| # no_bg_image.paste(image, mask=alpha) | |
| # image = no_bg_image | |
| # return image | |
| def remove_background( | |
| image: PIL.Image.Image, | |
| rembg: Any = None, | |
| force: bool = False, | |
| **rembg_kwargs, | |
| ) -> PIL.Image.Image: | |
| do_remove = True | |
| if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
| do_remove = False | |
| do_remove = do_remove or force | |
| if do_remove: | |
| transform_image = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| image = image.convert('RGB') | |
| input_images = transform_image(image).unsqueeze(0).to(rembg.device) | |
| with torch.no_grad(): | |
| preds = rembg(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image.size) | |
| image.putalpha(mask) | |
| return image | |
| # def remove_background(image: PIL.Image.Image, | |
| # rembg_session: Any = None, | |
| # force: bool = False, | |
| # **rembg_kwargs, | |
| # ) -> PIL.Image.Image: | |
| # do_remove = True | |
| # if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
| # do_remove = False | |
| # do_remove = do_remove or force | |
| # if do_remove: | |
| # image = rembg.remove(image, session=rembg_session, **rembg_kwargs) | |
| # return image | |
| def resize_foreground( | |
| image: PIL.Image.Image, | |
| ratio: float, | |
| ) -> PIL.Image.Image: | |
| image = np.array(image) | |
| assert image.shape[-1] == 4 | |
| alpha = np.where(image[..., 3] > 0) | |
| y1, y2, x1, x2 = ( | |
| alpha[0].min(), | |
| alpha[0].max(), | |
| alpha[1].min(), | |
| alpha[1].max(), | |
| ) | |
| # crop the foreground | |
| fg = image[y1:y2, x1:x2] | |
| # pad to square | |
| size = max(fg.shape[0], fg.shape[1]) | |
| ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
| ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
| new_image = np.pad( | |
| fg, | |
| ((ph0, ph1), (pw0, pw1), (0, 0)), | |
| mode="constant", | |
| constant_values=((0, 0), (0, 0), (0, 0)), | |
| ) | |
| # compute padding according to the ratio | |
| new_size = int(new_image.shape[0] / ratio) | |
| # pad to size, double side | |
| ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
| ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
| new_image = np.pad( | |
| new_image, | |
| ((ph0, ph1), (pw0, pw1), (0, 0)), | |
| mode="constant", | |
| constant_values=((0, 0), (0, 0), (0, 0)), | |
| ) | |
| new_image = Image.fromarray(new_image) | |
| return new_image | |
| def rgba_to_white_background(image: PIL.Image.Image) -> torch.Tensor: | |
| image = np.asarray(image, dtype=np.float32) / 255.0 | |
| image = torch.from_numpy(image).movedim(2, 0).float() | |
| image, alpha = image.split([3, 1], dim=0) | |
| image = image * alpha + torch.ones_like(image) * (1 - alpha) | |
| return image, alpha | |
| def save_video( | |
| frames: torch.Tensor, | |
| output_path: str, | |
| fps: int = 30, | |
| ) -> None: | |
| # images: (N, C, H, W) | |
| frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] | |
| writer = imageio.get_writer(output_path, mode='I', fps=fps, codec='libx264') | |
| for frame in frames: | |
| writer.append_data(frame) | |
| writer.close() |