|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List |
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipelineBlocks, |
|
|
ComponentSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
ModularPipeline, |
|
|
PipelineState, |
|
|
) |
|
|
from diffusers.guiders import ClassifierFreeGuidance |
|
|
from transformers import UMT5EncoderModel, AutoTokenizer |
|
|
from diffusers.image_processor import PipelineImageInput |
|
|
import torch |
|
|
from diffusers.modular_pipelines.wan.encoders import WanTextEncoderStep |
|
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
|
from diffusers.video_processor import VideoProcessor |
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
|
|
|
|
|
|
class ChronoEditImageEncoderStep(ModularPipelineBlocks): |
|
|
model_name = "chronoedit" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("image_processor", CLIPImageProcessor), |
|
|
ComponentSpec("image_encoder", CLIPVisionModelWithProjection), |
|
|
] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [InputParam("image", type_hint=PipelineImageInput)] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"image_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="Image embeddings to use as conditions during the denoising process.", |
|
|
) |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def encode_image(components, image: PipelineImageInput, device: Optional[torch.device] = None): |
|
|
device = device or components.image_encoder.device |
|
|
image = components.image_processor(images=image, return_tensors="pt").to(device) |
|
|
image_embeds = components.image_encoder(**image, output_hidden_states=True) |
|
|
return image_embeds.hidden_states[-2] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
block_state.image_embeds = self.encode_image(components, block_state.image, components._execution_device) |
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class ChronoEditProcessImageStep(ModularPipelineBlocks): |
|
|
model_name = "chronoedit" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("image", type_hint=PipelineImageInput), |
|
|
InputParam("image_embeds", type_hint=torch.Tensor, required=False), |
|
|
InputParam("batch_size", type_hint=int, required=False), |
|
|
InputParam("height", type_hint=int), |
|
|
InputParam("width", type_hint=int), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam("processed_image", type_hint=PipelineImageInput), |
|
|
OutputParam("image_embeds", type_hint=torch.Tensor) |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec( |
|
|
"video_processor", |
|
|
VideoProcessor, |
|
|
config=FrozenDict({"vae_scale_factor": 8}), |
|
|
default_creation_method="from_config", |
|
|
) |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
image = block_state.image |
|
|
device = components._execution_device |
|
|
|
|
|
block_state.processed_image = components.video_processor.preprocess( |
|
|
image, height=block_state.height, width=block_state.width |
|
|
).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
if block_state.image_embeds is not None: |
|
|
image_embeds = block_state.image_embeds |
|
|
batch_size = block_state.batch_size |
|
|
block_state.image_embeds = image_embeds.repeat(batch_size, 1, 1).to(torch.bfloat16) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
|
|
|
class ChronoEditTextEncoderStep(WanTextEncoderStep): |
|
|
model_name = "chronoedit" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("text_encoder", UMT5EncoderModel), |
|
|
ComponentSpec("tokenizer", AutoTokenizer), |
|
|
ComponentSpec( |
|
|
"guider", |
|
|
ClassifierFreeGuidance, |
|
|
config=FrozenDict({"guidance_scale": 1.0}), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
self.check_inputs(block_state) |
|
|
|
|
|
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 |
|
|
block_state.device = components._execution_device |
|
|
|
|
|
block_state.negative_prompt_embeds = None |
|
|
|
|
|
( |
|
|
block_state.prompt_embeds, |
|
|
block_state.negative_prompt_embeds, |
|
|
) = self.encode_prompt( |
|
|
components, |
|
|
block_state.prompt, |
|
|
block_state.device, |
|
|
1, |
|
|
block_state.prepare_unconditional_embeds, |
|
|
block_state.negative_prompt, |
|
|
prompt_embeds=None, |
|
|
negative_prompt_embeds=block_state.negative_prompt_embeds, |
|
|
) |
|
|
|
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|