LogicGoInfotechSpaces's picture
Restore local ControlNet colorization pipeline
8d0a1ae
raw
history blame
7.68 kB
"""
Colorize model wrapper replicating the behaviour of the
`fffiloni/text-guided-image-colorization` Space.
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import BlipForConditionalGeneration, BlipProcessor
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache"
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as exc: # pragma: no cover
logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
return cache_dir
def _apply_lab_merge(original_luminance: Image.Image, color_map: Image.Image) -> Image.Image:
base_lab = original_luminance.convert("LAB")
color_lab = color_map.convert("LAB")
l_channel, _, _ = base_lab.split()
_, a_channel, b_channel = color_lab.split()
merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
return merged.convert("RGB")
def _clean_caption(prompt: str) -> str:
remove_terms = [
"black and white", "black & white", "monochrome", "bw photo",
"historical", "restored", "low contrast", "desaturated", "overcast",
]
cleaned = prompt
for term in remove_terms:
cleaned = cleaned.replace(term, "")
return cleaned.strip(" ,")
class ColorizeModel:
"""Colorization model that runs the SDXL + ControlNet pipeline locally."""
def __init__(self, model_id: str | None = None) -> None:
self.cache_dir = _ensure_cache_dir()
self.hf_token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACE_HUB_TOKEN")
or os.getenv("HUGGINGFACE_API_TOKEN")
)
if not self.hf_token:
logger.warning("HF token not provided – attempting to download public models only.")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
os.environ.setdefault("OMP_NUM_THREADS", "1")
self.controlnet_id = model_id or settings.MODEL_ID
self.base_model_id = settings.BASE_MODEL_ID
self.lightning_repo = settings.LIGHTNING_REPO
self.lightning_weights = settings.LIGHTNING_WEIGHTS
self.caption_model_id = settings.CAPTION_MODEL_ID
self.num_inference_steps = settings.NUM_INFERENCE_STEPS
self.guidance_scale = settings.GUIDANCE_SCALE
self.controlnet_scale = settings.CONTROLNET_SCALE
self.positive_prompt = settings.POSITIVE_PROMPT
self.negative_prompt = settings.NEGATIVE_PROMPT
self.caption_prefix = settings.CAPTION_PREFIX
self.seed = settings.COLORIZE_SEED
self._load_caption_model()
self._load_pipeline()
def _load_caption_model(self) -> None:
logger.info("Loading BLIP captioning model: %s", self.caption_model_id)
self.caption_processor = BlipProcessor.from_pretrained(
self.caption_model_id,
cache_dir=self.cache_dir,
token=self.hf_token,
)
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.caption_model_id,
cache_dir=self.cache_dir,
token=self.hf_token,
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
).to(self.device)
def _load_pipeline(self) -> None:
logger.info("Loading ControlNet model: %s", self.controlnet_id)
controlnet = ControlNetModel.from_pretrained(
self.controlnet_id,
torch_dtype=self.dtype,
cache_dir=self.cache_dir,
token=self.hf_token,
)
logger.info("Loading SDXL base model components: %s", self.base_model_id)
vae = AutoencoderKL.from_pretrained(
self.base_model_id,
subfolder="vae",
torch_dtype=self.dtype,
cache_dir=self.cache_dir,
token=self.hf_token,
)
unet = UNet2DConditionModel.from_config(
self.base_model_id,
subfolder="unet",
cache_dir=self.cache_dir,
token=self.hf_token,
)
lightning_path = hf_hub_download(
repo_id=self.lightning_repo,
filename=self.lightning_weights,
cache_dir=self.cache_dir,
token=self.hf_token,
)
unet.load_state_dict(load_file(lightning_path))
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
self.base_model_id,
vae=vae,
unet=unet,
controlnet=controlnet,
torch_dtype=self.dtype,
cache_dir=self.cache_dir,
token=self.hf_token,
safety_checker=None,
requires_safety_checker=False,
)
self.pipe.set_progress_bar_config(disable=True)
self.pipe.to(self.device, dtype=self.dtype)
if self.device.type == "cuda" and hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
try:
self.pipe.enable_xformers_memory_efficient_attention()
except Exception as exc: # pragma: no cover
logger.warning("Could not enable xFormers optimizations: %s", exc)
logger.info("Colorization pipeline ready.")
def caption_image(self, image: Image.Image) -> str:
inputs = self.caption_processor(
image,
self.caption_prefix,
return_tensors="pt",
).to(self.device)
if self.device.type != "cuda":
inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.inference_mode():
caption_ids = self.caption_model.generate(**inputs)
caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
return _clean_caption(caption)
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
original_size = image.size
control_image = image.convert("L").convert("RGB").resize((512, 512), Image.Resampling.LANCZOS)
caption = self.caption_image(image)
prompt_components = [self.positive_prompt, caption]
prompt = ", ".join([p for p in prompt_components if p])
steps = num_inference_steps or self.num_inference_steps
generator = torch.Generator(device=self.device).manual_seed(self.seed)
logger.info("Running ControlNet pipeline with prompt: %s", prompt)
result = self.pipe(
prompt=prompt,
negative_prompt=self.negative_prompt or None,
image=control_image,
control_image=control_image,
num_inference_steps=steps,
guidance_scale=self.guidance_scale,
controlnet_conditioning_scale=self.controlnet_scale,
generator=generator,
)
generated = result.images[0]
colorized = _apply_lab_merge(control_image, generated)
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
return colorized, caption