| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Text and controller encoder blocks for WorldEngine modular pipeline.""" |
|
|
| import html |
| from typing import List, Set, Tuple, Union |
|
|
| import regex as re |
| import torch |
| from transformers import AutoTokenizer, UMT5EncoderModel |
|
|
| from diffusers.utils import is_ftfy_available, logging |
| from diffusers.modular_pipelines import ( |
| ModularPipelineBlocks, |
| ModularPipeline, |
| PipelineState, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| ComponentSpec, |
| ConfigSpec, |
| InputParam, |
| OutputParam, |
| ) |
|
|
| if is_ftfy_available(): |
| import ftfy |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def basic_clean(text): |
| text = ftfy.fix_text(text) |
| text = html.unescape(html.unescape(text)) |
| return text.strip() |
|
|
|
|
| def whitespace_clean(text): |
| text = re.sub(r"\s+", " ", text) |
| text = text.strip() |
| return text |
|
|
|
|
| def prompt_clean(text): |
| text = whitespace_clean(basic_clean(text)) |
| return text |
|
|
|
|
| class WorldEngineTextEncoderStep(ModularPipelineBlocks): |
| """Encodes text prompts using UMT5-XL for conditioning.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Text Encoder step that generates text embeddings to guide frame generation" |
| ) |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("text_encoder", UMT5EncoderModel), |
| ComponentSpec("tokenizer", AutoTokenizer), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "prompt", |
| description="The prompt or prompts to guide the frame generation", |
| ), |
| InputParam( |
| "prompt_embeds", |
| type_hint=torch.Tensor, |
| description="Pre-computed text embeddings", |
| ), |
| InputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| description="Padding mask for prompt embeddings", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "prompt_embeds", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Text embeddings used to guide frame generation", |
| ), |
| OutputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Padding mask for prompt embeddings", |
| ), |
| ] |
|
|
| @staticmethod |
| def check_inputs(block_state): |
| if block_state.prompt is not None and ( |
| not isinstance(block_state.prompt, str) |
| and not isinstance(block_state.prompt, list) |
| ): |
| raise ValueError( |
| f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" |
| ) |
|
|
| @staticmethod |
| def encode_prompt( |
| components, |
| prompt: Union[str, List[str]], |
| device: torch.device, |
| max_sequence_length: int = 512, |
| ): |
| dtype = components.text_encoder.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| prompt = [prompt_clean(p) for p in prompt] |
|
|
| text_inputs = components.tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids.to(device) |
| attention_mask = text_inputs.attention_mask.to(device) |
|
|
| prompt_embeds = components.text_encoder( |
| text_input_ids, attention_mask |
| ).last_hidden_state |
| prompt_embeds = prompt_embeds.to(dtype=dtype) |
|
|
| |
| prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as( |
| prompt_embeds |
| ) |
|
|
| |
| prompt_pad_mask = attention_mask.eq(0) |
|
|
| return prompt_embeds, prompt_pad_mask |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| self.check_inputs(block_state) |
|
|
| device = components._execution_device |
| if block_state.prompt_embeds is None: |
| block_state.prompt = block_state.prompt or "An explorable world" |
| ( |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| ) = self.encode_prompt(components, block_state.prompt, device) |
| block_state.prompt_embeds = block_state.prompt_embeds.contiguous() |
|
|
| if block_state.prompt_pad_mask is None: |
| block_state.prompt_pad_mask = torch.zeros( |
| block_state.prompt_embeds.shape[:2], |
| dtype=torch.bool, |
| device=device, |
| ) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class WorldEngineControllerEncoderStep(ModularPipelineBlocks): |
| """Encodes controller inputs (mouse + buttons + scroll) for conditioning.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [] |
|
|
| @property |
| def expected_configs(self) -> List[ComponentSpec]: |
| return [ConfigSpec("n_buttons", 256)] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "button", |
| type_hint=Set[int], |
| default=set(), |
| description="Set of pressed button IDs", |
| ), |
| InputParam( |
| "mouse", |
| type_hint=Tuple[float, float], |
| default=(0.0, 0.0), |
| description="Mouse velocity (x, y)", |
| ), |
| InputParam( |
| "scroll", |
| type_hint=int, |
| default=0, |
| description="Scroll wheel direction (-1, 0, 1)", |
| ), |
| InputParam( |
| "button_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="One-hot encoded button tensor", |
| ), |
| InputParam( |
| "mouse_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Mouse velocity tensor", |
| ), |
| InputParam( |
| "scroll_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Scroll wheel sign tensor", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "button_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="One-hot encoded button tensor", |
| ), |
| OutputParam( |
| "mouse_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Mouse velocity tensor", |
| ), |
| OutputParam( |
| "scroll_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Scroll wheel sign tensor", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| n_buttons = components.config.n_buttons |
|
|
| |
| if block_state.button_tensor is None: |
| block_state.button_tensor = torch.zeros( |
| (1, 1, n_buttons), device=device, dtype=dtype |
| ) |
|
|
| |
| block_state.button_tensor.zero_() |
| if block_state.button: |
| for btn_id in block_state.button: |
| if 0 <= btn_id < n_buttons: |
| block_state.button_tensor[0, 0, btn_id] = 1.0 |
|
|
| |
| if block_state.mouse_tensor is None: |
| block_state.mouse_tensor = torch.zeros( |
| (1, 1, 2), device=device, dtype=dtype |
| ) |
|
|
| |
| mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0) |
| block_state.mouse_tensor[0, 0, 0] = mouse[0] |
| block_state.mouse_tensor[0, 0, 1] = mouse[1] |
|
|
| |
| if block_state.scroll_tensor is None: |
| block_state.scroll_tensor = torch.zeros( |
| (1, 1, 1), device=device, dtype=dtype |
| ) |
|
|
| |
| scroll = block_state.scroll if block_state.scroll is not None else 0 |
| block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|