Spaces:
Runtime error
Runtime error
| import math | |
| import sys | |
| from typing import Callable, List, Optional, Union | |
| import torch | |
| from einops import rearrange, repeat | |
| from torch import Tensor | |
| from ..models.model import Flux | |
| from ..modules.conditioner import HFEmbedder | |
| from ..modules.image_embedders import ReduxImageEncoder | |
| # ------------------------------------------------------------------------- | |
| # Progress bar | |
| # ------------------------------------------------------------------------- | |
| import time | |
| TGT_PREFIX = "[TARGET-SCENE]" | |
| def print_progress_bar(iteration, total, prefix='', suffix='', length=30, fill='█'): | |
| """ | |
| Simple progress bar for console output, with elapsed and estimated remaining time. | |
| Args: | |
| iteration: Current iteration (Int) | |
| total: Total iterations (Int) | |
| prefix: Prefix string (Str) | |
| suffix: Suffix string (Str) | |
| length: Bar length (Int) | |
| fill: Bar fill character (Str) | |
| """ | |
| # Static variable to store start time | |
| if not hasattr(print_progress_bar, "_start_time") or iteration == 0: | |
| print_progress_bar._start_time = time.time() | |
| percent = f"{100 * (iteration / float(total)):.1f}%" | |
| filled_length = int(length * iteration // total) | |
| bar = fill * filled_length + '-' * (length - filled_length) | |
| elapsed = time.time() - print_progress_bar._start_time | |
| elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed)) | |
| if iteration > 0: | |
| avg_time_per_iter = elapsed / iteration | |
| remaining = avg_time_per_iter * (total - iteration) | |
| else: | |
| remaining = 0 | |
| remaining_str = time.strftime("%H:%M:%S", time.gmtime(remaining)) | |
| time_info = f"Elapsed: {elapsed_str} | ETA: {remaining_str}" | |
| sys.stdout.write(f'\r{prefix} |{bar}| {percent} {suffix} {time_info}') | |
| sys.stdout.flush() | |
| if iteration == total: | |
| sys.stdout.write('\n') | |
| sys.stdout.flush() | |
| # ------------------------------------------------------------------------- | |
| # 1) sampling func | |
| # ------------------------------------------------------------------------- | |
| def unpack(x: Tensor, height: int, width: int) -> Tensor: | |
| return rearrange( | |
| x, | |
| "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
| h=math.ceil(height / 16), | |
| w=math.ceil(width / 16), | |
| ph=2, | |
| pw=2, | |
| ) | |
| def time_shift(mu: float, sigma: float, t: Tensor): | |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
| def get_lin_function( | |
| x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 | |
| ): | |
| m = (y2 - y1) / (x2 - x1) | |
| b = y1 - m * x1 | |
| return lambda x: m * x + b | |
| def get_schedule( | |
| num_steps: int, | |
| image_seq_len: int, | |
| base_shift: float = 0.5, | |
| max_shift: float = 1.15, | |
| shift: bool = True, | |
| ): | |
| # extra step for zero | |
| timesteps = torch.linspace(1, 0, num_steps + 1) | |
| # shifting the schedule to favor high timesteps for higher signal images | |
| if shift: | |
| # eastimate mu based on linear estimation between two points | |
| mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) | |
| timesteps = time_shift(mu, 1.0, timesteps) | |
| return timesteps.tolist() | |
| def get_noise( | |
| num_samples: int, | |
| height: int, | |
| width: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| seed: int, | |
| ): | |
| noise = torch.cat( | |
| [torch.randn( | |
| 1, | |
| 16, | |
| # allow for packing | |
| 2 * math.ceil(height / 16), | |
| 2 * math.ceil(width / 16), | |
| device=device, | |
| dtype=dtype, | |
| generator=torch.Generator(device=device).manual_seed(seed+i), | |
| ) | |
| for i in range(num_samples) | |
| ], | |
| dim=0 | |
| ) | |
| return noise | |
| # ------------------------------------------------------------------------- | |
| # prepare input func | |
| # ------------------------------------------------------------------------- | |
| def _get_batch_size_and_prompt(prompt, img_shape): | |
| """ | |
| Helper to determine batch size and prompt list. | |
| """ | |
| bs, c, h, w = img_shape | |
| is_prompt_none = prompt is None | |
| return bs, prompt, is_prompt_none, h, w | |
| def _make_img_ids(bs, h, w, device=None, dtype=None): | |
| """ | |
| Helper to create image ids tensor. | |
| """ | |
| img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) | |
| img_ids[..., 1] = torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
| return img_ids | |
| def prepare( | |
| t5: HFEmbedder, | |
| clip: HFEmbedder, | |
| img: Tensor, | |
| prompt: Union[str, List[str], None], | |
| num_images_per_prompt: int = 1, | |
| ): | |
| """ | |
| Prepare the regular input for the Diffusion Transformer. | |
| """ | |
| img_bs, prompt, is_prompt_none, h, w = _get_batch_size_and_prompt(prompt, img.shape) | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| img_ids = _make_img_ids(img_bs, h, w, device=img.device, dtype=img.dtype) | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| txt_bs = len(prompt) | |
| if not is_prompt_none: | |
| prompt = [TGT_PREFIX + p for p in prompt] | |
| txt = t5(prompt) | |
| txt_ids = torch.zeros(txt_bs, txt.shape[1], 3, device=img.device, dtype=img.dtype) | |
| txt_vec = clip(prompt) | |
| else: | |
| txt = torch.zeros(txt_bs, 512, 4096, device=img.device, dtype=img.dtype) | |
| txt_ids = torch.zeros(txt_bs, 512, 3, device=img.device, dtype=img.dtype) | |
| txt_vec = torch.zeros(txt_bs, 768, device=img.device, dtype=img.dtype) | |
| if num_images_per_prompt > 1: | |
| txt = txt.repeat_interleave(num_images_per_prompt, dim=0) | |
| txt_ids = txt_ids.repeat_interleave(num_images_per_prompt, dim=0) | |
| txt_vec = txt_vec.repeat_interleave(num_images_per_prompt, dim=0) | |
| return { | |
| "img": img.to(img.device), | |
| "img_ids": img_ids.to(img.device), | |
| "txt": txt.to(img.device), | |
| "txt_ids": txt_ids.to(img.device), | |
| "txt_vec": txt_vec.to(img.device), | |
| } | |
| def prepare_with_redux( | |
| t5: HFEmbedder, | |
| clip: HFEmbedder, | |
| image_encoder: ReduxImageEncoder, | |
| img: Tensor, | |
| img_ip: Tensor, | |
| prompt: Union[str, List[str], None], | |
| num_images_per_prompt: int = 1, | |
| ): | |
| img_bs, prompt, is_prompt_none, h, w = _get_batch_size_and_prompt(prompt, img.shape) | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| img_ids = _make_img_ids(img_bs, h, w, device=img.device, dtype=img.dtype) | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| txt_bs = len(prompt) | |
| if not is_prompt_none: | |
| prompt = [TGT_PREFIX + p for p in prompt] | |
| txt = torch.cat((t5(prompt), image_encoder(img_ip)), dim=1) | |
| txt_ids = torch.zeros(txt_bs, txt.shape[1], 3, device=img.device, dtype=img.dtype) | |
| txt_vec = clip(prompt) | |
| else: | |
| txt = torch.zeros(txt_bs, 512, 4096, device=img.device, dtype=img.dtype) | |
| txt_ids = torch.zeros(txt_bs, 512, 3, device=img.device, dtype=img.dtype) | |
| txt_vec = torch.zeros(txt_bs, 768, device=img.device, dtype=img.dtype) | |
| if num_images_per_prompt > 1: | |
| txt = txt.repeat_interleave(num_images_per_prompt, dim=0) | |
| txt_ids = txt_ids.repeat_interleave(num_images_per_prompt, dim=0) | |
| txt_vec = txt_vec.repeat_interleave(num_images_per_prompt, dim=0) | |
| return { | |
| "img": img.to(img.device), | |
| "img_ids": img_ids.to(img.device), | |
| "txt": txt.to(img.device), | |
| "txt_ids": txt_ids.to(img.device), | |
| "txt_vec": txt_vec.to(img.device), | |
| } | |
| def prepare_image_cond( | |
| ae, | |
| img_ref, | |
| img_target, | |
| mask_target, | |
| dtype, | |
| device, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| batch_size, _, _, _ = img_target.shape | |
| # Apply mask to target image | |
| mask_targeted_img = img_target * mask_target | |
| if mask_target.shape[1] == 3: | |
| mask_target = mask_target[:, 0 : 1, :, :] | |
| with torch.no_grad(): | |
| autoencoder_dtype = next(ae.parameters()).dtype | |
| # Encode masked target image to latent space | |
| mask_targeted_latent = ae.encode(mask_targeted_img.to(autoencoder_dtype)).to(dtype) | |
| # Encode reference image to latent space | |
| reference_latent = ae.encode(img_ref.to(autoencoder_dtype)).to(dtype) | |
| # Repeat reference latent if batch size > 1 | |
| if reference_latent.shape[0] == 1 and batch_size > 1: | |
| reference_latent = repeat(reference_latent, "1 ... -> bs ...", bs=batch_size) | |
| # Concatenate reference and target latents | |
| latent_concat = torch.cat((reference_latent, mask_targeted_latent), dim=-1) | |
| # Pack latents into 2x2 patches | |
| latent_packed = rearrange(latent_concat, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| # Create reference mask (all ones) | |
| reference_mask = torch.ones_like(img_ref) | |
| if reference_mask.shape[1] == 3: | |
| reference_mask = reference_mask[:, 0 : 1, :, :] | |
| # Concatenate reference and target masks | |
| mask_concat = torch.cat((reference_mask, mask_target), dim=-1) | |
| # Pack masks into 16x16 patches for image conditioning | |
| mask_16x16 = rearrange(mask_concat, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=16, pw=16) | |
| # Interpolate masks to latent space dimensions | |
| mask_latent = torch.nn.functional.interpolate(mask_concat, size=(latent_concat.shape[2] // 2, latent_concat.shape[3] // 2), mode='nearest') | |
| # Pack interpolated masks into 1x1 patches for mask conditioning | |
| mask_cond = rearrange(mask_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=1, pw=1) | |
| # Combine packed latents and masks for image conditioning | |
| img_cond = torch.cat((latent_packed, mask_16x16), dim=-1) | |
| if num_images_per_prompt > 1: | |
| img_cond = img_cond.repeat_interleave(num_images_per_prompt, dim=0) | |
| mask_cond = mask_cond.repeat_interleave(num_images_per_prompt, dim=0) | |
| latent_packed = latent_packed.repeat_interleave(num_images_per_prompt, dim=0) | |
| return { | |
| "img_cond": img_cond.to(device).to(dtype), | |
| "mask_cond": mask_cond.to(device).to(dtype), | |
| "img_latent": latent_packed.to(device).to(dtype), | |
| } | |
| # ------------------------------------------------------------------------- | |
| # 2) denoise func | |
| # ------------------------------------------------------------------------- | |
| def is_even_step(step: int) -> bool: | |
| """Check if the current step is odd.""" | |
| return (step % 2 == 0) | |
| def denoise( | |
| model, | |
| img, | |
| img_ids, | |
| txt, | |
| txt_ids, | |
| txt_vec, | |
| timesteps, | |
| guidance: float = 4.0, | |
| img_cond: Tensor = None, | |
| mask_cond: Tensor = None, | |
| img_latent: Tensor = None, | |
| cond_w_regions: Optional[Union[List[int], int]] = None, | |
| mask_type_ids: Optional[Union[Tensor, int]] = None, | |
| height: int = 1024, | |
| width: int = 1024, | |
| use_background_preservation: bool = False, | |
| use_progressive_background_preservation: bool = True, | |
| background_blend_threshold: float = 0.8, | |
| true_gs: float = 1, | |
| timestep_to_start_cfg: int = 0, | |
| neg_txt: Tensor = None, | |
| neg_txt_ids: Tensor = None, | |
| neg_txt_vec: Tensor = None, | |
| show_progress: bool = False, | |
| use_flash_attention: bool = False, | |
| gradio_progress=None, | |
| ): | |
| do_true_cfg = true_gs > 1 and neg_txt is not None | |
| guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
| v_gt = img - img_latent | |
| num_steps = len(timesteps[:-1]) | |
| for step, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): | |
| if show_progress: | |
| print_progress_bar(step, num_steps, prefix='Denoising:', suffix=f'Step {step+1}/{num_steps}') | |
| # Update Gradio progress if available | |
| if gradio_progress is not None: | |
| # Map denoise progress to 0.2-0.8 range (since 0.0-0.2 is preprocessing, 0.8-1.0 is postprocessing) | |
| progress_value = (step / num_steps) | |
| gradio_progress(progress_value, desc=f"Denoising step {step+1}/{num_steps}") | |
| t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
| model_dtype = list(model.parameters())[0].dtype | |
| pred = model( | |
| img=torch.cat((img.to(model_dtype), img_cond.to(model_dtype)), dim=-1) if img_cond is not None else img.to(model_dtype), | |
| img_ids=img_ids.to(model_dtype), | |
| txt=txt.to(model_dtype), | |
| txt_ids=txt_ids.to(model_dtype), | |
| txt_vec=txt_vec.to(model_dtype), | |
| timesteps=t_vec.to(model_dtype), | |
| guidance=guidance_vec.to(model_dtype), | |
| cond_w_regions=cond_w_regions, | |
| mask_type_ids=mask_type_ids, | |
| height=height, | |
| width=width, | |
| use_flash_attention=use_flash_attention, | |
| ) | |
| if do_true_cfg and step >= timestep_to_start_cfg: | |
| neg_perd = model( | |
| img=torch.cat((img.to(model_dtype), img_cond.to(model_dtype)), dim=-1) if img_cond is not None else img.to(model_dtype), | |
| img_ids=img_ids.to(model_dtype), | |
| txt=neg_txt.to(model_dtype), | |
| txt_ids=neg_txt_ids.to(model_dtype), | |
| txt_vec=neg_txt_vec.to(model_dtype), | |
| timesteps=t_vec.to(model_dtype), | |
| guidance=guidance_vec.to(model_dtype), | |
| cond_w_regions=cond_w_regions, | |
| mask_type_ids=mask_type_ids, | |
| height=height, | |
| width=width, | |
| use_flash_attention=use_flash_attention, | |
| ) | |
| pred = neg_perd + true_gs * (pred - neg_perd) | |
| if use_background_preservation: | |
| is_early_phase = step <= num_steps * background_blend_threshold | |
| if is_early_phase: | |
| if use_progressive_background_preservation: | |
| if is_even_step(step): | |
| # Apply mask blending on odd steps in early phase | |
| masked_latent = pred * (1 - mask_cond) + v_gt * mask_cond | |
| else: | |
| # Use prediction directly for even steps or late phase | |
| masked_latent = pred | |
| else: | |
| masked_latent = pred * (1 - mask_cond) + v_gt * mask_cond | |
| else: | |
| # Use prediction directly for even steps or late phase | |
| masked_latent = pred | |
| img = img + (t_prev - t_curr) * masked_latent | |
| else: | |
| img = img + (t_prev - t_curr) * pred | |
| if show_progress: | |
| print_progress_bar(num_steps, num_steps, prefix='Denoising:', suffix='Complete') | |
| return img |