Spaces:
Running
on
A10G
Running
on
A10G
import os | |
from functools import partial | |
from glob import glob | |
from pathlib import Path as PythonPath | |
import cv2 | |
import torchvision.transforms.functional as TvF | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from inspect import isfunction | |
from PIL import Image | |
from lib import smplfusion | |
from lib.smplfusion import share, router, attentionpatch, transformerpatch | |
from lib.utils.iimage import IImage | |
from lib.utils import poisson_blend | |
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v | |
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor): | |
lr_mask = hr_mask.resize(512) | |
x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0]) | |
x_min = max(x_min - 1, 0) | |
y_min = max(y_min - 1, 0) | |
x_max = x_min + rect_w + 1 | |
y_max = y_min + rect_h + 1 | |
input_box = np.array([x_min, y_min, x_max, y_max]) | |
sam_predictor.set_image(hr_image.resize(512).data[0]) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_box[None, :], | |
multimask_output=True, | |
) | |
dilation_kernel = np.ones((13, 13)) | |
original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) | |
original_object_mask = cv2.dilate(original_object_mask, dilation_kernel) | |
sam_predictor.set_image(lr_image.resize(512).data[0]) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_box[None, :], | |
multimask_output=True, | |
) | |
dilation_kernel = np.ones((3, 3)) | |
inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) | |
inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel) | |
lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8) | |
new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis] | |
new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC) | |
return new_mask | |
def run( | |
ddim, | |
sam_predictor, | |
lr_image, | |
hr_image, | |
hr_mask, | |
prompt = 'high resolution professional photo', | |
noise_level=20, | |
blend_output = True, | |
blend_trick = True, | |
dt = 50, | |
seed = 1, | |
guidance_scale = 7.5, | |
negative_prompt = '', | |
use_sam_mask = False | |
): | |
hr_image_info = hr_image.info | |
lr_image = IImage(lr_image) | |
hr_image = IImage(hr_image).resize(2048) | |
hr_mask = IImage(hr_mask).resize(2048) | |
torch.manual_seed(seed) | |
dtype = ddim.vae.encoder.conv_in.weight.dtype | |
device = ddim.vae.encoder.conv_in.weight.device | |
router.attention_forward = attentionpatch.default.forward_xformers | |
router.basic_transformer_forward = transformerpatch.default.forward | |
hr_image_orig = hr_image | |
hr_mask_orig = hr_mask | |
if use_sam_mask: | |
with torch.no_grad(): | |
hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor) | |
orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3] | |
hr_image = hr_image.padx(256, padding_mode='reflect') | |
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19) | |
lr_image = lr_image.padx(64, padding_mode='reflect').torch() | |
lr_mask = hr_mask.resize((lr_image.shape[2:]), resample = Image.BICUBIC) | |
lr_mask = lr_mask.alpha().torch(vmin=0).to(device) | |
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19) | |
# encode hr image | |
with torch.no_grad(): | |
hr_image = hr_image.torch().to(dtype=dtype, device=device) | |
hr_z0 = ddim.vae.encode(hr_image).mean * ddim.config.scale_factor | |
assert hr_z0.shape[2] == lr_image.shape[2] | |
assert hr_z0.shape[3] == lr_image.shape[3] | |
with torch.no_grad(): | |
context = ddim.encoder.encode([negative_prompt, prompt]) | |
noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long() | |
unet_condition = lr_image.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) | |
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level) | |
with torch.autocast('cuda'), torch.no_grad(): | |
zt = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])) | |
zt = zt.cuda().to(dtype=dtype, device=device) | |
for index,t in enumerate(range(999, 0, -dt)): | |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) | |
eps_uncond, eps = ddim.unet( | |
torch.cat([_zt, _zt]).to(dtype=dtype, device=device), | |
timesteps = torch.tensor([t, t]).to(device=device), | |
context = context, | |
y=torch.cat([noise_level]*2) | |
).chunk(2) | |
ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long) | |
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond)) | |
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) | |
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) | |
if blend_trick: | |
z0 = z0 * lr_mask + hr_z0 * (1-lr_mask) | |
zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps | |
with torch.no_grad(): | |
hr_result = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor) | |
# postprocess | |
hr_result = (255 * ((hr_result + 1) / 2).clip(0, 1)).to(torch.uint8) | |
hr_result = hr_result.cpu().permute(0, 2, 3, 1)[0].numpy() | |
hr_result = hr_result[:orig_h, :orig_w, :] | |
if blend_output: | |
hr_result = poisson_blend( | |
orig_img=hr_image_orig.data[0], | |
fake_img=hr_result, | |
mask=hr_mask_orig.alpha().data[0] | |
) | |
hr_result = Image.fromarray(hr_result) | |
hr_result.info = hr_image_info # save metadata | |
return hr_result | |