IntrinsicAnything / inference.py
burningdust
Initial commit
d72c37e
raw
history blame contribute delete
No virus
11.6 kB
import os
import imageio
import numpy as np
import glob
import sys
from typing import Any
sys.path.insert(1, '.')
import argparse
from pytorch_lightning import seed_everything
from PIL import Image
import torch
from operators import GaussialBlurOperator
from utils import get_rank
from torchvision.ops import masks_to_boxes
from matfusion import MateralDiffusion
from loguru import logger
__MAX_BATCH__ = 4 # 4 for A10
def init_model(ckpt_path, ddim, gpu_id):
# find config
configs = os.listdir(f'{ckpt_path}/configs')
model_config = [config for config in configs if "project.yaml" in config][0]
sds_loss_class = MateralDiffusion(device=gpu_id, fp16=True,
config=f'{ckpt_path}/configs/{model_config}',
ckpt=f'{ckpt_path}/checkpoints/last.ckpt', vram_O=False,
t_range=[0.001, 0.02], opt=None, use_ddim=ddim)
return sds_loss_class
def images_spliter(image, seg_h, seg_w, padding_pixel, padding_val, overlaps=1):
# split the input images along height and weidth by
# return a list of images
h, w, c = image.shape
h = h - (h%(seg_h*overlaps))
w = w - (w%(seg_w*overlaps))
h_crop = h // seg_h
w_crop = w // seg_w
images = []
positions = []
img_padded = torch.zeros(h+padding_pixel*2, w+padding_pixel*2, 3, device=image.device) + padding_val
img_padded[padding_pixel:h+padding_pixel, padding_pixel:w+padding_pixel, :] = image[:h, :w]
# overlapped sampling
seg_h = np.round((h - h_crop) / h_crop * overlaps).astype(int) + 1
seg_w = np.round((w - w_crop) / w_crop * overlaps).astype(int) + 1
h_step = np.round(h_crop / overlaps).astype(int)
w_step = np.round(w_crop / overlaps).astype(int)
# print(f"h_step: {h_step}, seg_h: {seg_h}, w_step: {w_step}, seg_w: {seg_w}, img_padded: {img_padded.shape}, image[:h, :w]: {image[:h, :w].shape}")
for ind_i in range(0,seg_h):
i = ind_i * h_step
for ind_j in range(0,seg_w):
j = ind_j * w_step
img_ = img_padded[i:i+h_crop+padding_pixel*2, j:j+w_crop+padding_pixel*2, :]
images.append(img_)
positions.append(torch.FloatTensor([i-padding_pixel, j-padding_pixel]).reshape(2))
return torch.stack(images, dim=0), torch.stack(positions, dim=0), seg_h, seg_w
class InferenceModel():
def __init__(self, ckpt_path, use_ddim, gpu_id=0):
self.model = init_model(ckpt_path, use_ddim, gpu_id=gpu_id)
self.gpu_id = gpu_id
self.split_hw = [1,1]
self.padding = 0
self.padding_crop = 0
self.results_list = None
self.results_output_list = []
self.image_sizes_list = []
def parse_item(self, img_ori, mask_img_ori, guid_images):
# if mask_img_ori is None:
# mask_img_ori = read_img(input_name, read_alpha=True)
# # ensure background is white, same as training data
# img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1
img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1
use_true_mask = (self.split_hw[0] * self.split_hw[1]) <= 1
self.ori_hw = list(img_ori.shape)
# mask cropping
min_max_uv = masks_to_boxes(mask_img_ori[None, ..., -1] > 0.5).long()
self.min_uv, self.max_uv = min_max_uv[0, ..., [1,0]], min_max_uv[0, ..., [3,2]]+1
# print(self.min_uv, self.max_uv)
mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
image_size = list(img.shape)
if not use_true_mask:
# for cropping boarder
self.max_uv[0] = self.max_uv[0] - ((self.max_uv[0]-self.min_uv[0])%(self.split_hw[0]*self.split_overlap))
self.max_uv[1] = self.max_uv[1] - ((self.max_uv[1]-self.min_uv[1])%(self.split_hw[1]*self.split_overlap))
mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
image_size = list(img.shape)
if not use_true_mask:
mask_img = torch.ones_like(mask_img)
mask_img, _ = images_spliter(mask_img[..., [0, 0, 0]], self.split_hw[0], self.split_hw[1], self.padding, not use_true_mask, self.split_overlap)[:2]
img, position_indexes, seg_h, seg_w = images_spliter(img, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)
self.split_hw_overlapped = [seg_h, seg_w]
logger.info(f"Spliting Size: {image_size}, splits: {self.split_hw}, Overlapped: {self.split_hw_overlapped}")
if guid_images is None:
guid_images = torch.zeros_like(img)
else:
guid_images = guid_images[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
guid_images, _ = images_spliter(guid_images, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)[:2]
return guid_images, img, mask_img[..., :1], image_size, position_indexes
def prepare_batch(self, guid_img, img_ori, mask_img_ori, batch_size):
input_img = []
cond_img = []
mask_img = []
image_size = []
position_indexes = []
for i in range(batch_size):
_input_img, _cond_img, _mask_img, _image_size, _position_indexes = \
self.parse_item(img_ori, mask_img_ori, guid_img)
input_img.append(_input_img)
cond_img.append(_cond_img)
mask_img.append(_mask_img)
position_indexes.append(_position_indexes)
image_size += [_image_size] * _input_img.shape[0]
input_img = torch.cat(input_img, dim=0).to(self.gpu_id)
cond_img = torch.cat(cond_img, dim=0).to(self.gpu_id)
mask_img = torch.cat(mask_img, dim=0).to(self.gpu_id)
position_indexes = torch.cat(position_indexes, dim=0).to(self.gpu_id)
return input_img, cond_img, mask_img, image_size, position_indexes
def assemble_results(self, img_out, img_hw=None, position_index=None, default_val=1):
results_img = np.zeros((img_hw[0], img_hw[1], 3))
weight_img = np.zeros((img_hw[0], img_hw[1], 3)) + 1e-5
for i in range(position_index.shape[0]):
# crop out boarder
crop_h, crop_w = img_out[i].shape[:2]
pathed_img = img_out[i][self.padding_crop:crop_h-self.padding_crop, self.padding_crop:crop_w-self.padding_crop]
position_index[i] += self.padding_crop
crop_h, crop_w = pathed_img.shape[:2]
crop_x, crop_y = max(position_index[i][0], 0), max(position_index[i][1], 0)
shape_max = results_img[crop_x:crop_x+crop_h, crop_y:crop_y+crop_w].shape[:2]
start_crop_x, start_crop_y = abs(min(position_index[i][0], 0)), abs(min(position_index[i][1], 0))
# print(pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]].shape, crop_x, crop_y, position_index[i])
results_img[crop_x:crop_x+shape_max[0]-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]]
weight_img[crop_x:crop_x+crop_h-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += 1
img_out = results_img / weight_img
img_out[weight_img[:,:,0] < 1] = 255
# print(img_out.shape, weight_img.shape, np.unique(weight_img), pathed_img.dtype)
img_out_ = (np.zeros((self.ori_hw[0], self.ori_hw[1], 3)) + default_val) * 255
img_out_[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] = img_out
img_out = img_out_
return img_out
def write_batch_img(self, imgs, image_sizes, position_indexes):
cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1]
if self.results_list is None or self.results_list.shape[0] == 0:
self.results_list = imgs
self.position_indexes = position_indexes
else:
self.results_list = torch.cat([self.results_list, imgs], dim=0)
self.position_indexes = torch.cat([self.position_indexes, position_indexes], dim=0)
self.image_sizes_list += image_sizes
valid_len = self.results_list.shape[0] - (self.results_list.shape[0] % cropped_batch)
out_images = []
for ind in range(0, valid_len, cropped_batch):
# assemble results
img_out = (self.results_list[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8)
img_out = self.assemble_results(img_out, self.image_sizes_list[ind], self.position_indexes[ind:ind+cropped_batch].detach().cpu().numpy().astype(int))
# Image.fromarray(img_out.astype(np.uint8)).save(self.results_output_list[ind])
out_images.append(img_out.astype(np.uint8))
self.results_list = self.results_list[valid_len:]
self.position_indexes = self.position_indexes[valid_len:]
self.image_sizes_list = self.image_sizes_list[valid_len:]
return out_images
def write_batch_input(self, imgs, image_sizes, position_indexes, default_val=1):
cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1]
images = []
valid_len = imgs.shape[0]
for ind in range(0, valid_len, cropped_batch):
# assemble results
img_out = (imgs[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8)
img_out = self.assemble_results(img_out, image_sizes[ind], position_indexes.detach().cpu().numpy().astype(int), default_val).astype(np.uint8)
images.append(img_out)
return images
def generation(self, split_hw, split_overlap, guid_img, img_ori, mask_img_ori, dps_scale, uc_score, ddim_steps, batch_size=32, n_samples=1):
max_batch = __MAX_BATCH__
operator = GaussialBlurOperator(61, 3.0, self.gpu_id)
assert batch_size == 1
self.split_resolution = None
self.split_overlap = split_overlap
self.split_hw = split_hw
# get img hw
for src_img_id in range(0, 1, batch_size):
input_img, cond_img, mask_img, image_sizes, position_indexes = self.prepare_batch(guid_img, img_ori, mask_img_ori, 1)
input_masked = self.write_batch_input(cond_img, image_sizes, position_indexes)
input_maskes = self.write_batch_input(mask_img, image_sizes, position_indexes, 0)
results_all = []
for _ in range(n_samples):
for batch_id in range(0, input_img.shape[0], max_batch):
embeddings = {}
embeddings["cond_img"] = cond_img[batch_id:batch_id+max_batch]
if (mask_img[batch_id:batch_id+max_batch] > 0.5).sum() == 0:
results = torch.ones_like(cond_img[batch_id:batch_id+max_batch])
else:
results = self.model(embeddings, input_img[batch_id:batch_id+max_batch], mask_img[batch_id:batch_id+max_batch], ddim_steps=ddim_steps,
guidance_scale=uc_score, dps_scale=dps_scale, as_latent=False, grad_scale=1, operator=operator)
out_images = self.write_batch_img(results, image_sizes[batch_id:batch_id+max_batch], position_indexes[batch_id:batch_id+max_batch])
results_all += out_images
ret = {
"input_image": input_masked,
"input_maskes": input_maskes,
"out_images": results_all
}
return ret