LogicGoInfotechSpaces commited on
Commit
8d0a1ae
·
1 Parent(s): d58eb50

Restore local ControlNet colorization pipeline

Browse files
Files changed (2) hide show
  1. app/colorize_model.py +134 -92
  2. app/config.py +6 -2
app/colorize_model.py CHANGED
@@ -1,17 +1,24 @@
1
  """
2
- Colorize model wrapper that forwards requests to the Hugging Face Inference API.
 
3
  """
4
 
5
  from __future__ import annotations
6
 
7
- import io
8
  import logging
9
  import os
10
  from typing import Tuple
11
 
12
- import requests
13
  import torch
14
  from PIL import Image
 
 
 
 
 
 
 
 
15
  from transformers import BlipForConditionalGeneration, BlipProcessor
16
 
17
  from app.config import settings
@@ -20,92 +27,138 @@ logger = logging.getLogger(__name__)
20
 
21
 
22
  def _ensure_cache_dir() -> str:
23
- """Ensure we have a writable Hugging Face cache directory."""
24
- data_dir = os.getenv("DATA_DIR")
25
- candidates = []
26
- if data_dir:
27
- candidates.append(os.path.join(data_dir, "hf_cache"))
28
- candidates.extend(
29
- [
30
- os.path.join("/tmp", "hf_cache"),
31
- os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
32
- ]
33
- )
34
-
35
- for path in candidates:
36
- try:
37
- os.makedirs(path, exist_ok=True)
38
- logger.info("Using HF cache directory: %s", path)
39
- os.environ["HF_HOME"] = path
40
- os.environ["HUGGINGFACE_HUB_CACHE"] = path
41
- os.environ["TRANSFORMERS_CACHE"] = path
42
- return path
43
- except Exception as exc:
44
- logger.warning("Failed to create cache dir %s: %s", path, exc)
45
-
46
- raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")
47
 
48
 
49
  def _clean_caption(prompt: str) -> str:
50
- replacements = [
51
- "black and white", "black & white", "monochrome", "monochromatic",
52
- "bw photo", "blurry", "grainy", "historical", "restored", "circa",
53
- "taken in", "overcast", "desaturated", "low contrast",
54
  ]
55
  cleaned = prompt
56
- for word in replacements:
57
- cleaned = cleaned.replace(word, "")
58
  return cleaned.strip(" ,")
59
 
60
 
61
  class ColorizeModel:
62
- """Colorization model that leverages the HF Inference API."""
63
-
64
- CAPTION_MODEL = "Salesforce/blip-image-captioning-large"
65
 
66
  def __init__(self, model_id: str | None = None) -> None:
67
- self.model_id = model_id or settings.MODEL_ID
68
- self.api_url = f"https://router.huggingface.co/hf-inference/models/{self.model_id}"
69
-
70
- self.api_token = (
71
- os.getenv("HUGGINGFACE_API_TOKEN")
72
  or os.getenv("HUGGINGFACE_HUB_TOKEN")
73
- or os.getenv("HF_TOKEN")
74
  )
75
- if not self.api_token:
76
- raise RuntimeError(
77
- "HUGGINGFACE_API_TOKEN (or HUGGINGFACE_HUB_TOKEN / HF_TOKEN) is not set. "
78
- "Please provide an access token with Inference API permissions."
79
- )
80
 
81
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
83
  os.environ.setdefault("OMP_NUM_THREADS", "1")
84
 
85
- self.cache_dir = _ensure_cache_dir()
86
- self.positive_prompt = settings.POSITIVE_PROMPT
87
- self.negative_prompt = settings.NEGATIVE_PROMPT
 
 
 
88
  self.num_inference_steps = settings.NUM_INFERENCE_STEPS
89
  self.guidance_scale = settings.GUIDANCE_SCALE
 
 
 
90
  self.caption_prefix = settings.CAPTION_PREFIX
91
  self.seed = settings.COLORIZE_SEED
92
- self.timeout = settings.INFERENCE_TIMEOUT
93
- self.provider = settings.INFERENCE_PROVIDER
94
 
95
  self._load_caption_model()
 
96
 
97
  def _load_caption_model(self) -> None:
98
- logger.info("Loading BLIP captioning model for prompt generation...")
99
  self.caption_processor = BlipProcessor.from_pretrained(
100
- self.CAPTION_MODEL,
101
- cache_dir=self.cache_dir
 
102
  )
103
  self.caption_model = BlipForConditionalGeneration.from_pretrained(
104
- self.CAPTION_MODEL,
 
 
105
  torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
106
- cache_dir=self.cache_dir
107
  ).to(self.device)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def caption_image(self, image: Image.Image) -> str:
110
  inputs = self.caption_processor(
111
  image,
@@ -121,43 +174,32 @@ class ColorizeModel:
121
  caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
122
  return _clean_caption(caption)
123
 
124
- def _build_payload(self, prompt: str) -> dict:
125
- payload = {
126
- "inputs": prompt,
127
- "parameters": {
128
- "num_inference_steps": self.num_inference_steps,
129
- "guidance_scale": self.guidance_scale,
130
- "negative_prompt": self.negative_prompt,
131
- "seed": self.seed,
132
- },
133
- }
134
- if self.provider:
135
- payload["provider"] = {"name": self.provider}
136
- return payload
137
-
138
- def colorize(self, image: Image.Image, _num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
139
- caption = self.caption_image(image)
140
- prompt_parts = [self.positive_prompt, caption]
141
- prompt = ", ".join([p for p in prompt_parts if p])
142
 
143
- headers = {
144
- "Authorization": f"Bearer {self.api_token}",
145
- "Content-Type": "application/json",
146
- }
147
- payload = self._build_payload(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- logger.info("Calling HF Inference API for prompt: %s", prompt)
150
- response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout)
 
 
151
 
152
- if response.status_code != 200:
153
- try:
154
- data = response.json()
155
- except ValueError:
156
- data = response.text
157
- logger.error("Inference API error (%s): %s", response.status_code, data)
158
- raise RuntimeError(f"Inference API error ({response.status_code}): {data}")
159
-
160
- colorized = Image.open(io.BytesIO(response.content)).convert("RGB")
161
- colorized = colorized.resize(image.size, Image.Resampling.LANCZOS)
162
  return colorized, caption
163
 
 
1
  """
2
+ Colorize model wrapper replicating the behaviour of the
3
+ `fffiloni/text-guided-image-colorization` Space.
4
  """
5
 
6
  from __future__ import annotations
7
 
 
8
  import logging
9
  import os
10
  from typing import Tuple
11
 
 
12
  import torch
13
  from PIL import Image
14
+ from diffusers import (
15
+ AutoencoderKL,
16
+ ControlNetModel,
17
+ StableDiffusionXLControlNetPipeline,
18
+ UNet2DConditionModel,
19
+ )
20
+ from huggingface_hub import hf_hub_download
21
+ from safetensors.torch import load_file
22
  from transformers import BlipForConditionalGeneration, BlipProcessor
23
 
24
  from app.config import settings
 
27
 
28
 
29
  def _ensure_cache_dir() -> str:
30
+ cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache"
31
+ try:
32
+ os.makedirs(cache_dir, exist_ok=True)
33
+ except Exception as exc: # pragma: no cover
34
+ logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
35
+ os.environ["HF_HOME"] = cache_dir
36
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
37
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
38
+ os.environ["HF_HUB_CACHE"] = cache_dir
39
+ return cache_dir
40
+
41
+
42
+ def _apply_lab_merge(original_luminance: Image.Image, color_map: Image.Image) -> Image.Image:
43
+ base_lab = original_luminance.convert("LAB")
44
+ color_lab = color_map.convert("LAB")
45
+ l_channel, _, _ = base_lab.split()
46
+ _, a_channel, b_channel = color_lab.split()
47
+ merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
48
+ return merged.convert("RGB")
 
 
 
 
 
49
 
50
 
51
  def _clean_caption(prompt: str) -> str:
52
+ remove_terms = [
53
+ "black and white", "black & white", "monochrome", "bw photo",
54
+ "historical", "restored", "low contrast", "desaturated", "overcast",
 
55
  ]
56
  cleaned = prompt
57
+ for term in remove_terms:
58
+ cleaned = cleaned.replace(term, "")
59
  return cleaned.strip(" ,")
60
 
61
 
62
  class ColorizeModel:
63
+ """Colorization model that runs the SDXL + ControlNet pipeline locally."""
 
 
64
 
65
  def __init__(self, model_id: str | None = None) -> None:
66
+ self.cache_dir = _ensure_cache_dir()
67
+ self.hf_token = (
68
+ os.getenv("HF_TOKEN")
 
 
69
  or os.getenv("HUGGINGFACE_HUB_TOKEN")
70
+ or os.getenv("HUGGINGFACE_API_TOKEN")
71
  )
72
+ if not self.hf_token:
73
+ logger.warning("HF token not provided – attempting to download public models only.")
 
 
 
74
 
75
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
  self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
77
  os.environ.setdefault("OMP_NUM_THREADS", "1")
78
 
79
+ self.controlnet_id = model_id or settings.MODEL_ID
80
+ self.base_model_id = settings.BASE_MODEL_ID
81
+ self.lightning_repo = settings.LIGHTNING_REPO
82
+ self.lightning_weights = settings.LIGHTNING_WEIGHTS
83
+ self.caption_model_id = settings.CAPTION_MODEL_ID
84
+
85
  self.num_inference_steps = settings.NUM_INFERENCE_STEPS
86
  self.guidance_scale = settings.GUIDANCE_SCALE
87
+ self.controlnet_scale = settings.CONTROLNET_SCALE
88
+ self.positive_prompt = settings.POSITIVE_PROMPT
89
+ self.negative_prompt = settings.NEGATIVE_PROMPT
90
  self.caption_prefix = settings.CAPTION_PREFIX
91
  self.seed = settings.COLORIZE_SEED
 
 
92
 
93
  self._load_caption_model()
94
+ self._load_pipeline()
95
 
96
  def _load_caption_model(self) -> None:
97
+ logger.info("Loading BLIP captioning model: %s", self.caption_model_id)
98
  self.caption_processor = BlipProcessor.from_pretrained(
99
+ self.caption_model_id,
100
+ cache_dir=self.cache_dir,
101
+ token=self.hf_token,
102
  )
103
  self.caption_model = BlipForConditionalGeneration.from_pretrained(
104
+ self.caption_model_id,
105
+ cache_dir=self.cache_dir,
106
+ token=self.hf_token,
107
  torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
 
108
  ).to(self.device)
109
 
110
+ def _load_pipeline(self) -> None:
111
+ logger.info("Loading ControlNet model: %s", self.controlnet_id)
112
+ controlnet = ControlNetModel.from_pretrained(
113
+ self.controlnet_id,
114
+ torch_dtype=self.dtype,
115
+ cache_dir=self.cache_dir,
116
+ token=self.hf_token,
117
+ )
118
+
119
+ logger.info("Loading SDXL base model components: %s", self.base_model_id)
120
+ vae = AutoencoderKL.from_pretrained(
121
+ self.base_model_id,
122
+ subfolder="vae",
123
+ torch_dtype=self.dtype,
124
+ cache_dir=self.cache_dir,
125
+ token=self.hf_token,
126
+ )
127
+ unet = UNet2DConditionModel.from_config(
128
+ self.base_model_id,
129
+ subfolder="unet",
130
+ cache_dir=self.cache_dir,
131
+ token=self.hf_token,
132
+ )
133
+ lightning_path = hf_hub_download(
134
+ repo_id=self.lightning_repo,
135
+ filename=self.lightning_weights,
136
+ cache_dir=self.cache_dir,
137
+ token=self.hf_token,
138
+ )
139
+ unet.load_state_dict(load_file(lightning_path))
140
+
141
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
142
+ self.base_model_id,
143
+ vae=vae,
144
+ unet=unet,
145
+ controlnet=controlnet,
146
+ torch_dtype=self.dtype,
147
+ cache_dir=self.cache_dir,
148
+ token=self.hf_token,
149
+ safety_checker=None,
150
+ requires_safety_checker=False,
151
+ )
152
+ self.pipe.set_progress_bar_config(disable=True)
153
+ self.pipe.to(self.device, dtype=self.dtype)
154
+ if self.device.type == "cuda" and hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
155
+ try:
156
+ self.pipe.enable_xformers_memory_efficient_attention()
157
+ except Exception as exc: # pragma: no cover
158
+ logger.warning("Could not enable xFormers optimizations: %s", exc)
159
+
160
+ logger.info("Colorization pipeline ready.")
161
+
162
  def caption_image(self, image: Image.Image) -> str:
163
  inputs = self.caption_processor(
164
  image,
 
174
  caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
175
  return _clean_caption(caption)
176
 
177
+ def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
178
+ original_size = image.size
179
+ control_image = image.convert("L").convert("RGB").resize((512, 512), Image.Resampling.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ caption = self.caption_image(image)
182
+ prompt_components = [self.positive_prompt, caption]
183
+ prompt = ", ".join([p for p in prompt_components if p])
184
+ steps = num_inference_steps or self.num_inference_steps
185
+ generator = torch.Generator(device=self.device).manual_seed(self.seed)
186
+
187
+ logger.info("Running ControlNet pipeline with prompt: %s", prompt)
188
+ result = self.pipe(
189
+ prompt=prompt,
190
+ negative_prompt=self.negative_prompt or None,
191
+ image=control_image,
192
+ control_image=control_image,
193
+ num_inference_steps=steps,
194
+ guidance_scale=self.guidance_scale,
195
+ controlnet_conditioning_scale=self.controlnet_scale,
196
+ generator=generator,
197
+ )
198
 
199
+ generated = result.images[0]
200
+ colorized = _apply_lab_merge(control_image, generated)
201
+ if colorized.size != original_size:
202
+ colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
203
 
 
 
 
 
 
 
 
 
 
 
204
  return colorized, caption
205
 
app/config.py CHANGED
@@ -18,8 +18,12 @@ class Settings(BaseSettings):
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
  # Model / inference settings
21
- MODEL_ID: str = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
22
- NUM_INFERENCE_STEPS: int = int(os.getenv("NUM_INFERENCE_STEPS", "30"))
 
 
 
 
23
  POSITIVE_PROMPT: str = os.getenv(
24
  "POSITIVE_PROMPT",
25
  "high quality color photo, vibrant natural colors, detailed lighting"
 
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
  # Model / inference settings
21
+ MODEL_ID: str = os.getenv("MODEL_ID", "fffiloni/controlnet-colorization-sdxl")
22
+ BASE_MODEL_ID: str = os.getenv("BASE_MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
23
+ LIGHTNING_REPO: str = os.getenv("LIGHTNING_REPO", "ByteDance/SDXL-Lightning")
24
+ LIGHTNING_WEIGHTS: str = os.getenv("LIGHTNING_WEIGHTS", "sdxl_lightning_8step_unet.safetensors")
25
+ CAPTION_MODEL_ID: str = os.getenv("CAPTION_MODEL_ID", "Salesforce/blip-image-captioning-base")
26
+ NUM_INFERENCE_STEPS: int = int(os.getenv("NUM_INFERENCE_STEPS", "20"))
27
  POSITIVE_PROMPT: str = os.getenv(
28
  "POSITIVE_PROMPT",
29
  "high quality color photo, vibrant natural colors, detailed lighting"