LogicGoInfotechSpaces commited on
Commit
5e6062c
·
1 Parent(s): 8d0a1ae

Switch to FastAI GAN colorization model (Hammad712/GAN-Colorization-Model)

Browse files
Files changed (4) hide show
  1. app/colorize_model.py +91 -167
  2. app/config.py +6 -1
  3. app/main.py +3 -0
  4. requirements.txt +1 -0
app/colorize_model.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- Colorize model wrapper replicating the behaviour of the
3
- `fffiloni/text-guided-image-colorization` Space.
4
  """
5
 
6
  from __future__ import annotations
@@ -11,15 +11,8 @@ from typing import Tuple
11
 
12
  import torch
13
  from PIL import Image
14
- from diffusers import (
15
- AutoencoderKL,
16
- ControlNetModel,
17
- StableDiffusionXLControlNetPipeline,
18
- UNet2DConditionModel,
19
- )
20
- from huggingface_hub import hf_hub_download
21
- from safetensors.torch import load_file
22
- from transformers import BlipForConditionalGeneration, BlipProcessor
23
 
24
  from app.config import settings
25
 
@@ -30,7 +23,7 @@ def _ensure_cache_dir() -> str:
30
  cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache"
31
  try:
32
  os.makedirs(cache_dir, exist_ok=True)
33
- except Exception as exc: # pragma: no cover
34
  logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
35
  os.environ["HF_HOME"] = cache_dir
36
  os.environ["TRANSFORMERS_CACHE"] = cache_dir
@@ -39,167 +32,98 @@ def _ensure_cache_dir() -> str:
39
  return cache_dir
40
 
41
 
42
- def _apply_lab_merge(original_luminance: Image.Image, color_map: Image.Image) -> Image.Image:
43
- base_lab = original_luminance.convert("LAB")
44
- color_lab = color_map.convert("LAB")
45
- l_channel, _, _ = base_lab.split()
46
- _, a_channel, b_channel = color_lab.split()
47
- merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
48
- return merged.convert("RGB")
49
-
50
-
51
- def _clean_caption(prompt: str) -> str:
52
- remove_terms = [
53
- "black and white", "black & white", "monochrome", "bw photo",
54
- "historical", "restored", "low contrast", "desaturated", "overcast",
55
- ]
56
- cleaned = prompt
57
- for term in remove_terms:
58
- cleaned = cleaned.replace(term, "")
59
- return cleaned.strip(" ,")
60
-
61
-
62
  class ColorizeModel:
63
- """Colorization model that runs the SDXL + ControlNet pipeline locally."""
64
 
65
  def __init__(self, model_id: str | None = None) -> None:
66
  self.cache_dir = _ensure_cache_dir()
67
- self.hf_token = (
68
- os.getenv("HF_TOKEN")
69
- or os.getenv("HUGGINGFACE_HUB_TOKEN")
70
- or os.getenv("HUGGINGFACE_API_TOKEN")
71
- )
72
- if not self.hf_token:
73
- logger.warning("HF token not provided – attempting to download public models only.")
74
-
75
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
- self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
77
  os.environ.setdefault("OMP_NUM_THREADS", "1")
78
 
79
- self.controlnet_id = model_id or settings.MODEL_ID
80
- self.base_model_id = settings.BASE_MODEL_ID
81
- self.lightning_repo = settings.LIGHTNING_REPO
82
- self.lightning_weights = settings.LIGHTNING_WEIGHTS
83
- self.caption_model_id = settings.CAPTION_MODEL_ID
84
-
85
- self.num_inference_steps = settings.NUM_INFERENCE_STEPS
86
- self.guidance_scale = settings.GUIDANCE_SCALE
87
- self.controlnet_scale = settings.CONTROLNET_SCALE
88
- self.positive_prompt = settings.POSITIVE_PROMPT
89
- self.negative_prompt = settings.NEGATIVE_PROMPT
90
- self.caption_prefix = settings.CAPTION_PREFIX
91
- self.seed = settings.COLORIZE_SEED
92
-
93
- self._load_caption_model()
94
- self._load_pipeline()
95
-
96
- def _load_caption_model(self) -> None:
97
- logger.info("Loading BLIP captioning model: %s", self.caption_model_id)
98
- self.caption_processor = BlipProcessor.from_pretrained(
99
- self.caption_model_id,
100
- cache_dir=self.cache_dir,
101
- token=self.hf_token,
102
- )
103
- self.caption_model = BlipForConditionalGeneration.from_pretrained(
104
- self.caption_model_id,
105
- cache_dir=self.cache_dir,
106
- token=self.hf_token,
107
- torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
108
- ).to(self.device)
109
-
110
- def _load_pipeline(self) -> None:
111
- logger.info("Loading ControlNet model: %s", self.controlnet_id)
112
- controlnet = ControlNetModel.from_pretrained(
113
- self.controlnet_id,
114
- torch_dtype=self.dtype,
115
- cache_dir=self.cache_dir,
116
- token=self.hf_token,
117
- )
118
-
119
- logger.info("Loading SDXL base model components: %s", self.base_model_id)
120
- vae = AutoencoderKL.from_pretrained(
121
- self.base_model_id,
122
- subfolder="vae",
123
- torch_dtype=self.dtype,
124
- cache_dir=self.cache_dir,
125
- token=self.hf_token,
126
- )
127
- unet = UNet2DConditionModel.from_config(
128
- self.base_model_id,
129
- subfolder="unet",
130
- cache_dir=self.cache_dir,
131
- token=self.hf_token,
132
- )
133
- lightning_path = hf_hub_download(
134
- repo_id=self.lightning_repo,
135
- filename=self.lightning_weights,
136
- cache_dir=self.cache_dir,
137
- token=self.hf_token,
138
- )
139
- unet.load_state_dict(load_file(lightning_path))
140
-
141
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
142
- self.base_model_id,
143
- vae=vae,
144
- unet=unet,
145
- controlnet=controlnet,
146
- torch_dtype=self.dtype,
147
- cache_dir=self.cache_dir,
148
- token=self.hf_token,
149
- safety_checker=None,
150
- requires_safety_checker=False,
151
- )
152
- self.pipe.set_progress_bar_config(disable=True)
153
- self.pipe.to(self.device, dtype=self.dtype)
154
- if self.device.type == "cuda" and hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
155
- try:
156
- self.pipe.enable_xformers_memory_efficient_attention()
157
- except Exception as exc: # pragma: no cover
158
- logger.warning("Could not enable xFormers optimizations: %s", exc)
159
-
160
- logger.info("Colorization pipeline ready.")
161
-
162
- def caption_image(self, image: Image.Image) -> str:
163
- inputs = self.caption_processor(
164
- image,
165
- self.caption_prefix,
166
- return_tensors="pt",
167
- ).to(self.device)
168
-
169
- if self.device.type != "cuda":
170
- inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
171
-
172
- with torch.inference_mode():
173
- caption_ids = self.caption_model.generate(**inputs)
174
- caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
175
- return _clean_caption(caption)
176
 
177
  def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
178
- original_size = image.size
179
- control_image = image.convert("L").convert("RGB").resize((512, 512), Image.Resampling.LANCZOS)
180
-
181
- caption = self.caption_image(image)
182
- prompt_components = [self.positive_prompt, caption]
183
- prompt = ", ".join([p for p in prompt_components if p])
184
- steps = num_inference_steps or self.num_inference_steps
185
- generator = torch.Generator(device=self.device).manual_seed(self.seed)
186
-
187
- logger.info("Running ControlNet pipeline with prompt: %s", prompt)
188
- result = self.pipe(
189
- prompt=prompt,
190
- negative_prompt=self.negative_prompt or None,
191
- image=control_image,
192
- control_image=control_image,
193
- num_inference_steps=steps,
194
- guidance_scale=self.guidance_scale,
195
- controlnet_conditioning_scale=self.controlnet_scale,
196
- generator=generator,
197
- )
198
-
199
- generated = result.images[0]
200
- colorized = _apply_lab_merge(control_image, generated)
201
- if colorized.size != original_size:
202
- colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
203
-
204
- return colorized, caption
205
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Colorize model wrapper using FastAI GAN Colorization Model
3
+ Hammad712/GAN-Colorization-Model
4
  """
5
 
6
  from __future__ import annotations
 
11
 
12
  import torch
13
  from PIL import Image
14
+ from fastai.vision.all import *
15
+ from huggingface_hub import from_pretrained_fastai
 
 
 
 
 
 
 
16
 
17
  from app.config import settings
18
 
 
23
  cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache"
24
  try:
25
  os.makedirs(cache_dir, exist_ok=True)
26
+ except Exception as exc:
27
  logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
28
  os.environ["HF_HOME"] = cache_dir
29
  os.environ["TRANSFORMERS_CACHE"] = cache_dir
 
32
  return cache_dir
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class ColorizeModel:
36
+ """Colorization model using FastAI GAN model."""
37
 
38
  def __init__(self, model_id: str | None = None) -> None:
39
  self.cache_dir = _ensure_cache_dir()
 
 
 
 
 
 
 
 
40
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
41
  os.environ.setdefault("OMP_NUM_THREADS", "1")
42
 
43
+ # Use FastAI model ID from config or default
44
+ self.model_id = model_id or settings.MODEL_ID
45
+ self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")
46
+
47
+ logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
48
+ try:
49
+ self.learn = from_pretrained_fastai(self.model_id)
50
+ logger.info("FastAI GAN Colorization model loaded successfully")
51
+ except Exception as e:
52
+ error_msg = (
53
+ f"Failed to load FastAI model '{self.model_id}'. "
54
+ f"Error: {str(e)}\n"
55
+ f"Please check the MODEL_ID environment variable. "
56
+ f"Default model: 'Hammad712/GAN-Colorization-Model'"
57
+ )
58
+ logger.error(error_msg)
59
+ raise RuntimeError(error_msg) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
62
+ """
63
+ Colorize a grayscale or color image using FastAI GAN model.
64
+
65
+ Args:
66
+ image: PIL Image (grayscale or color)
67
+ num_inference_steps: Ignored for FastAI model (kept for API compatibility)
68
+
69
+ Returns:
70
+ Tuple of (colorized PIL Image, caption string)
71
+ """
72
+ try:
73
+ original_size = image.size
74
+
75
+ # Ensure image is RGB
76
+ if image.mode != "RGB":
77
+ image = image.convert("RGB")
78
+
79
+ # FastAI predict expects a PIL Image
80
+ logger.info("Running FastAI GAN colorization...")
81
+
82
+ # Use the model's predict method
83
+ # FastAI predict for image models typically returns the output image directly
84
+ # or as the first element of a tuple
85
+ prediction = self.learn.predict(image)
86
+
87
+ # Extract the colorized image from prediction
88
+ # Handle different return types from FastAI
89
+ if isinstance(prediction, (list, tuple)):
90
+ # If tuple/list, first element is usually the prediction
91
+ colorized = prediction[0] if len(prediction) > 0 else image
92
+ else:
93
+ # Direct return
94
+ colorized = prediction
95
+
96
+ # Ensure we have a PIL Image
97
+ if not isinstance(colorized, Image.Image):
98
+ # If it's a tensor, convert to PIL
99
+ if isinstance(colorized, torch.Tensor):
100
+ # Handle tensor conversion
101
+ if colorized.dim() == 4:
102
+ colorized = colorized[0] # Remove batch dimension
103
+ if colorized.dim() == 3:
104
+ # Convert CHW to HWC and denormalize if needed
105
+ colorized = colorized.permute(1, 2, 0).cpu()
106
+ # Clamp values to [0, 1] if float, or [0, 255] if uint8
107
+ if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
108
+ colorized = torch.clamp(colorized, 0, 1)
109
+ colorized = (colorized * 255).byte()
110
+ colorized = Image.fromarray(colorized.numpy(), 'RGB')
111
+ else:
112
+ raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
113
+ else:
114
+ raise ValueError(f"Unexpected prediction type: {type(colorized)}")
115
+
116
+ # Ensure RGB mode
117
+ if colorized.mode != "RGB":
118
+ colorized = colorized.convert("RGB")
119
+
120
+ # Resize back to original size if needed
121
+ if colorized.size != original_size:
122
+ colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
123
+
124
+ logger.info("Colorization completed successfully")
125
+ return colorized, self.output_caption
126
+
127
+ except Exception as e:
128
+ logger.error("Error during colorization: %s", str(e))
129
+ raise RuntimeError(f"Colorization failed: {str(e)}") from e
app/config.py CHANGED
@@ -18,7 +18,8 @@ class Settings(BaseSettings):
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
  # Model / inference settings
21
- MODEL_ID: str = os.getenv("MODEL_ID", "fffiloni/controlnet-colorization-sdxl")
 
22
  BASE_MODEL_ID: str = os.getenv("BASE_MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
23
  LIGHTNING_REPO: str = os.getenv("LIGHTNING_REPO", "ByteDance/SDXL-Lightning")
24
  LIGHTNING_WEIGHTS: str = os.getenv("LIGHTNING_WEIGHTS", "sdxl_lightning_8step_unet.safetensors")
@@ -36,6 +37,10 @@ class Settings(BaseSettings):
36
  CONTROLNET_SCALE: float = float(os.getenv("CONTROLNET_SCALE", "1.0"))
37
  CAPTION_PREFIX: str = os.getenv("CAPTION_PREFIX", "a photography of")
38
  COLORIZE_SEED: int = int(os.getenv("COLORIZE_SEED", "123"))
 
 
 
 
39
  INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "hf-inference")
40
  INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
41
 
 
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
  # Model / inference settings
21
+ MODEL_ID: str = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
22
+ MODEL_BACKEND: str = os.getenv("MODEL_BACKEND", "fastai")
23
  BASE_MODEL_ID: str = os.getenv("BASE_MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
24
  LIGHTNING_REPO: str = os.getenv("LIGHTNING_REPO", "ByteDance/SDXL-Lightning")
25
  LIGHTNING_WEIGHTS: str = os.getenv("LIGHTNING_WEIGHTS", "sdxl_lightning_8step_unet.safetensors")
 
37
  CONTROLNET_SCALE: float = float(os.getenv("CONTROLNET_SCALE", "1.0"))
38
  CAPTION_PREFIX: str = os.getenv("CAPTION_PREFIX", "a photography of")
39
  COLORIZE_SEED: int = int(os.getenv("COLORIZE_SEED", "123"))
40
+ FASTAI_OUTPUT_CAPTION: str = os.getenv(
41
+ "FASTAI_OUTPUT_CAPTION",
42
+ "Colorized using GAN-Colorization-Model"
43
+ )
44
  INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "hf-inference")
45
  INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
46
 
app/main.py CHANGED
@@ -3,6 +3,9 @@ FastAPI application for image colorization using ColorizeNet model
3
  with Firebase App Check integration
4
  """
5
  import os
 
 
 
6
  import uuid
7
  import logging
8
  from pathlib import Path
 
3
  with Firebase App Check integration
4
  """
5
  import os
6
+ # Set OMP_NUM_THREADS before any torch imports to avoid libgomp warnings
7
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
8
+
9
  import uuid
10
  import logging
11
  from pathlib import Path
requirements.txt CHANGED
@@ -14,4 +14,5 @@ firebase-admin>=6.0.0
14
  pydantic-settings>=2.0.0
15
  huggingface-hub>=0.16.0
16
  safetensors>=0.3.0
 
17
 
 
14
  pydantic-settings>=2.0.0
15
  huggingface-hub>=0.16.0
16
  safetensors>=0.3.0
17
+ fastai>=2.7.13
18