Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import logging | |
import os | |
import random | |
import sys | |
import tempfile | |
import gradio as gr | |
import imageio | |
import numpy as np | |
import PIL.Image | |
import torch | |
import tqdm.auto | |
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, | |
DiffusionPipeline, PNDMPipeline, PNDMScheduler) | |
HF_TOKEN = os.environ['HF_TOKEN'] | |
formatter = logging.Formatter( | |
'[%(asctime)s] %(name)s %(levelname)s: %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S') | |
stream_handler = logging.StreamHandler(stream=sys.stdout) | |
stream_handler.setLevel(logging.INFO) | |
stream_handler.setFormatter(formatter) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
logger.propagate = False | |
logger.addHandler(stream_handler) | |
class Model: | |
MODEL_NAMES = [ | |
'ddpm-128-exp000', | |
] | |
def __init__(self, device: str | torch.device): | |
self.device = torch.device(device) | |
self._download_all_models() | |
self.model_name = self.MODEL_NAMES[0] | |
self.scheduler_type = 'DDIM' | |
self.pipeline = self._load_pipeline(self.model_name, | |
self.scheduler_type) | |
self.rng = random.Random() | |
self.real_esrgan = gr.Interface.load('spaces/hysts/Real-ESRGAN-anime') | |
def _load_pipeline(model_name: str, | |
scheduler_type: str) -> DiffusionPipeline: | |
repo_id = f'hysts/diffusers-anime-faces-{model_name}' | |
if scheduler_type == 'DDPM': | |
pipeline = DDPMPipeline.from_pretrained(repo_id, | |
use_auth_token=HF_TOKEN) | |
elif scheduler_type == 'DDIM': | |
pipeline = DDIMPipeline.from_pretrained(repo_id, | |
use_auth_token=HF_TOKEN) | |
pipeline.scheduler = DDIMScheduler.from_config( | |
repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN) | |
elif scheduler_type == 'PNDM': | |
pipeline = PNDMPipeline.from_pretrained(repo_id, | |
use_auth_token=HF_TOKEN) | |
pipeline.scheduler = PNDMScheduler.from_config( | |
repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN) | |
else: | |
raise ValueError | |
return pipeline | |
def set_pipeline(self, model_name: str, scheduler_type: str) -> None: | |
logger.info('--- set_pipeline ---') | |
logger.info(f'{model_name=}, {scheduler_type=}') | |
if model_name == self.model_name and scheduler_type == self.scheduler_type: | |
logger.info('Skipping') | |
logger.info('--- done ---') | |
return | |
self.model_name = model_name | |
self.scheduler_type = scheduler_type | |
self.pipeline = self._load_pipeline(model_name, scheduler_type) | |
logger.info('--- done ---') | |
def _download_all_models(self) -> None: | |
for name in self.MODEL_NAMES: | |
self._load_pipeline(name, 'DDPM') | |
def generate(self, | |
seed: int, | |
num_steps: int, | |
num_images: int = 1) -> list[PIL.Image.Image]: | |
logger.info('--- generate ---') | |
logger.info(f'{seed=}, {num_steps=}') | |
torch.manual_seed(seed) | |
if self.scheduler_type == 'DDPM': | |
res = self.pipeline(batch_size=num_images, | |
torch_device=self.device)['sample'] | |
elif self.scheduler_type in ['DDIM', 'PNDM']: | |
res = self.pipeline(batch_size=num_images, | |
torch_device=self.device, | |
num_inference_steps=num_steps)['sample'] | |
else: | |
raise ValueError | |
logger.info('--- done ---') | |
return res | |
def postprocess(sample: torch.Tensor) -> np.ndarray: | |
res = (sample / 2 + 0.5).clamp(0, 1) | |
res = (res * 255).to(torch.uint8) | |
res = res.cpu().permute(0, 2, 3, 1).numpy() | |
return res | |
def generate_with_video(self, seed: int, | |
num_steps: int) -> tuple[PIL.Image.Image, str]: | |
logger.info('--- generate_with_video ---') | |
if self.scheduler_type == 'DDPM': | |
num_steps = 1000 | |
fps = 100 | |
else: | |
fps = 10 | |
logger.info(f'{seed=}, {num_steps=}') | |
model = self.pipeline.unet.to(self.device) | |
scheduler = self.pipeline.scheduler | |
scheduler.set_timesteps(num_inference_steps=num_steps) | |
input_shape = (1, model.config.in_channels, model.config.sample_size, | |
model.config.sample_size) | |
torch.manual_seed(seed) | |
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
writer = imageio.get_writer(out_file.name, fps=fps) | |
sample = torch.randn(input_shape).to(self.device) | |
for t in tqdm.auto.tqdm(scheduler.timesteps): | |
out = model(sample, t)['sample'] | |
sample = scheduler.step(out, t, sample)['prev_sample'] | |
res = self.postprocess(sample)[0] | |
writer.append_data(res) | |
writer.close() | |
logger.info('--- done ---') | |
return PIL.Image.fromarray(res), out_file.name | |
def superresolve(self, image: PIL.Image.Image) -> PIL.Image.Image: | |
logger.info('--- superresolve ---') | |
with tempfile.NamedTemporaryFile(suffix='.png') as f: | |
image.save(f.name) | |
out_file = self.real_esrgan(f.name) | |
logger.info('--- done ---') | |
return PIL.Image.open(out_file) | |
def run(self, model_name: str, scheduler_type: str, num_steps: int, | |
randomize_seed: bool, seed: int, | |
superresolve: bool) -> tuple[PIL.Image.Image, int, str]: | |
self.set_pipeline(model_name, scheduler_type) | |
if scheduler_type == 'PNDM': | |
num_steps = max(4, min(num_steps, 100)) | |
if randomize_seed: | |
seed = self.rng.randint(0, 100000) | |
res, filename = self.generate_with_video(seed, num_steps) | |
if superresolve: | |
res = self.superresolve(res) | |
return res, seed, filename | |
def to_grid(images: list[PIL.Image.Image], | |
ncols: int = 2) -> PIL.Image.Image: | |
images = [np.asarray(image) for image in images] | |
nrows = (len(images) + ncols - 1) // ncols | |
h, w = images[0].shape[:2] | |
if (d := nrows * ncols - len(images)) > 0: | |
images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d | |
grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose( | |
0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3) | |
return PIL.Image.fromarray(grid) | |
def run_simple(self) -> PIL.Image.Image: | |
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM') | |
seed = self.rng.randint(0, 1000000) | |
images = self.generate(seed, num_steps=10, num_images=4) | |
images = [self.superresolve(image) for image in images] | |
return self.to_grid(images, 2) | |