|
|
""" |
|
|
Colorize model wrapper replicating the behaviour of the |
|
|
`fffiloni/text-guided-image-colorization` Space. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
ControlNetModel, |
|
|
StableDiffusionXLControlNetPipeline, |
|
|
UNet2DConditionModel, |
|
|
) |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
from transformers import BlipForConditionalGeneration, BlipProcessor |
|
|
|
|
|
from app.config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _ensure_cache_dir() -> str: |
|
|
cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache" |
|
|
try: |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
except Exception as exc: |
|
|
logger.warning("Could not create cache directory %s: %s", cache_dir, exc) |
|
|
os.environ["HF_HOME"] = cache_dir |
|
|
os.environ["TRANSFORMERS_CACHE"] = cache_dir |
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir |
|
|
os.environ["HF_HUB_CACHE"] = cache_dir |
|
|
return cache_dir |
|
|
|
|
|
|
|
|
def _apply_lab_merge(original_luminance: Image.Image, color_map: Image.Image) -> Image.Image: |
|
|
base_lab = original_luminance.convert("LAB") |
|
|
color_lab = color_map.convert("LAB") |
|
|
l_channel, _, _ = base_lab.split() |
|
|
_, a_channel, b_channel = color_lab.split() |
|
|
merged = Image.merge("LAB", (l_channel, a_channel, b_channel)) |
|
|
return merged.convert("RGB") |
|
|
|
|
|
|
|
|
def _clean_caption(prompt: str) -> str: |
|
|
remove_terms = [ |
|
|
"black and white", "black & white", "monochrome", "bw photo", |
|
|
"historical", "restored", "low contrast", "desaturated", "overcast", |
|
|
] |
|
|
cleaned = prompt |
|
|
for term in remove_terms: |
|
|
cleaned = cleaned.replace(term, "") |
|
|
return cleaned.strip(" ,") |
|
|
|
|
|
|
|
|
class ColorizeModel: |
|
|
"""Colorization model that runs the SDXL + ControlNet pipeline locally.""" |
|
|
|
|
|
def __init__(self, model_id: str | None = None) -> None: |
|
|
self.cache_dir = _ensure_cache_dir() |
|
|
self.hf_token = ( |
|
|
os.getenv("HF_TOKEN") |
|
|
or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
or os.getenv("HUGGINGFACE_API_TOKEN") |
|
|
) |
|
|
if not self.hf_token: |
|
|
logger.warning("HF token not provided – attempting to download public models only.") |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
|
|
|
self.controlnet_id = model_id or settings.MODEL_ID |
|
|
self.base_model_id = settings.BASE_MODEL_ID |
|
|
self.lightning_repo = settings.LIGHTNING_REPO |
|
|
self.lightning_weights = settings.LIGHTNING_WEIGHTS |
|
|
self.caption_model_id = settings.CAPTION_MODEL_ID |
|
|
|
|
|
self.num_inference_steps = settings.NUM_INFERENCE_STEPS |
|
|
self.guidance_scale = settings.GUIDANCE_SCALE |
|
|
self.controlnet_scale = settings.CONTROLNET_SCALE |
|
|
self.positive_prompt = settings.POSITIVE_PROMPT |
|
|
self.negative_prompt = settings.NEGATIVE_PROMPT |
|
|
self.caption_prefix = settings.CAPTION_PREFIX |
|
|
self.seed = settings.COLORIZE_SEED |
|
|
|
|
|
self._load_caption_model() |
|
|
self._load_pipeline() |
|
|
|
|
|
def _load_caption_model(self) -> None: |
|
|
logger.info("Loading BLIP captioning model: %s", self.caption_model_id) |
|
|
self.caption_processor = BlipProcessor.from_pretrained( |
|
|
self.caption_model_id, |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
) |
|
|
self.caption_model = BlipForConditionalGeneration.from_pretrained( |
|
|
self.caption_model_id, |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32, |
|
|
).to(self.device) |
|
|
|
|
|
def _load_pipeline(self) -> None: |
|
|
logger.info("Loading ControlNet model: %s", self.controlnet_id) |
|
|
controlnet = ControlNetModel.from_pretrained( |
|
|
self.controlnet_id, |
|
|
torch_dtype=self.dtype, |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
) |
|
|
|
|
|
logger.info("Loading SDXL base model components: %s", self.base_model_id) |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
self.base_model_id, |
|
|
subfolder="vae", |
|
|
torch_dtype=self.dtype, |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
) |
|
|
unet = UNet2DConditionModel.from_config( |
|
|
self.base_model_id, |
|
|
subfolder="unet", |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
) |
|
|
lightning_path = hf_hub_download( |
|
|
repo_id=self.lightning_repo, |
|
|
filename=self.lightning_weights, |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
) |
|
|
unet.load_state_dict(load_file(lightning_path)) |
|
|
|
|
|
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
|
self.base_model_id, |
|
|
vae=vae, |
|
|
unet=unet, |
|
|
controlnet=controlnet, |
|
|
torch_dtype=self.dtype, |
|
|
cache_dir=self.cache_dir, |
|
|
token=self.hf_token, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False, |
|
|
) |
|
|
self.pipe.set_progress_bar_config(disable=True) |
|
|
self.pipe.to(self.device, dtype=self.dtype) |
|
|
if self.device.type == "cuda" and hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): |
|
|
try: |
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
except Exception as exc: |
|
|
logger.warning("Could not enable xFormers optimizations: %s", exc) |
|
|
|
|
|
logger.info("Colorization pipeline ready.") |
|
|
|
|
|
def caption_image(self, image: Image.Image) -> str: |
|
|
inputs = self.caption_processor( |
|
|
image, |
|
|
self.caption_prefix, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
if self.device.type != "cuda": |
|
|
inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
caption_ids = self.caption_model.generate(**inputs) |
|
|
caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True) |
|
|
return _clean_caption(caption) |
|
|
|
|
|
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: |
|
|
original_size = image.size |
|
|
control_image = image.convert("L").convert("RGB").resize((512, 512), Image.Resampling.LANCZOS) |
|
|
|
|
|
caption = self.caption_image(image) |
|
|
prompt_components = [self.positive_prompt, caption] |
|
|
prompt = ", ".join([p for p in prompt_components if p]) |
|
|
steps = num_inference_steps or self.num_inference_steps |
|
|
generator = torch.Generator(device=self.device).manual_seed(self.seed) |
|
|
|
|
|
logger.info("Running ControlNet pipeline with prompt: %s", prompt) |
|
|
result = self.pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=self.negative_prompt or None, |
|
|
image=control_image, |
|
|
control_image=control_image, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=self.guidance_scale, |
|
|
controlnet_conditioning_scale=self.controlnet_scale, |
|
|
generator=generator, |
|
|
) |
|
|
|
|
|
generated = result.images[0] |
|
|
colorized = _apply_lab_merge(control_image, generated) |
|
|
if colorized.size != original_size: |
|
|
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
return colorized, caption |
|
|
|
|
|
|