| from __future__ import annotations |
|
|
| from typing import Sequence, Union |
|
|
| import torch |
|
|
| from diffusers import DiffusionPipeline |
| from diffusers.pipelines.pipeline_utils import ImagePipelineOutput |
|
|
|
|
| class BitDanceImageNetPipeline(DiffusionPipeline): |
| model_cpu_offload_seq = "transformer" |
|
|
| def __init__(self, transformer, autoencoder=None): |
| super().__init__() |
| self.register_modules(transformer=transformer, autoencoder=autoencoder) |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| class_labels: Union[int, Sequence[int]] = 0, |
| num_images_per_label: int = 1, |
| sample_steps: int = 100, |
| cfg_scale: float = 4.6, |
| chunk_size: int = 0, |
| output_type: str = "pil", |
| return_dict: bool = True, |
| ): |
| device = self._execution_device |
|
|
| if isinstance(class_labels, int): |
| labels = [class_labels] |
| else: |
| labels = list(class_labels) |
|
|
| class_ids = torch.tensor(labels, device=device, dtype=torch.long) |
| if num_images_per_label > 1: |
| class_ids = class_ids.repeat_interleave(num_images_per_label) |
|
|
| images = self.transformer.sample( |
| class_ids=class_ids, |
| sample_steps=sample_steps, |
| cfg_scale=cfg_scale, |
| chunk_size=chunk_size, |
| ) |
|
|
| images = (images / 2 + 0.5).clamp(0, 1) |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
| if output_type == "pil": |
| images = self.numpy_to_pil(images) |
|
|
| if not return_dict: |
| return (images,) |
|
|
| return ImagePipelineOutput(images=images) |
|
|