hysts's picture
hysts HF staff
Add an option to show denoising process
af2a8f5
raw
history blame
No virus
6.6 kB
from __future__ import annotations
import logging
import os
import random
import sys
import tempfile
import imageio
import numpy as np
import PIL.Image
import torch
import tqdm.auto
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
DiffusionPipeline, PNDMPipeline, PNDMScheduler)
HF_TOKEN = os.environ['HF_TOKEN']
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
class Model:
MODEL_NAMES = [
'ddpm-128-exp000',
]
def __init__(self, device: str | torch.device):
self.device = torch.device(device)
self._download_all_models()
self.model_name = self.MODEL_NAMES[0]
self.scheduler_type = 'DDIM'
self.pipeline = self._load_pipeline(self.model_name,
self.scheduler_type)
self.rng = random.Random()
@staticmethod
def _load_pipeline(model_name: str,
scheduler_type: str) -> DiffusionPipeline:
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
if scheduler_type == 'DDPM':
pipeline = DDPMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
elif scheduler_type == 'DDIM':
pipeline = DDIMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
pipeline.scheduler = DDIMScheduler.from_config(
repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN)
elif scheduler_type == 'PNDM':
pipeline = PNDMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
pipeline.scheduler = PNDMScheduler.from_config(
repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN)
else:
raise ValueError
return pipeline
def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
logger.info('--- set_pipeline ---')
logger.info(f'{model_name=}, {scheduler_type=}')
if model_name == self.model_name and scheduler_type == self.scheduler_type:
logger.info('Skipping')
logger.info('--- done ---')
return
self.model_name = model_name
self.scheduler_type = scheduler_type
self.pipeline = self._load_pipeline(model_name, scheduler_type)
logger.info('--- done ---')
def _download_all_models(self) -> None:
for name in self.MODEL_NAMES:
self._load_pipeline(name, 'DDPM')
def generate(self,
seed: int,
num_steps: int,
num_images: int = 1) -> list[PIL.Image.Image]:
logger.info('--- generate ---')
logger.info(f'{seed=}, {num_steps=}')
torch.manual_seed(seed)
if self.scheduler_type == 'DDPM':
res = self.pipeline(batch_size=num_images,
torch_device=self.device)['sample']
elif self.scheduler_type in ['DDIM', 'PNDM']:
res = self.pipeline(batch_size=num_images,
torch_device=self.device,
num_inference_steps=num_steps)['sample']
else:
raise ValueError
logger.info('--- done ---')
return res
@staticmethod
def postprocess(sample: torch.Tensor) -> np.ndarray:
res = (sample / 2 + 0.5).clamp(0, 1)
res = (res * 255).to(torch.uint8)
res = res.cpu().permute(0, 2, 3, 1).numpy()
return res
@torch.inference_mode()
def generate_with_video(self, seed: int,
num_steps: int) -> tuple[PIL.Image.Image, str]:
logger.info('--- generate_with_video ---')
if self.scheduler_type == 'DDPM':
num_steps = 1000
fps = 100
else:
fps = 10
logger.info(f'{seed=}, {num_steps=}')
model = self.pipeline.unet.to(self.device)
scheduler = self.pipeline.scheduler
scheduler.set_timesteps(num_inference_steps=num_steps)
input_shape = (1, model.config.in_channels, model.config.sample_size,
model.config.sample_size)
torch.manual_seed(seed)
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
writer = imageio.get_writer(out_file.name, fps=fps)
sample = torch.randn(input_shape).to(self.device)
for t in tqdm.auto.tqdm(scheduler.timesteps):
out = model(sample, t)['sample']
sample = scheduler.step(out, t, sample)['prev_sample']
res = self.postprocess(sample)[0]
writer.append_data(res)
writer.close()
logger.info('--- done ---')
return res, out_file.name
def run(self, model_name: str, scheduler_type: str, num_steps: int,
randomize_seed: bool, seed: int, visualize_denoising: bool
) -> tuple[PIL.Image.Image, int, str | None]:
self.set_pipeline(model_name, scheduler_type)
if scheduler_type == 'PNDM':
num_steps = max(4, min(num_steps, 100))
if randomize_seed:
seed = self.rng.randint(0, 100000)
if not visualize_denoising:
return self.generate(seed, num_steps)[0], seed, None
else:
res, filename = self.generate_with_video(seed, num_steps)
return res, seed, filename
@staticmethod
def to_grid(images: list[PIL.Image.Image],
ncols: int = 2) -> PIL.Image.Image:
images = [np.asarray(image) for image in images]
nrows = (len(images) + ncols - 1) // ncols
h, w = images[0].shape[:2]
if (d := nrows * ncols - len(images)) > 0:
images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d
grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose(
0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3)
return PIL.Image.fromarray(grid)
def run_simple(self) -> PIL.Image.Image:
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
seed = self.rng.randint(0, 1000000)
images = self.generate(seed, num_steps=10, num_images=4)
return self.to_grid(images, 2)