LogicGoInfotechSpaces commited on
Commit
8f6f449
·
1 Parent(s): eb42092

Align pipeline with text-guided colorization Space

Browse files
Files changed (3) hide show
  1. app/colorize_model.py +259 -254
  2. app/config.py +14 -2
  3. app/main.py +3 -2
app/colorize_model.py CHANGED
@@ -1,275 +1,280 @@
1
  """
2
- ColorizeNet model wrapper for image colorization
 
3
  """
 
 
 
4
  import logging
5
  import os
 
 
6
  import torch
7
- import numpy as np
8
  from PIL import Image
9
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, StableDiffusionImg2ImgPipeline
10
- from diffusers.utils import load_image
11
- from transformers import pipeline
12
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
13
  from app.config import settings
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class ColorizeModel:
18
- """Wrapper for ColorizeNet model"""
19
-
20
- def __init__(self, model_id: str | None = None):
21
- """
22
- Initialize the ColorizeNet model
23
-
24
- Args:
25
- model_id: Hugging Face model ID for ColorizeNet
26
- """
27
- if model_id is None:
28
- model_id = settings.MODEL_ID
29
- self.model_id = model_id
30
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info("Using device: %s", self.device)
32
- self.dtype = torch.float16 if self.device == "cuda" else torch.float32
33
- # Check for Hugging Face token (try both environment variable names)
34
- self.hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
35
-
36
- # Configure writable cache to avoid permission issues on Spaces
37
- # Prefer DATA_DIR if available, otherwise fallback to /tmp
38
- data_dir = os.getenv("DATA_DIR")
39
- if not data_dir:
40
- data_dir = "/tmp"
41
- hf_cache_dir = os.path.join(data_dir, "hf_cache")
42
-
43
- # Set cache environment variables
44
- os.environ["HF_HOME"] = hf_cache_dir
45
- os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
46
- os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
47
-
48
- try:
49
- os.makedirs(hf_cache_dir, exist_ok=True)
50
- logger.info("HF cache directory: %s", hf_cache_dir)
51
- except Exception as e:
52
- # Fallback to /tmp/hf_cache if DATA_DIR was set but not writable
53
- tmp_cache_dir = os.path.join("/tmp", "hf_cache")
54
- logger.warning("Failed to create cache in %s: %s, trying %s", data_dir, str(e), tmp_cache_dir)
55
- hf_cache_dir = tmp_cache_dir
56
- os.environ["HF_HOME"] = hf_cache_dir
57
- os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
58
- os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
59
- try:
60
- os.makedirs(hf_cache_dir, exist_ok=True)
61
- logger.info("HF cache directory (tmp): %s", hf_cache_dir)
62
- except Exception as e_tmp:
63
- # Final fallback to user home (local dev)
64
- logger.warning("Failed to create cache in /tmp: %s, trying user home", str(e_tmp))
65
- default_home_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
66
- hf_cache_dir = default_home_cache
67
- os.environ["HF_HOME"] = hf_cache_dir
68
- os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
69
- os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
70
- try:
71
- os.makedirs(hf_cache_dir, exist_ok=True)
72
- logger.info("HF cache directory (home): %s", hf_cache_dir)
73
- except Exception as e2:
74
- logger.error("Failed to create cache directory: %s", str(e2))
75
- raise RuntimeError(f"Cannot create Hugging Face cache directory: {str(e2)}")
76
 
77
- else:
78
- # Ensure environment variables reflect the final cache dir
79
- os.environ["HF_HOME"] = hf_cache_dir
80
- os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
81
- os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
82
- # Avoid libgomp warning by setting a valid integer
83
  os.environ.setdefault("OMP_NUM_THREADS", "1")
84
-
85
- try:
86
- # Decide whether to use ControlNet based on model_id
87
- wants_controlnet = "control" in self.model_id.lower()
88
 
89
- if wants_controlnet:
90
- # Try loading as ControlNet with Stable Diffusion
91
- logger.info("Attempting to load model as ControlNet: %s", self.model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
- # Load ControlNet model
94
- self.controlnet = ControlNetModel.from_pretrained(
95
- self.model_id,
96
- torch_dtype=self.dtype,
97
- token=self.hf_token,
98
- cache_dir=hf_cache_dir
99
- )
100
-
101
- # Try SDXL first, fallback to SD 1.5
102
- try:
103
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
104
- "stabilityai/stable-diffusion-xl-base-1.0",
105
- controlnet=self.controlnet,
106
- torch_dtype=self.dtype,
107
- safety_checker=None,
108
- requires_safety_checker=False,
109
- token=self.hf_token,
110
- cache_dir=hf_cache_dir
111
- )
112
- logger.info("Loaded with SDXL base model")
113
- except Exception:
114
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
115
- "runwayml/stable-diffusion-v1-5",
116
- controlnet=self.controlnet,
117
- torch_dtype=self.dtype,
118
- safety_checker=None,
119
- requires_safety_checker=False,
120
- token=self.hf_token,
121
- cache_dir=hf_cache_dir
122
- )
123
- logger.info("Loaded with SD 1.5 base model")
124
-
125
- self.pipe.to(self.device)
126
-
127
- # Enable memory efficient attention if available
128
- if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
129
- try:
130
- self.pipe.enable_xformers_memory_efficient_attention()
131
- logger.info("XFormers memory efficient attention enabled")
132
- except Exception as e:
133
- logger.warning("Could not enable XFormers: %s", str(e))
134
-
135
- logger.info("ColorizeNet model loaded successfully as ControlNet")
136
- self.model_type = "controlnet"
137
- except Exception as e:
138
- logger.warning("Failed to load as ControlNet: %s", str(e))
139
- wants_controlnet = False # fall through to pipeline
140
-
141
- if not wants_controlnet:
142
- # Load as image-to-image pipeline
143
- logger.info("Trying to load as image-to-image pipeline...")
144
- self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
145
- self.model_id,
146
- torch_dtype=self.dtype,
147
- safety_checker=None,
148
- requires_safety_checker=False,
149
- use_safetensors=True,
150
- cache_dir=hf_cache_dir,
151
- token=self.hf_token
152
- ).to(self.device)
153
- logger.info("ColorizeNet model loaded using image-to-image pipeline")
154
- self.model_type = "pipeline"
155
-
156
- except Exception as e:
157
- logger.error("Failed to load ColorizeNet model: %s", str(e))
158
- raise RuntimeError(f"Could not load ColorizeNet model: {str(e)}")
159
-
160
- def preprocess_image(self, image: Image.Image) -> Image.Image:
161
- """
162
- Preprocess image for colorization
163
-
164
- Args:
165
- image: PIL Image
166
-
167
- Returns:
168
- Preprocessed PIL Image
169
- """
170
- # Convert to grayscale if needed
171
- if image.mode != "L":
172
- # Convert to grayscale
173
- image = image.convert("L")
174
-
175
- # Convert back to RGB (grayscale image with 3 channels)
176
- image = image.convert("RGB")
177
-
178
- # Resize to standard size (512x512 for SD models)
179
- image = image.resize((512, 512), Image.Resampling.LANCZOS)
180
-
181
- return image
182
-
183
- def colorize(self, image: Image.Image, num_inference_steps: int = None) -> Image.Image:
184
- """
185
- Colorize a grayscale image
186
-
187
- Args:
188
- image: PIL Image (grayscale or color)
189
- num_inference_steps: Number of inference steps (auto-adjusted for CPU/GPU)
190
-
191
- Returns:
192
- Colorized PIL Image
193
- """
194
  try:
195
- # Optimize inference steps based on device
196
- if num_inference_steps is None:
197
- # Use fewer steps on CPU for faster processing
198
- num_inference_steps = 8 if self.device == "cpu" else 20
199
-
200
- # Preprocess image
201
- control_image = self.preprocess_image(image)
202
  original_size = image.size
203
-
204
- # Prepare prompt for colorization
205
- prompt = "colorize this black and white image, high quality, detailed, vibrant colors, natural colors"
206
- negative_prompt = "black and white, grayscale, monochrome, low quality, blurry, desaturated"
207
-
208
- # Adjust guidance scale for CPU (lower = faster)
209
- guidance_scale = 5.0 if self.device == "cpu" else 7.5
210
-
211
- # Generate colorized image based on model type
212
- if self.model_type == "controlnet":
213
- # Use ControlNet pipeline
214
- result = self.pipe(
215
- prompt=prompt,
216
- image=control_image,
217
- negative_prompt=negative_prompt,
218
- num_inference_steps=num_inference_steps,
219
- guidance_scale=guidance_scale,
220
- controlnet_conditioning_scale=1.0,
221
- generator=torch.Generator(device=self.device).manual_seed(42)
222
- )
223
-
224
- if isinstance(result, dict) and "images" in result:
225
- colorized = result["images"][0]
226
- elif isinstance(result, list) and len(result) > 0:
227
- colorized = result[0]
228
- else:
229
- colorized = result
230
- else:
231
- # Use pipeline directly
232
- result = self.pipe(
233
- prompt=prompt,
234
- image=control_image,
235
- negative_prompt=negative_prompt,
236
- num_inference_steps=num_inference_steps,
237
- guidance_scale=guidance_scale,
238
- strength=1.0
239
- )
240
-
241
- if isinstance(result, dict) and "images" in result:
242
- colorized = result["images"][0]
243
- elif isinstance(result, list) and len(result) > 0:
244
- colorized = result[0]
245
- else:
246
- colorized = result
247
-
248
- # Ensure we have a PIL Image
249
- if not isinstance(colorized, Image.Image):
250
- if isinstance(colorized, np.ndarray):
251
- # Handle numpy array
252
- if colorized.dtype != np.uint8:
253
- colorized = (colorized * 255).astype(np.uint8)
254
- if len(colorized.shape) == 3 and colorized.shape[2] == 3:
255
- colorized = Image.fromarray(colorized, 'RGB')
256
- else:
257
- colorized = Image.fromarray(colorized)
258
- elif torch.is_tensor(colorized):
259
- # Handle torch tensor
260
- colorized = colorized.cpu().permute(1, 2, 0).numpy()
261
- colorized = (colorized * 255).astype(np.uint8)
262
- colorized = Image.fromarray(colorized, 'RGB')
263
- else:
264
- raise ValueError(f"Unexpected output type: {type(colorized)}")
265
-
266
- # Resize back to original size
267
- if original_size != (512, 512):
268
  colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
269
-
270
- return colorized
271
-
272
- except Exception as e:
273
- logger.error("Error during colorization: %s", str(e))
274
  raise
275
 
 
1
  """
2
+ Colorize model wrapper replicating the behaviour of the
3
+ `fffiloni/text-guided-image-colorization` Space.
4
  """
5
+
6
+ from __future__ import annotations
7
+
8
  import logging
9
  import os
10
+ from typing import Tuple
11
+
12
  import torch
 
13
  from PIL import Image
14
+ from diffusers import (
15
+ AutoencoderKL,
16
+ ControlNetModel,
17
+ StableDiffusionXLControlNetPipeline,
18
+ UNet2DConditionModel,
19
+ )
20
+ from huggingface_hub import hf_hub_download, snapshot_download
21
+ from safetensors.torch import load_file
22
+ from transformers import BlipForConditionalGeneration, BlipProcessor
23
+
24
  from app.config import settings
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
+
29
+ def _ensure_cache_dir() -> str:
30
+ """Ensure we have a writable Hugging Face cache directory."""
31
+ data_dir = os.getenv("DATA_DIR")
32
+ candidate_dirs = []
33
+ if data_dir:
34
+ candidate_dirs.append(os.path.join(data_dir, "hf_cache"))
35
+ candidate_dirs.extend(
36
+ [
37
+ os.path.join("/tmp", "hf_cache"),
38
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
39
+ ]
40
+ )
41
+
42
+ for path in candidate_dirs:
43
+ try:
44
+ os.makedirs(path, exist_ok=True)
45
+ logger.info("Using HF cache directory: %s", path)
46
+ os.environ["HF_HOME"] = path
47
+ os.environ["HUGGINGFACE_HUB_CACHE"] = path
48
+ os.environ["TRANSFORMERS_CACHE"] = path
49
+ return path
50
+ except Exception as exc: # pragma: no cover - best effort
51
+ logger.warning("Failed to create cache dir %s: %s", path, exc)
52
+
53
+ raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")
54
+
55
+
56
+ def _apply_color(luminance_image: Image.Image, color_map: Image.Image) -> Image.Image:
57
+ """Merge the L channel of the grayscale control image with AB channels from generated image."""
58
+ image_lab = luminance_image.convert("LAB")
59
+ color_map_lab = color_map.convert("LAB")
60
+ l_channel, _, _ = image_lab.split()
61
+ _, a_channel, b_channel = color_map_lab.split()
62
+ merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
63
+ return merged.convert("RGB")
64
+
65
+
66
+ def _remove_unlikely_words(prompt: str) -> str:
67
+ """Clean up BLIP captions to avoid misleading descriptors."""
68
+ unlikely_words = []
69
+
70
+ decades = [f"{i}s" for i in range(1900, 2000)]
71
+ years = [f"{i}" for i in range(1900, 2000)]
72
+ years_with_word = [f"year {i}" for i in range(1900, 2000)]
73
+ circa_years = [f"circa {i}" for i in range(1900, 2000)]
74
+
75
+ expanded = [
76
+ [f"{d[0]} {d[1]} {d[2]} {d[3]} s" for d in decades],
77
+ [f"{d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
78
+ [f"year {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
79
+ [f"circa {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
80
+ ]
81
+
82
+ manual_terms = [
83
+ "black and white,", "black and white", "black & white,", "black & white",
84
+ "circa", "monochrome,", "monochrome", "bw", "bw,", "b&w", "b&w,",
85
+ "grainy", "grainy photo", "grainy photograph", "grainy footage",
86
+ "black-and-white", "black - and - white", "black on white",
87
+ "historical photo", "historic photo", "restored", "desaturated",
88
+ "low contrast", "blurry", "overcast", "taken in", "photo taken in",
89
+ ", photo", ", photo", ", photo", ", photograph",
90
+ ]
91
+
92
+ for seq in expanded:
93
+ unlikely_words.extend(seq)
94
+ unlikely_words.extend(decades + years + years_with_word + circa_years + manual_terms)
95
+
96
+ cleaned = prompt
97
+ for word in unlikely_words:
98
+ cleaned = cleaned.replace(word, "")
99
+ return cleaned.strip(" ,")
100
+
101
+
102
  class ColorizeModel:
103
+ """Colorization model wrapper."""
104
+
105
+ CONTROLNET_REPO = "nickpai/sdxl_light_caption_output"
106
+ CONTROLNET_SUBDIR = os.path.join("checkpoint-30000", "controlnet")
107
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
108
+ LIGHTNING_REPO = "ByteDance/SDXL-Lightning"
109
+ LIGHTNING_WEIGHTS = "sdxl_lightning_8step_unet.safetensors"
110
+ CAPTION_MODEL = "Salesforce/blip-image-captioning-large"
111
+
112
+ def __init__(self, model_id: str | None = None) -> None:
113
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
114
  logger.info("Using device: %s", self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
 
 
 
 
 
117
  os.environ.setdefault("OMP_NUM_THREADS", "1")
 
 
 
 
118
 
119
+ self.hf_token = (
120
+ os.getenv("HF_TOKEN")
121
+ or os.getenv("HUGGINGFACE_HUB_TOKEN")
122
+ or None
123
+ )
124
+ self.cache_dir = _ensure_cache_dir()
125
+
126
+ self.num_inference_steps = settings.NUM_INFERENCE_STEPS
127
+ self.guidance_scale = settings.GUIDANCE_SCALE
128
+ self.controlnet_scale = settings.CONTROLNET_SCALE
129
+ self.positive_prompt = settings.POSITIVE_PROMPT
130
+ self.negative_prompt = settings.NEGATIVE_PROMPT
131
+ self.caption_prefix = settings.CAPTION_PREFIX
132
+ self.seed = settings.COLORIZE_SEED
133
+
134
+ self.model_id = model_id or settings.MODEL_ID
135
+
136
+ self._load_pipeline()
137
+ self._load_caption_model()
138
+ self.last_caption: str | None = None
139
+
140
+ # --------------------------------------------------------------------- #
141
+ # Initialisation helpers
142
+ # --------------------------------------------------------------------- #
143
+ def _download_controlnet(self) -> str:
144
+ logger.info("Downloading ControlNet snapshot: %s", self.CONTROLNET_REPO)
145
+ local_dir = os.path.join(self.cache_dir, "sdxl_light_caption_output")
146
+ path = snapshot_download(
147
+ repo_id=self.CONTROLNET_REPO,
148
+ local_dir=local_dir,
149
+ local_dir_use_symlinks=False,
150
+ token=self.hf_token,
151
+ )
152
+ controlnet_path = os.path.join(path, self.CONTROLNET_SUBDIR)
153
+ if not os.path.isdir(controlnet_path):
154
+ raise RuntimeError(f"ControlNet weights not found at {controlnet_path}")
155
+ return controlnet_path
156
+
157
+ def _load_pipeline(self) -> None:
158
+ controlnet_path = self._download_controlnet()
159
+
160
+ logger.info("Loading SDXL components...")
161
+ vae = AutoencoderKL.from_pretrained(
162
+ self.BASE_MODEL,
163
+ subfolder="vae",
164
+ torch_dtype=self.dtype,
165
+ token=self.hf_token,
166
+ )
167
+ unet = UNet2DConditionModel.from_config(
168
+ self.BASE_MODEL,
169
+ subfolder="unet",
170
+ token=self.hf_token,
171
+ )
172
+ lightning_path = hf_hub_download(
173
+ repo_id=self.LIGHTNING_REPO,
174
+ filename=self.LIGHTNING_WEIGHTS,
175
+ token=self.hf_token,
176
+ )
177
+ unet.load_state_dict(load_file(lightning_path))
178
+
179
+ controlnet = ControlNetModel.from_pretrained(
180
+ controlnet_path,
181
+ torch_dtype=self.dtype,
182
+ )
183
+
184
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
185
+ self.BASE_MODEL,
186
+ vae=vae,
187
+ unet=unet,
188
+ controlnet=controlnet,
189
+ torch_dtype=self.dtype,
190
+ safety_checker=None,
191
+ requires_safety_checker=False,
192
+ token=self.hf_token,
193
+ )
194
+ self.pipe.set_progress_bar_config(disable=True)
195
+
196
+ if self.device.type == "cuda":
197
+ self.pipe.to(self.device, dtype=self.dtype)
198
+ if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
199
  try:
200
+ self.pipe.enable_xformers_memory_efficient_attention()
201
+ except Exception as exc: # pragma: no cover
202
+ logger.warning("Could not enable xformers attention: %s", exc)
203
+ else:
204
+ self.pipe.to(self.device, dtype=self.dtype)
205
+
206
+ logger.info("Colorization pipeline ready.")
207
+
208
+ def _load_caption_model(self) -> None:
209
+ logger.info("Loading BLIP captioning model...")
210
+ processor = BlipProcessor.from_pretrained(self.CAPTION_MODEL, token=self.hf_token)
211
+ model = BlipForConditionalGeneration.from_pretrained(
212
+ self.CAPTION_MODEL,
213
+ torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
214
+ token=self.hf_token,
215
+ )
216
+ self.caption_processor = processor
217
+ self.caption_model = model.to(self.device)
218
+
219
+ # --------------------------------------------------------------------- #
220
+ # Public API
221
+ # --------------------------------------------------------------------- #
222
+ def caption_image(self, image: Image.Image) -> str:
223
+ """Generate a cleaned caption for the image."""
224
+ inputs = self.caption_processor(
225
+ image,
226
+ self.caption_prefix,
227
+ return_tensors="pt",
228
+ ).to(self.device)
229
+
230
+ # BLIP on CPU expects float32 inputs
231
+ if self.device.type != "cuda":
232
+ inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
233
+
234
+ with torch.inference_mode():
235
+ caption_ids = self.caption_model.generate(**inputs)
236
+ caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
237
+ cleaned_caption = _remove_unlikely_words(caption)
238
+ return cleaned_caption or caption
239
+
240
+ def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
241
+ """Colorize a grayscale image."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  try:
 
 
 
 
 
 
 
243
  original_size = image.size
244
+ control_image = image.convert("L").convert("RGB").resize(
245
+ (512, 512), Image.Resampling.LANCZOS
246
+ )
247
+
248
+ caption = self.caption_image(image)
249
+ self.last_caption = caption
250
+
251
+ prompt_parts = [caption]
252
+ if self.positive_prompt:
253
+ prompt_parts.insert(0, self.positive_prompt)
254
+ final_prompt = ", ".join([part for part in prompt_parts if part])
255
+
256
+ negative_prompt = self.negative_prompt or None
257
+ steps = num_inference_steps or self.num_inference_steps
258
+ generator = torch.Generator(device=self.device).manual_seed(self.seed)
259
+
260
+ logger.info("Running SDXL pipeline with prompt: %s", final_prompt)
261
+ result = self.pipe(
262
+ prompt=final_prompt,
263
+ negative_prompt=negative_prompt,
264
+ image=control_image,
265
+ num_inference_steps=steps,
266
+ guidance_scale=self.guidance_scale,
267
+ controlnet_conditioning_scale=self.controlnet_scale,
268
+ generator=generator,
269
+ )
270
+
271
+ generated_image = result.images[0]
272
+ colorized = _apply_color(control_image, generated_image)
273
+ if colorized.size != original_size:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
275
+
276
+ return colorized, caption
277
+ except Exception as exc:
278
+ logger.exception("Error during colorization: %s", exc)
 
279
  raise
280
 
app/config.py CHANGED
@@ -18,8 +18,20 @@ class Settings(BaseSettings):
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
  # Model settings
21
- MODEL_ID: str = os.getenv("MODEL_ID", "lllyasviel/control_v11f1e_sd15_color")
22
- NUM_INFERENCE_STEPS: int = int(os.getenv("NUM_INFERENCE_STEPS", "20"))
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Storage settings
25
  UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "uploads")
 
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
  # Model settings
21
+ MODEL_ID: str = os.getenv("MODEL_ID", "nickpai/sdxl_light_caption_output")
22
+ NUM_INFERENCE_STEPS: int = int(os.getenv("NUM_INFERENCE_STEPS", "8"))
23
+ POSITIVE_PROMPT: str = os.getenv(
24
+ "POSITIVE_PROMPT",
25
+ "high quality color photo, vibrant natural colors, detailed lighting"
26
+ )
27
+ NEGATIVE_PROMPT: str = os.getenv(
28
+ "NEGATIVE_PROMPT",
29
+ "low quality, monochrome, black and white, desaturated, blurry, grainy"
30
+ )
31
+ GUIDANCE_SCALE: float = float(os.getenv("GUIDANCE_SCALE", "1.0"))
32
+ CONTROLNET_SCALE: float = float(os.getenv("CONTROLNET_SCALE", "1.0"))
33
+ CAPTION_PREFIX: str = os.getenv("CAPTION_PREFIX", "a photography of")
34
+ COLORIZE_SEED: int = int(os.getenv("COLORIZE_SEED", "123"))
35
 
36
  # Storage settings
37
  UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "uploads")
app/main.py CHANGED
@@ -254,7 +254,7 @@ async def colorize_image(
254
 
255
  # Colorize the image
256
  logger.info("Colorizing image...")
257
- colorized_image = colorize_model.colorize(image)
258
 
259
  # Save colorized image
260
  file_id = str(uuid.uuid4())
@@ -274,7 +274,8 @@ async def colorize_image(
274
  "result_id": file_id,
275
  "download_url": download_url,
276
  "api_download_url": api_download_url,
277
- "filename": result_filename
 
278
  }
279
  except Exception as e:
280
  logger.error("Error colorizing image: %s", str(e))
 
254
 
255
  # Colorize the image
256
  logger.info("Colorizing image...")
257
+ colorized_image, caption = colorize_model.colorize(image)
258
 
259
  # Save colorized image
260
  file_id = str(uuid.uuid4())
 
274
  "result_id": file_id,
275
  "download_url": download_url,
276
  "api_download_url": api_download_url,
277
+ "filename": result_filename,
278
+ "caption": caption
279
  }
280
  except Exception as e:
281
  logger.error("Error colorizing image: %s", str(e))