File size: 1,037 Bytes
e8aa256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput, is_flax_available
@dataclass
class StableDiffusionXLPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
if is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionXLPipelineOutput(BaseOutput):
"""
Output class for Flax Stable Diffusion XL pipelines.
Args:
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
"""
images: np.ndarray
|