Spaces:
Build error
Build error
# Copyright © Alibaba, Inc. and its affiliates. | |
import random | |
from typing import Any, Dict | |
import numpy as np | |
import torch | |
from diffusers import (ControlNetModel, DiffusionPipeline, | |
EulerAncestralDiscreteScheduler, | |
UniPCMultistepScheduler) | |
from PIL import Image | |
from RealESRGAN import RealESRGAN | |
from .pipeline_i2p import StableDiffusionImage2PanoPipeline | |
from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline | |
import py360convert | |
class LazyRealESRGAN: | |
def __init__(self, device, scale): | |
self.device = device | |
self.scale = scale | |
self.model = None | |
self.model_path = None | |
def load_model(self): | |
if self.model is None: | |
self.model = RealESRGAN(self.device, scale=self.scale) | |
self.model.load_weights(self.model_path, download=False) | |
def predict(self, img): | |
self.load_model() | |
return self.model.predict(img) | |
class Image2360PanoramaImagePipeline(DiffusionPipeline): | |
""" Stable Diffusion for 360 Panorama Image Generation Pipeline. | |
Example: | |
>>> import torch | |
>>> from txt2panoimg import Text2360PanoramaImagePipeline | |
>>> prompt = 'The mountains' | |
>>> input = {'prompt': prompt, 'upscale': True} | |
>>> model_id = 'models/' | |
>>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16) | |
>>> output = txt2panoimg(input) | |
>>> output.save('result.png') | |
""" | |
def __init__(self, model: str, device: str = 'cuda', **kwargs): | |
""" | |
Use `model` to create a stable diffusion pipeline for 360 panorama image generation. | |
Args: | |
model: model id on modelscope hub. | |
device: str = 'cuda' | |
""" | |
super().__init__() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' | |
) if device is None else device | |
if device == 'gpu': | |
device = torch.device('cuda') | |
torch_dtype = kwargs.get('torch_dtype', torch.float16) | |
enable_xformers_memory_efficient_attention = kwargs.get( | |
'enable_xformers_memory_efficient_attention', True) | |
model_id = model + '/sr-base/' | |
# init i2p model | |
controlnet = ControlNetModel.from_pretrained(model + '/sd-i2p', torch_dtype=torch.float16) | |
self.pipe = StableDiffusionImage2PanoPipeline.from_pretrained( | |
model_id, controlnet=controlnet, torch_dtype=torch_dtype).to(device) | |
self.pipe.vae.enable_tiling() | |
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
self.pipe.scheduler.config) | |
# remove following line if xformers is not installed | |
try: | |
if enable_xformers_memory_efficient_attention: | |
self.pipe.enable_xformers_memory_efficient_attention() | |
except Exception as e: | |
print(e) | |
# init controlnet-sr model | |
base_model_path = model + '/sr-base' | |
controlnet_path = model + '/sr-control' | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_path, torch_dtype=torch_dtype) | |
self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained( | |
base_model_path, controlnet=controlnet, | |
torch_dtype=torch_dtype).to(device) | |
self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config( | |
self.pipe.scheduler.config) | |
self.pipe_sr.vae.enable_tiling() | |
# remove following line if xformers is not installed | |
try: | |
if enable_xformers_memory_efficient_attention: | |
self.pipe_sr.enable_xformers_memory_efficient_attention() | |
except Exception as e: | |
print(e) | |
device = torch.device("cuda") | |
model_path = model + '/RealESRGAN_x2plus.pth' | |
self.upsampler = LazyRealESRGAN(device=device, scale=2) | |
self.upsampler.model_path = model_path | |
def process_control_image(image, mask): | |
def to_tensor(img: Image, batch_size: int, width=1024, height=512): | |
img = img.resize((width, height), resample=Image.BICUBIC) | |
img = np.array(img).astype(np.float32) / 255.0 | |
img = np.vstack([img[None].transpose(0, 3, 1, 2)] * batch_size) | |
img = torch.from_numpy(img) | |
return img | |
zeros = np.zeros_like(np.array(image)) | |
dice_np = [np.array(image) if x == 0 else zeros for x in range(6)] | |
output_image = py360convert.c2e(dice_np, 512, 1024, cube_format='list') | |
bk_image = to_tensor(image, batch_size=1) | |
control_image = Image.fromarray(output_image.astype(np.uint8)) | |
control_image = to_tensor(control_image, batch_size=1) | |
mask_image = to_tensor(mask, batch_size=1) | |
control_image = (1 - mask_image) * bk_image + mask_image * control_image | |
control_image = torch.cat([mask_image[:, :1, :, :], control_image], dim=1) | |
return control_image | |
def blend_h(a, b, blend_extent): | |
a = np.array(a) | |
b = np.array(b) | |
blend_extent = min(a.shape[1], b.shape[1], blend_extent) | |
for x in range(blend_extent): | |
b[:, x, :] = a[:, -blend_extent | |
+ x, :] * (1 - x / blend_extent) + b[:, x, :] * ( | |
x / blend_extent) | |
return b | |
def __call__(self, inputs: Dict[str, Any], | |
**forward_params) -> Dict[str, Any]: | |
if not isinstance(inputs, dict): | |
raise ValueError( | |
f'Expected the input to be a dictionary, but got {type(input)}' | |
) | |
num_inference_steps = inputs.get('num_inference_steps', 20) | |
guidance_scale = inputs.get('guidance_scale', 7.0) | |
preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))' | |
add_prompt = inputs.get('add_prompt', preset_a_prompt) | |
preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\ | |
'low quality, zombie, logo, text, watermark, username, monochrome, '\ | |
'complex lighting' | |
negative_prompt = inputs.get('negative_prompt', preset_n_prompt) | |
seed = inputs.get('seed', -1) | |
upscale = inputs.get('upscale', True) | |
refinement = inputs.get('refinement', True) | |
guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15) | |
guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17) | |
image = inputs['image'] | |
mask = inputs['mask'] | |
control_image = self.process_control_image(image, mask) | |
if 'prompt' in inputs.keys(): | |
prompt = inputs['prompt'] | |
else: | |
# for demo_service | |
prompt = forward_params.get('prompt', 'the living room') | |
print(f'Test with prompt: {prompt}') | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
print(f'global seed: {seed}') | |
generator = torch.manual_seed(seed) | |
prompt = '<360panorama>, ' + prompt + ', ' + add_prompt | |
output_img = self.pipe( | |
prompt, | |
image=(control_image[:, 1:, :, :] / 0.5 - 1.0), | |
control_image=control_image, | |
controlnet_conditioning_scale=1.0, | |
strength=1.0, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
height=512, | |
width=1024, | |
guidance_scale=guidance_scale, | |
generator=generator).images[0] | |
if not upscale: | |
print('finished') | |
else: | |
print('inputs: upscale=True, running upscaler.') | |
print('running upscaler step1. Initial super-resolution') | |
sr_scale = 2.0 | |
output_img = self.pipe_sr( | |
prompt.replace('<360panorama>, ', ''), | |
negative_prompt=negative_prompt, | |
image=output_img.resize( | |
(int(1536 * sr_scale), int(768 * sr_scale))), | |
num_inference_steps=7, | |
generator=generator, | |
control_image=output_img.resize( | |
(int(1536 * sr_scale), int(768 * sr_scale))), | |
strength=0.8, | |
controlnet_conditioning_scale=1.0, | |
guidance_scale=guidance_scale_sr_step1, | |
).images[0] | |
print('running upscaler step2. Super-resolution with Real-ESRGAN') | |
output_img = output_img.resize((1536 * 2, 768 * 2)) | |
w = output_img.size[0] | |
blend_extend = 10 | |
outscale = 2 | |
output_img = np.array(output_img) | |
output_img = np.concatenate( | |
[output_img, output_img[:, :blend_extend, :]], axis=1) | |
output_img = self.upsampler.predict( | |
output_img) | |
output_img = self.blend_h(output_img, output_img, | |
blend_extend * outscale) | |
output_img = Image.fromarray(output_img[:, :w * outscale, :]) | |
if refinement: | |
print( | |
'inputs: refinement=True, running refinement. This is a bit time-consuming.' | |
) | |
sr_scale = 4 | |
output_img = self.pipe_sr( | |
prompt.replace('<360panorama>, ', ''), | |
negative_prompt=negative_prompt, | |
image=output_img.resize( | |
(int(1536 * sr_scale), int(768 * sr_scale))), | |
num_inference_steps=7, | |
generator=generator, | |
control_image=output_img.resize( | |
(int(1536 * sr_scale), int(768 * sr_scale))), | |
strength=0.8, | |
controlnet_conditioning_scale=1.0, | |
guidance_scale=guidance_scale_sr_step2, | |
).images[0] | |
print('finished') | |
return output_img | |