Pana360gen / img2panoimg /image_to_360panorama_image_pipeline.py
banyapon's picture
Build
2f19003
raw
history blame
9.98 kB
# Copyright © Alibaba, Inc. and its affiliates.
import random
from typing import Any, Dict
import numpy as np
import torch
from diffusers import (ControlNetModel, DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UniPCMultistepScheduler)
from PIL import Image
from RealESRGAN import RealESRGAN
from .pipeline_i2p import StableDiffusionImage2PanoPipeline
from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
import py360convert
class LazyRealESRGAN:
def __init__(self, device, scale):
self.device = device
self.scale = scale
self.model = None
self.model_path = None
def load_model(self):
if self.model is None:
self.model = RealESRGAN(self.device, scale=self.scale)
self.model.load_weights(self.model_path, download=False)
def predict(self, img):
self.load_model()
return self.model.predict(img)
class Image2360PanoramaImagePipeline(DiffusionPipeline):
""" Stable Diffusion for 360 Panorama Image Generation Pipeline.
Example:
>>> import torch
>>> from txt2panoimg import Text2360PanoramaImagePipeline
>>> prompt = 'The mountains'
>>> input = {'prompt': prompt, 'upscale': True}
>>> model_id = 'models/'
>>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16)
>>> output = txt2panoimg(input)
>>> output.save('result.png')
"""
def __init__(self, model: str, device: str = 'cuda', **kwargs):
"""
Use `model` to create a stable diffusion pipeline for 360 panorama image generation.
Args:
model: model id on modelscope hub.
device: str = 'cuda'
"""
super().__init__()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
) if device is None else device
if device == 'gpu':
device = torch.device('cuda')
torch_dtype = kwargs.get('torch_dtype', torch.float16)
enable_xformers_memory_efficient_attention = kwargs.get(
'enable_xformers_memory_efficient_attention', True)
model_id = model + '/sr-base/'
# init i2p model
controlnet = ControlNetModel.from_pretrained(model + '/sd-i2p', torch_dtype=torch.float16)
self.pipe = StableDiffusionImage2PanoPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch_dtype).to(device)
self.pipe.vae.enable_tiling()
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.pipe.scheduler.config)
# remove following line if xformers is not installed
try:
if enable_xformers_memory_efficient_attention:
self.pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(e)
# init controlnet-sr model
base_model_path = model + '/sr-base'
controlnet_path = model + '/sr-control'
controlnet = ControlNetModel.from_pretrained(
controlnet_path, torch_dtype=torch_dtype)
self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(
base_model_path, controlnet=controlnet,
torch_dtype=torch_dtype).to(device)
self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config(
self.pipe.scheduler.config)
self.pipe_sr.vae.enable_tiling()
# remove following line if xformers is not installed
try:
if enable_xformers_memory_efficient_attention:
self.pipe_sr.enable_xformers_memory_efficient_attention()
except Exception as e:
print(e)
device = torch.device("cuda")
model_path = model + '/RealESRGAN_x2plus.pth'
self.upsampler = LazyRealESRGAN(device=device, scale=2)
self.upsampler.model_path = model_path
@staticmethod
def process_control_image(image, mask):
def to_tensor(img: Image, batch_size: int, width=1024, height=512):
img = img.resize((width, height), resample=Image.BICUBIC)
img = np.array(img).astype(np.float32) / 255.0
img = np.vstack([img[None].transpose(0, 3, 1, 2)] * batch_size)
img = torch.from_numpy(img)
return img
zeros = np.zeros_like(np.array(image))
dice_np = [np.array(image) if x == 0 else zeros for x in range(6)]
output_image = py360convert.c2e(dice_np, 512, 1024, cube_format='list')
bk_image = to_tensor(image, batch_size=1)
control_image = Image.fromarray(output_image.astype(np.uint8))
control_image = to_tensor(control_image, batch_size=1)
mask_image = to_tensor(mask, batch_size=1)
control_image = (1 - mask_image) * bk_image + mask_image * control_image
control_image = torch.cat([mask_image[:, :1, :, :], control_image], dim=1)
return control_image
@staticmethod
def blend_h(a, b, blend_extent):
a = np.array(a)
b = np.array(b)
blend_extent = min(a.shape[1], b.shape[1], blend_extent)
for x in range(blend_extent):
b[:, x, :] = a[:, -blend_extent
+ x, :] * (1 - x / blend_extent) + b[:, x, :] * (
x / blend_extent)
return b
def __call__(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
if not isinstance(inputs, dict):
raise ValueError(
f'Expected the input to be a dictionary, but got {type(input)}'
)
num_inference_steps = inputs.get('num_inference_steps', 20)
guidance_scale = inputs.get('guidance_scale', 7.0)
preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))'
add_prompt = inputs.get('add_prompt', preset_a_prompt)
preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\
'low quality, zombie, logo, text, watermark, username, monochrome, '\
'complex lighting'
negative_prompt = inputs.get('negative_prompt', preset_n_prompt)
seed = inputs.get('seed', -1)
upscale = inputs.get('upscale', True)
refinement = inputs.get('refinement', True)
guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15)
guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17)
image = inputs['image']
mask = inputs['mask']
control_image = self.process_control_image(image, mask)
if 'prompt' in inputs.keys():
prompt = inputs['prompt']
else:
# for demo_service
prompt = forward_params.get('prompt', 'the living room')
print(f'Test with prompt: {prompt}')
if seed == -1:
seed = random.randint(0, 65535)
print(f'global seed: {seed}')
generator = torch.manual_seed(seed)
prompt = '<360panorama>, ' + prompt + ', ' + add_prompt
output_img = self.pipe(
prompt,
image=(control_image[:, 1:, :, :] / 0.5 - 1.0),
control_image=control_image,
controlnet_conditioning_scale=1.0,
strength=1.0,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
height=512,
width=1024,
guidance_scale=guidance_scale,
generator=generator).images[0]
if not upscale:
print('finished')
else:
print('inputs: upscale=True, running upscaler.')
print('running upscaler step1. Initial super-resolution')
sr_scale = 2.0
output_img = self.pipe_sr(
prompt.replace('<360panorama>, ', ''),
negative_prompt=negative_prompt,
image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
num_inference_steps=7,
generator=generator,
control_image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
strength=0.8,
controlnet_conditioning_scale=1.0,
guidance_scale=guidance_scale_sr_step1,
).images[0]
print('running upscaler step2. Super-resolution with Real-ESRGAN')
output_img = output_img.resize((1536 * 2, 768 * 2))
w = output_img.size[0]
blend_extend = 10
outscale = 2
output_img = np.array(output_img)
output_img = np.concatenate(
[output_img, output_img[:, :blend_extend, :]], axis=1)
output_img = self.upsampler.predict(
output_img)
output_img = self.blend_h(output_img, output_img,
blend_extend * outscale)
output_img = Image.fromarray(output_img[:, :w * outscale, :])
if refinement:
print(
'inputs: refinement=True, running refinement. This is a bit time-consuming.'
)
sr_scale = 4
output_img = self.pipe_sr(
prompt.replace('<360panorama>, ', ''),
negative_prompt=negative_prompt,
image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
num_inference_steps=7,
generator=generator,
control_image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
strength=0.8,
controlnet_conditioning_scale=1.0,
guidance_scale=guidance_scale_sr_step2,
).images[0]
print('finished')
return output_img