import importlib import torchvision import torch from torch import optim import numpy as np from inspect import isfunction from PIL import Image, ImageDraw, ImageFont import os import numpy as np import matplotlib.pyplot as plt from PIL import Image import torch import time import cv2 from carvekit.api.high import HiInterface import PIL def pil_rectangle_crop(im): width, height = im.size # Get dimensions if width <= height: left = 0 right = width top = (height - width)/2 bottom = (height + width)/2 else: top = 0 bottom = height left = (width - height) / 2 bottom = (width + height) / 2 # Crop the center of the image im = im.crop((left, top, right, bottom)) return im def add_margin(pil_img, color, size=256): width, height = pil_img.size result =, (size, size), color) result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) return result def create_carvekit_interface(): # Check doc strings for more information interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". batch_size_seg=5, batch_size_matting=1, device='cuda' if torch.cuda.is_available() else 'cpu', seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=False) return interface def load_and_preprocess(interface, input_im): ''' :param input_im (PIL Image). :return image (H, W, 3) array in [0, 1]. ''' # See image = input_im.convert('RGB') image_without_background = interface([image])[0] image_without_background = np.array(image_without_background) est_seg = image_without_background > 127 image = np.array(image) foreground = est_seg[:, : , -1].astype(np.bool_) image[~foreground] = [255., 255., 255.] x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) image = image[y:y+h, x:x+w, :] image = PIL.Image.fromarray(np.array(image)) # resize image such that long edge is 512 image.thumbnail([200, 200], Image.Resampling.LANCZOS) image = add_margin(image, (255, 255, 255), size=256) image = np.array(image) return image def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt ="RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) nc = int(40 * (wh[0] / 256)) lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x,torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) class AdamWwithEMAandWings(optim.Optimizer): # credit to def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code ema_power=1., param_names=()): """AdamW that saves EMA versions of the parameters.""" if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0.0 <= ema_decay <= 1.0: raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, ema_power=ema_power, param_names=param_names) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] ema_params_with_grad = [] state_sums = [] max_exp_avg_sqs = [] state_steps = [] amsgrad = group['amsgrad'] beta1, beta2 = group['betas'] ema_decay = group['ema_decay'] ema_power = group['ema_power'] for p in group['params']: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: raise RuntimeError('AdamW does not support sparse gradients') grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of parameter values state['param_exp_avg'] = p.detach().float().clone() exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) ema_params_with_grad.append(state['param_exp_avg']) if amsgrad: max_exp_avg_sqs.append(state['max_exp_avg_sq']) # update the steps for each param group update state['step'] += 1 # record the step after step update state_steps.append(state['step']) optim._functional.adamw(params_with_grad, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad=amsgrad, beta1=beta1, beta2=beta2, lr=group['lr'], weight_decay=group['weight_decay'], eps=group['eps'], maximize=False) cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) for param, ema_param in zip(params_with_grad, ema_params_with_grad): ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) return loss def get_state_dict(d): return d.get('state_dict', d) def load_state_dict(ckpt_path, location='cpu'): _, extension = os.path.splitext(ckpt_path) if extension.lower() == ".safetensors": import safetensors.torch state_dict = safetensors.torch.load_file(ckpt_path, device=location) else: state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) state_dict = get_state_dict(state_dict) print(f'Loaded state_dict from [{ckpt_path}]') return state_dict def img2tensor(imgs, bgr2rgb=True, float32=True): """Numpy array to tensor. Args: imgs (list[ndarray] | ndarray): Input images. bgr2rgb (bool): Whether to change bgr to rgb. float32 (bool): Whether to change to float32. Returns: list[tensor] | tensor: Tensor images. If returned results only have one element, just return tensor. """ def _totensor(img, bgr2rgb, float32): if img.shape[2] == 3 and bgr2rgb: if img.dtype == 'float64': img = img.astype('float32') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img.transpose(2, 0, 1)) if float32: img = img.float() return img if isinstance(imgs, list): return [_totensor(img, bgr2rgb, float32) for img in imgs] else: return _totensor(imgs, bgr2rgb, float32)