Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Image-to-image processor for LightDiffusion-Next.
This processor handles image-to-image generation and upscaling
using the Ultimate SD Upscale approach.
"""
import logging
from typing import TYPE_CHECKING, Any, Optional, Callable
import numpy as np
import torch
from PIL import Image
if TYPE_CHECKING:
from src.Core.PipelineContext import PipelineContext
from src.Core.AbstractModel import AbstractModel
class Img2Img:
"""Image-to-image generation and upscaling processor.
Uses Ultimate SD Upscale for high-quality image transformation
and super-resolution.
"""
# Default settings
DEFAULT_UPSCALE_BY = 2
DEFAULT_STEPS = 8
DEFAULT_CFG = 6
DEFAULT_DENOISE = 0.3
DEFAULT_SCHEDULER = "karras"
DEFAULT_TILE_WIDTH = 512
DEFAULT_TILE_HEIGHT = 512
DEFAULT_MASK_BLUR = 16
DEFAULT_TILE_PADDING = 32
DEFAULT_UPSCALER = "RealESRGAN_x4plus.pth"
@classmethod
def apply(
cls,
ctx: "PipelineContext",
model: "AbstractModel",
positive: Any,
negative: Any,
image_path: str = None,
image_tensor: torch.Tensor = None,
upscale_by: float = None,
denoise: float = None,
callback: Optional[Callable] = None,
) -> torch.Tensor:
"""Apply image-to-image transformation.
Args:
ctx: Pipeline context with configuration
model: The loaded model instance
positive: Positive conditioning
negative: Negative conditioning
image_path: Path to input image (used if image_tensor not provided)
image_tensor: Input image tensor [B, H, W, C] or [H, W, C]
upscale_by: Upscale factor (default: 2)
denoise: Denoising strength (default: 0.3)
callback: Optional callback for live previews
Returns:
Processed image tensor
"""
logger = logging.getLogger(__name__)
# Determine input source
if image_tensor is None:
source_path = image_path or ctx.features.img2img_image
if source_path is None:
raise ValueError("No input image provided for img2img")
# Load image from path
image_tensor = cls._load_image(source_path)
# Determine upscale factor from context dimensions if not provided
if upscale_by is None:
input_w = image_tensor.shape[2] # [B, H, W, C]
target_w = ctx.generation.width
if target_w and target_w != input_w:
upscale_by = target_w / input_w
logger.info(f"Img2Img: calculated upscale_by={upscale_by:.2f} from target width {target_w}")
else:
upscale_by = cls.DEFAULT_UPSCALE_BY
denoise = denoise or cls.DEFAULT_DENOISE
# Determine model flags
is_flux = getattr(model.capabilities, "is_flux", False)
is_flux2 = getattr(model.capabilities, "is_flux2", False)
# Adjust CFG for Flux models
img2img_cfg = cls.DEFAULT_CFG
if is_flux or is_flux2:
img2img_cfg = 1.0
try:
# Import required modules
from src.UltimateSDUpscale import UltimateSDUpscale, USDU_upscaler
# Load upscaler model
upscale_loader = USDU_upscaler.UpscaleModelLoader()
upscale_model = upscale_loader.load_model(cls.DEFAULT_UPSCALER)[0]
# Initialize Ultimate SD Upscale
upscaler = UltimateSDUpscale.UltimateSDUpscale()
# Get current seed from context
current_seed = ctx.seed
logger.info(f"Img2Img: processing with {upscale_by}x upscale, denoise={denoise}")
# Run upscaling
result = upscaler.upscale(
upscale_by=upscale_by,
seed=current_seed,
steps=cls.DEFAULT_STEPS,
cfg=img2img_cfg,
sampler_name=ctx.sampling.sampler,
scheduler=cls.DEFAULT_SCHEDULER,
denoise=denoise,
mode_type="Linear",
tile_width=cls.DEFAULT_TILE_WIDTH,
tile_height=cls.DEFAULT_TILE_HEIGHT,
mask_blur=cls.DEFAULT_MASK_BLUR,
tile_padding=cls.DEFAULT_TILE_PADDING,
seam_fix_mode="Half Tile",
seam_fix_denoise=0.2,
seam_fix_width=64,
seam_fix_mask_blur=16,
seam_fix_padding=32,
force_uniform_tiles="enable",
image=image_tensor,
model=model.model,
positive=positive,
negative=negative,
vae=model.vae,
upscale_model=upscale_model,
pipeline=True,
callback=callback or ctx.callback,
)
logger.info("Img2Img: completed successfully")
return result[0]
except Exception as e:
logger.exception(f"Img2Img failed: {e}")
# Return original image on failure
return image_tensor
@classmethod
def _load_image(cls, path: str) -> torch.Tensor:
"""Load an image from disk and convert to tensor.
Args:
path: Path to the image file
Returns:
Image tensor in [B, H, W, C] format, normalized to [0, 1]
"""
img = Image.open(path)
img_array = np.array(img)
img_tensor = torch.from_numpy(img_array).float().to("cpu") / 255.0
# Add batch dimension
if img_tensor.dim() == 3:
img_tensor = img_tensor.unsqueeze(0)
return img_tensor
@classmethod
def simple_img2img(
cls,
ctx: "PipelineContext",
model: "AbstractModel",
positive: Any,
negative: Any,
image_tensor: torch.Tensor,
denoise: float = 0.75,
last_step: Optional[int] = None,
callback: Optional[Callable] = None,
) -> dict:
"""Simple image-to-image without upscaling.
Encodes the input image to latents and runs diffusion with
the specified denoising strength.
Args:
ctx: Pipeline context
model: The loaded model
positive: Positive conditioning
negative: Negative conditioning
image_tensor: Input image tensor
denoise: Denoising strength (0.0 = no change, 1.0 = full generation)
last_step: Optional step to stop at (for refiner handoff)
Returns:
Dictionary with 'samples' key containing generated latents
"""
logger = logging.getLogger(__name__)
try:
from src.AutoEncoders import VariationalAE
from src.sample import sampling
from src.hidiffusion import msw_msa_attention
# Determine model flags
is_flux = getattr(model.capabilities, "is_flux", False)
is_flux2 = getattr(model.capabilities, "is_flux2", False)
# Encode image to latents (pass flux flag for correct encoding)
vae_encode = VariationalAE.VAEEncode()
latents = vae_encode.encode(
vae=model.vae,
pixels=image_tensor,
flux=is_flux or is_flux2,
)[0]
# Apply HiDiffusion optimizer (not for Flux)
if not is_flux:
try:
hidiff = msw_msa_attention.ApplyMSWMSAAttentionSimple()
optimized_model = hidiff.go(model_type="auto", model=model.model)[0]
except Exception:
optimized_model = model.model
else:
optimized_model = model.model
# Run sampling with denoise < 1.0
ksampler = sampling.KSampler()
result = ksampler.sample(
seed=ctx.seed,
steps=ctx.sampling.steps,
cfg=ctx.sampling.cfg if not is_flux else 1.0,
sampler_name=ctx.sampling.sampler,
scheduler=ctx.sampling.scheduler,
denoise=denoise,
model=optimized_model,
positive=positive,
negative=negative,
latent_image=latents,
pipeline=True,
flux=is_flux,
flux2=is_flux2,
enable_multiscale=False if is_flux else ctx.sampling.enable_multiscale,
cfg_free_enabled=ctx.sampling.cfg_free_enabled,
cfg_free_start_percent=ctx.sampling.cfg_free_start_percent,
last_step=last_step,
callback=callback or ctx.callback, # Enable live previews during sampling
)
return result[0]
except Exception as e:
logger.exception(f"Simple img2img failed: {e}")
raise
@classmethod
def is_enabled(cls, ctx: "PipelineContext") -> bool:
"""Check if Img2Img mode is enabled.
Args:
ctx: Pipeline context
Returns:
True if img2img mode is enabled
"""
return ctx.features.img2img