from __future__ import annotations import logging import os import random import sys import tempfile import imageio import numpy as np import PIL.Image import torch import tqdm.auto from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, DiffusionPipeline, PNDMPipeline, PNDMScheduler) HF_TOKEN = os.environ['HF_TOKEN'] formatter = logging.Formatter( '[%(asctime)s] %(name)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.propagate = False logger.addHandler(stream_handler) class Model: MODEL_NAMES = [ 'ddpm-128-exp000', ] def __init__(self, device: str | torch.device): self.device = torch.device(device) self._download_all_models() self.model_name = self.MODEL_NAMES[0] self.scheduler_type = 'DDIM' self.pipeline = self._load_pipeline(self.model_name, self.scheduler_type) self.rng = random.Random() @staticmethod def _load_pipeline(model_name: str, scheduler_type: str) -> DiffusionPipeline: repo_id = f'hysts/diffusers-anime-faces-{model_name}' if scheduler_type == 'DDPM': pipeline = DDPMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) elif scheduler_type == 'DDIM': pipeline = DDIMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) pipeline.scheduler = DDIMScheduler.from_config( repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN) elif scheduler_type == 'PNDM': pipeline = PNDMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) pipeline.scheduler = PNDMScheduler.from_config( repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN) else: raise ValueError return pipeline def set_pipeline(self, model_name: str, scheduler_type: str) -> None: logger.info('--- set_pipeline ---') logger.info(f'{model_name=}, {scheduler_type=}') if model_name == self.model_name and scheduler_type == self.scheduler_type: logger.info('Skipping') logger.info('--- done ---') return self.model_name = model_name self.scheduler_type = scheduler_type self.pipeline = self._load_pipeline(model_name, scheduler_type) logger.info('--- done ---') def _download_all_models(self) -> None: for name in self.MODEL_NAMES: self._load_pipeline(name, 'DDPM') def generate(self, seed: int, num_steps: int, num_images: int = 1) -> list[PIL.Image.Image]: logger.info('--- generate ---') logger.info(f'{seed=}, {num_steps=}') torch.manual_seed(seed) if self.scheduler_type == 'DDPM': res = self.pipeline(batch_size=num_images, torch_device=self.device)['sample'] elif self.scheduler_type in ['DDIM', 'PNDM']: res = self.pipeline(batch_size=num_images, torch_device=self.device, num_inference_steps=num_steps)['sample'] else: raise ValueError logger.info('--- done ---') return res @staticmethod def postprocess(sample: torch.Tensor) -> np.ndarray: res = (sample / 2 + 0.5).clamp(0, 1) res = (res * 255).to(torch.uint8) res = res.cpu().permute(0, 2, 3, 1).numpy() return res @torch.inference_mode() def generate_with_video(self, seed: int, num_steps: int) -> tuple[PIL.Image.Image, str]: logger.info('--- generate_with_video ---') if self.scheduler_type == 'DDPM': num_steps = 1000 fps = 100 else: fps = 10 logger.info(f'{seed=}, {num_steps=}') model = self.pipeline.unet.to(self.device) scheduler = self.pipeline.scheduler scheduler.set_timesteps(num_inference_steps=num_steps) input_shape = (1, model.config.in_channels, model.config.sample_size, model.config.sample_size) torch.manual_seed(seed) out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) writer = imageio.get_writer(out_file.name, fps=fps) sample = torch.randn(input_shape).to(self.device) for t in tqdm.auto.tqdm(scheduler.timesteps): out = model(sample, t)['sample'] sample = scheduler.step(out, t, sample)['prev_sample'] res = self.postprocess(sample)[0] writer.append_data(res) writer.close() logger.info('--- done ---') return res, out_file.name def run(self, model_name: str, scheduler_type: str, num_steps: int, randomize_seed: bool, seed: int, visualize_denoising: bool ) -> tuple[PIL.Image.Image, int, str | None]: self.set_pipeline(model_name, scheduler_type) if scheduler_type == 'PNDM': num_steps = max(4, min(num_steps, 100)) if randomize_seed: seed = self.rng.randint(0, 100000) if not visualize_denoising: return self.generate(seed, num_steps)[0], seed, None else: res, filename = self.generate_with_video(seed, num_steps) return res, seed, filename @staticmethod def to_grid(images: list[PIL.Image.Image], ncols: int = 2) -> PIL.Image.Image: images = [np.asarray(image) for image in images] nrows = (len(images) + ncols - 1) // ncols h, w = images[0].shape[:2] if (d := nrows * ncols - len(images)) > 0: images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose( 0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3) return PIL.Image.fromarray(grid) def run_simple(self) -> PIL.Image.Image: self.set_pipeline(self.MODEL_NAMES[0], 'DDIM') seed = self.rng.randint(0, 1000000) images = self.generate(seed, num_steps=10, num_images=4) return self.to_grid(images, 2)