Bi-Noising.Diffusion / diffusion.py
MKFMIKU's picture
Upload 10 files
6eb1be9
import numpy as np
import PIL
from PIL import Image
import torch
from diffusion_arch import ILVRUNetModel, ConditionalUNetModel
from guided_diffusion.script_util import create_gaussian_diffusion
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w, h), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0)
mask[mask > 0] = 1
return mask
class DiffusionPipeline():
def __init__(self, device):
super().__init__()
self.device = device
diffusion_model = ILVRUNetModel(
in_channels=3,
model_channels=128,
out_channels=6,
num_res_blocks=1,
attention_resolutions=[16],
channel_mult=(1, 1, 2, 2, 4, 4),
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=4,
num_head_channels=64,
num_heads_upsample=-1,
use_scale_shift_norm=True,
resblock_updown=True,
use_new_attention_order=False
)
diffusion_model = diffusion_model.to(device)
diffusion_model = diffusion_model.eval()
ilvr_pretraining = torch.load('./ffhq_10m.pt', map_location='cpu')
diffusion_model.load_state_dict(ilvr_pretraining)
self.diffusion_model = diffusion_model
diffusion_restoration_model = ConditionalUNetModel(
in_channels=3,
model_channels=128,
out_channels=6,
num_res_blocks=1,
attention_resolutions=[16],
dropout=0.0,
channel_mult=(1, 1, 2, 2, 4, 4),
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=4,
num_head_channels=64,
num_heads_upsample=-1,
use_scale_shift_norm=True,
resblock_updown=True,
use_new_attention_order=False
)
diffusion_restoration_model = diffusion_restoration_model.to(device)
diffusion_restoration_model = diffusion_restoration_model.eval()
state_dict = torch.load('./net_g_250000.pth', map_location='cpu')
diffusion_restoration_model.load_state_dict(state_dict['params'])
self.diffusion_restoration_model = diffusion_restoration_model
@torch.no_grad()
def __call__(self, lq, diffusion_step, binoising_step, grid_size):
lq = lq.convert("RGB").resize((256, 256), resample=Image.LANCZOS)
eval_gaussian_diffusion = create_gaussian_diffusion(
steps=1000,
learn_sigma=True,
noise_schedule='linear',
use_kl=False,
timestep_respacing=str(int(diffusion_step)),
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
)
ow, oh = lq.size
# preprocess image
lq_img_th = preprocess_image(lq).to(self.device)
lq_img_th = lq_img_th.repeat([grid_size, 1, 1, 1])
img = torch.randn_like(lq_img_th, device=self.device)
s_img = torch.randn_like(lq_img_th, device=self.device)
indices = list(range(eval_gaussian_diffusion.num_timesteps))[::-1]
for i in indices:
t = torch.tensor([i] * lq_img_th.size(0), device=self.device)
out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_restoration_model, s_img, t, model_kwargs={'lq': lq_img_th})
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1)))
) # no noise when t == 0
s_img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device)
s_img_pred = out["pred_xstart"]
if i < binoising_step:
model_output = eval_gaussian_diffusion._wrap_model(self.diffusion_restoration_model)(img, t, lq=lq_img_th)
B, C = img.shape[:2]
model_output, model_var_values = torch.split(model_output, C, dim=1)
pred_xstart = eval_gaussian_diffusion._predict_xstart_from_eps(img, t, model_output).clamp(-1, 1)
img = eval_gaussian_diffusion.q_sample(pred_xstart, t)
out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_model, img, t)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1)))
) # no noise when t == 0
img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device)
img_pred = out["pred_xstart"]
if i % 2 == 0:
yield [Image.fromarray(np.uint8((make_grid(s_img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))]
yield [Image.fromarray(np.uint8((make_grid(s_img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))]