| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Decoder blocks for WorldEngine modular pipeline.""" |
| |
|
| | from typing import List, Union |
| |
|
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| |
|
| | from diffusers import AutoModel |
| | from diffusers.configuration_utils import FrozenDict |
| | from diffusers.image_processor import VaeImageProcessor |
| | from diffusers.utils import logging |
| | from diffusers.modular_pipelines import ( |
| | ModularPipelineBlocks, |
| | ModularPipeline, |
| | PipelineState, |
| | ) |
| | from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| | ComponentSpec, |
| | InputParam, |
| | OutputParam, |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class WorldEngineDecodeStep(ModularPipelineBlocks): |
| | """Decodes denoised latents back to RGB image using VAE.""" |
| |
|
| | model_name = "world_engine" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec("vae", AutoModel), |
| | ComponentSpec( |
| | "image_processor", |
| | VaeImageProcessor, |
| | config=FrozenDict( |
| | { |
| | "vae_scale_factor": 16, |
| | "do_normalize": False, |
| | "do_convert_rgb": True, |
| | } |
| | ), |
| | default_creation_method="from_config", |
| | ), |
| | ] |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Decodes denoised latents to RGB image using the VAE decoder" |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "latents", |
| | required=True, |
| | type_hint=torch.Tensor, |
| | description="Denoised latent tensor [1, 1, C, H, W]", |
| | ), |
| | InputParam( |
| | "output_type", |
| | default="pil", |
| | description="The output format for the generated images (pil, latent, pt, or np)", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "images", |
| | type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray], |
| | description="Decoded RGB image in requested output format", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | latents = block_state.latents |
| | output_type = block_state.output_type or "pil" |
| |
|
| | if output_type == "latent": |
| | block_state.images = latents |
| | else: |
| | |
| | |
| | |
| | image = components.vae.decode(latents.squeeze(1)) |
| |
|
| | |
| | if output_type == "pt": |
| | block_state.images = image |
| | elif output_type == "np": |
| | block_state.images = image.cpu().numpy() |
| | else: |
| | block_state.images = PIL.Image.fromarray(image.cpu().numpy()) |
| |
|
| | |
| | block_state.latents = None |
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|