Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import PIL | |
import torchvision.transforms as T | |
import torch.nn.functional as F | |
from KandiSuperRes.model.unet import UNet | |
from KandiSuperRes.model.unet_sr import UNet as UNet_sr | |
from KandiSuperRes.movq import MoVQ | |
from KandiSuperRes.model.diffusion_sr import DPMSolver | |
from KandiSuperRes.model.diffusion_refine import BaseDiffusion, get_named_beta_schedule | |
from KandiSuperRes.model.diffusion_sr_turbo import BaseDiffusion as BaseDiffusion_turbo | |
class KandiSuperResPipeline: | |
def __init__( | |
self, | |
scale: int, | |
device: str, | |
dtype: str, | |
flash: bool, | |
sr_model: UNet_sr, | |
movq: MoVQ = None, | |
refiner: UNet = None, | |
): | |
self.device = device | |
self.dtype = dtype | |
self.scale = scale | |
self.flash = flash | |
self.to_pil = T.ToPILImage() | |
self.image_transform = T.Compose([ | |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
T.ToTensor(), | |
T.Lambda(lambda img: 2. * img - 1.), | |
]) | |
self.sr_model = sr_model | |
self.movq = movq | |
self.refiner = refiner | |
def __call__( | |
self, | |
pil_image: PIL.Image.Image = None, | |
steps: int = 5, | |
view_batch_size: int = 15, | |
seed: int = 0, | |
refine=True | |
) -> PIL.Image.Image: | |
if self.flash: | |
betas_turbo = get_named_beta_schedule('linear', 1000) | |
base_diffusion_sr = BaseDiffusion_turbo(betas_turbo) | |
old_height = pil_image.size[1] | |
old_width = pil_image.size[0] | |
height = int(old_height-np.mod(old_height,32)) | |
width = int(old_width-np.mod(old_width,32)) | |
pil_image = pil_image.resize((width,height)) | |
lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device['sr_model']) | |
sr_image = base_diffusion_sr.p_sample_loop( | |
self.sr_model, (1, 3, height*self.scale, width*self.scale), self.device['sr_model'], self.dtype['sr_model'], lowres_img=lr_image | |
) | |
if refine: | |
betas = get_named_beta_schedule('cosine', 1000) | |
base_diffusion = BaseDiffusion(betas, 0.99) | |
with torch.cuda.amp.autocast(dtype=self.dtype['movq']): | |
lr_image_latent = self.movq.encode(sr_image) | |
pil_images = [] | |
context = torch.load('weights/context.pt').to(self.dtype['refiner']) | |
context_mask = torch.load('weights/context_mask.pt').to(self.dtype['refiner']) | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(dtype=self.dtype['refiner']): | |
refiner_image = base_diffusion.refine_tiled(self.refiner, lr_image_latent, context, context_mask) | |
with torch.cuda.amp.autocast(dtype=self.dtype['movq']): | |
refiner_image = self.movq.decode(refiner_image) | |
refiner_image = torch.clip((refiner_image + 1.) / 2., 0., 1.) | |
if old_height*self.scale != refiner_image.shape[2] or old_width*self.scale != refiner_image.shape[3]: | |
refiner_image = F.interpolate(refiner_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True) | |
refined_pil_image = self.to_pil(refiner_image[0]) | |
return refined_pil_image | |
sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.) | |
if old_height*self.scale != sr_image.shape[2] or old_width*self.scale != sr_image.shape[3]: | |
sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True) | |
pil_sr_image = self.to_pil(sr_image[0]) | |
return pil_sr_image | |
else: | |
base_diffusion = DPMSolver(steps) | |
lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device) | |
old_height = pil_image.size[1] | |
old_width = pil_image.size[0] | |
height = int(old_height+np.mod(old_height,2))*self.scale | |
width = int(old_width+np.mod(old_width,2))*self.scale | |
sr_image = base_diffusion.generate_panorama(height, width, self.device, self.dtype, steps, | |
self.sr_model, lowres_img=lr_image, | |
view_batch_size=view_batch_size, eta=0.0, seed=seed) | |
sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.) | |
if old_height*self.scale != height or old_width*self.scale != width: | |
sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True) | |
pil_sr_image = self.to_pil(sr_image[0]) | |
return pil_sr_image | |