import os import imageio import numpy as np import glob import sys from typing import Any sys.path.insert(1, '.') import argparse from pytorch_lightning import seed_everything from PIL import Image import torch from operators import GaussialBlurOperator from utils import get_rank from torchvision.ops import masks_to_boxes from matfusion import MateralDiffusion from loguru import logger __MAX_BATCH__ = 4 # 4 for A10 def init_model(ckpt_path, ddim, gpu_id): # find config configs = os.listdir(f'{ckpt_path}/configs') model_config = [config for config in configs if "project.yaml" in config][0] sds_loss_class = MateralDiffusion(device=gpu_id, fp16=True, config=f'{ckpt_path}/configs/{model_config}', ckpt=f'{ckpt_path}/checkpoints/last.ckpt', vram_O=False, t_range=[0.001, 0.02], opt=None, use_ddim=ddim) return sds_loss_class def images_spliter(image, seg_h, seg_w, padding_pixel, padding_val, overlaps=1): # split the input images along height and weidth by # return a list of images h, w, c = image.shape h = h - (h%(seg_h*overlaps)) w = w - (w%(seg_w*overlaps)) h_crop = h // seg_h w_crop = w // seg_w images = [] positions = [] img_padded = torch.zeros(h+padding_pixel*2, w+padding_pixel*2, 3, device=image.device) + padding_val img_padded[padding_pixel:h+padding_pixel, padding_pixel:w+padding_pixel, :] = image[:h, :w] # overlapped sampling seg_h = np.round((h - h_crop) / h_crop * overlaps).astype(int) + 1 seg_w = np.round((w - w_crop) / w_crop * overlaps).astype(int) + 1 h_step = np.round(h_crop / overlaps).astype(int) w_step = np.round(w_crop / overlaps).astype(int) # print(f"h_step: {h_step}, seg_h: {seg_h}, w_step: {w_step}, seg_w: {seg_w}, img_padded: {img_padded.shape}, image[:h, :w]: {image[:h, :w].shape}") for ind_i in range(0,seg_h): i = ind_i * h_step for ind_j in range(0,seg_w): j = ind_j * w_step img_ = img_padded[i:i+h_crop+padding_pixel*2, j:j+w_crop+padding_pixel*2, :] images.append(img_) positions.append(torch.FloatTensor([i-padding_pixel, j-padding_pixel]).reshape(2)) return torch.stack(images, dim=0), torch.stack(positions, dim=0), seg_h, seg_w class InferenceModel(): def __init__(self, ckpt_path, use_ddim, gpu_id=0): self.model = init_model(ckpt_path, use_ddim, gpu_id=gpu_id) self.gpu_id = gpu_id self.split_hw = [1,1] self.padding = 0 self.padding_crop = 0 self.results_list = None self.results_output_list = [] self.image_sizes_list = [] def parse_item(self, img_ori, mask_img_ori, guid_images): # if mask_img_ori is None: # mask_img_ori = read_img(input_name, read_alpha=True) # # ensure background is white, same as training data # img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1 img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1 use_true_mask = (self.split_hw[0] * self.split_hw[1]) <= 1 self.ori_hw = list(img_ori.shape) # mask cropping min_max_uv = masks_to_boxes(mask_img_ori[None, ..., -1] > 0.5).long() self.min_uv, self.max_uv = min_max_uv[0, ..., [1,0]], min_max_uv[0, ..., [3,2]]+1 # print(self.min_uv, self.max_uv) mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] image_size = list(img.shape) if not use_true_mask: # for cropping boarder self.max_uv[0] = self.max_uv[0] - ((self.max_uv[0]-self.min_uv[0])%(self.split_hw[0]*self.split_overlap)) self.max_uv[1] = self.max_uv[1] - ((self.max_uv[1]-self.min_uv[1])%(self.split_hw[1]*self.split_overlap)) mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] image_size = list(img.shape) if not use_true_mask: mask_img = torch.ones_like(mask_img) mask_img, _ = images_spliter(mask_img[..., [0, 0, 0]], self.split_hw[0], self.split_hw[1], self.padding, not use_true_mask, self.split_overlap)[:2] img, position_indexes, seg_h, seg_w = images_spliter(img, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap) self.split_hw_overlapped = [seg_h, seg_w] logger.info(f"Spliting Size: {image_size}, splits: {self.split_hw}, Overlapped: {self.split_hw_overlapped}") if guid_images is None: guid_images = torch.zeros_like(img) else: guid_images = guid_images[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] guid_images, _ = images_spliter(guid_images, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)[:2] return guid_images, img, mask_img[..., :1], image_size, position_indexes def prepare_batch(self, guid_img, img_ori, mask_img_ori, batch_size): input_img = [] cond_img = [] mask_img = [] image_size = [] position_indexes = [] for i in range(batch_size): _input_img, _cond_img, _mask_img, _image_size, _position_indexes = \ self.parse_item(img_ori, mask_img_ori, guid_img) input_img.append(_input_img) cond_img.append(_cond_img) mask_img.append(_mask_img) position_indexes.append(_position_indexes) image_size += [_image_size] * _input_img.shape[0] input_img = torch.cat(input_img, dim=0).to(self.gpu_id) cond_img = torch.cat(cond_img, dim=0).to(self.gpu_id) mask_img = torch.cat(mask_img, dim=0).to(self.gpu_id) position_indexes = torch.cat(position_indexes, dim=0).to(self.gpu_id) return input_img, cond_img, mask_img, image_size, position_indexes def assemble_results(self, img_out, img_hw=None, position_index=None, default_val=1): results_img = np.zeros((img_hw[0], img_hw[1], 3)) weight_img = np.zeros((img_hw[0], img_hw[1], 3)) + 1e-5 for i in range(position_index.shape[0]): # crop out boarder crop_h, crop_w = img_out[i].shape[:2] pathed_img = img_out[i][self.padding_crop:crop_h-self.padding_crop, self.padding_crop:crop_w-self.padding_crop] position_index[i] += self.padding_crop crop_h, crop_w = pathed_img.shape[:2] crop_x, crop_y = max(position_index[i][0], 0), max(position_index[i][1], 0) shape_max = results_img[crop_x:crop_x+crop_h, crop_y:crop_y+crop_w].shape[:2] start_crop_x, start_crop_y = abs(min(position_index[i][0], 0)), abs(min(position_index[i][1], 0)) # print(pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]].shape, crop_x, crop_y, position_index[i]) results_img[crop_x:crop_x+shape_max[0]-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]] weight_img[crop_x:crop_x+crop_h-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += 1 img_out = results_img / weight_img img_out[weight_img[:,:,0] < 1] = 255 # print(img_out.shape, weight_img.shape, np.unique(weight_img), pathed_img.dtype) img_out_ = (np.zeros((self.ori_hw[0], self.ori_hw[1], 3)) + default_val) * 255 img_out_[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] = img_out img_out = img_out_ return img_out def write_batch_img(self, imgs, image_sizes, position_indexes): cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1] if self.results_list is None or self.results_list.shape[0] == 0: self.results_list = imgs self.position_indexes = position_indexes else: self.results_list = torch.cat([self.results_list, imgs], dim=0) self.position_indexes = torch.cat([self.position_indexes, position_indexes], dim=0) self.image_sizes_list += image_sizes valid_len = self.results_list.shape[0] - (self.results_list.shape[0] % cropped_batch) out_images = [] for ind in range(0, valid_len, cropped_batch): # assemble results img_out = (self.results_list[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8) img_out = self.assemble_results(img_out, self.image_sizes_list[ind], self.position_indexes[ind:ind+cropped_batch].detach().cpu().numpy().astype(int)) # Image.fromarray(img_out.astype(np.uint8)).save(self.results_output_list[ind]) out_images.append(img_out.astype(np.uint8)) self.results_list = self.results_list[valid_len:] self.position_indexes = self.position_indexes[valid_len:] self.image_sizes_list = self.image_sizes_list[valid_len:] return out_images def write_batch_input(self, imgs, image_sizes, position_indexes, default_val=1): cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1] images = [] valid_len = imgs.shape[0] for ind in range(0, valid_len, cropped_batch): # assemble results img_out = (imgs[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8) img_out = self.assemble_results(img_out, image_sizes[ind], position_indexes.detach().cpu().numpy().astype(int), default_val).astype(np.uint8) images.append(img_out) return images def generation(self, split_hw, split_overlap, guid_img, img_ori, mask_img_ori, dps_scale, uc_score, ddim_steps, batch_size=32, n_samples=1): max_batch = __MAX_BATCH__ operator = GaussialBlurOperator(61, 3.0, self.gpu_id) assert batch_size == 1 self.split_resolution = None self.split_overlap = split_overlap self.split_hw = split_hw # get img hw for src_img_id in range(0, 1, batch_size): input_img, cond_img, mask_img, image_sizes, position_indexes = self.prepare_batch(guid_img, img_ori, mask_img_ori, 1) input_masked = self.write_batch_input(cond_img, image_sizes, position_indexes) input_maskes = self.write_batch_input(mask_img, image_sizes, position_indexes, 0) results_all = [] for _ in range(n_samples): for batch_id in range(0, input_img.shape[0], max_batch): embeddings = {} embeddings["cond_img"] = cond_img[batch_id:batch_id+max_batch] if (mask_img[batch_id:batch_id+max_batch] > 0.5).sum() == 0: results = torch.ones_like(cond_img[batch_id:batch_id+max_batch]) else: results = self.model(embeddings, input_img[batch_id:batch_id+max_batch], mask_img[batch_id:batch_id+max_batch], ddim_steps=ddim_steps, guidance_scale=uc_score, dps_scale=dps_scale, as_latent=False, grad_scale=1, operator=operator) out_images = self.write_batch_img(results, image_sizes[batch_id:batch_id+max_batch], position_indexes[batch_id:batch_id+max_batch]) results_all += out_images ret = { "input_image": input_masked, "input_maskes": input_maskes, "out_images": results_all } return ret