Lasya18 commited on
Commit
c75da4e
·
verified ·
1 Parent(s): 01c9ed4

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +391 -0
utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Interior Style Transfer Pipeline
3
+ """
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import os
8
+ from typing import Tuple, List, Optional, Union
9
+ import json
10
+ from pathlib import Path
11
+
12
+ def load_image_safe(image_path: str, target_size: Tuple[int, int] = None) -> np.ndarray:
13
+ """
14
+ Safely load an image with error handling
15
+
16
+ Args:
17
+ image_path: Path to the image file
18
+ target_size: Optional target size (width, height)
19
+
20
+ Returns:
21
+ Loaded image as numpy array
22
+
23
+ Raises:
24
+ ValueError: If image cannot be loaded
25
+ """
26
+ if not os.path.exists(image_path):
27
+ raise ValueError(f"Image file not found: {image_path}")
28
+
29
+ # Try to load with OpenCV first
30
+ image = cv2.imread(image_path)
31
+ if image is None:
32
+ # Fallback to PIL
33
+ try:
34
+ pil_image = Image.open(image_path)
35
+ image = np.array(pil_image)
36
+ if len(image.shape) == 3 and image.shape[2] == 3:
37
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
38
+ elif len(image.shape) == 3 and image.shape[2] == 4:
39
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
40
+ except Exception as e:
41
+ raise ValueError(f"Could not load image {image_path}: {e}")
42
+
43
+ if target_size:
44
+ image = cv2.resize(image, target_size)
45
+
46
+ return image
47
+
48
+ def save_image_safe(image: np.ndarray, output_path: str,
49
+ quality: int = 95) -> bool:
50
+ """
51
+ Safely save an image with error handling
52
+
53
+ Args:
54
+ image: Image to save as numpy array
55
+ output_path: Output file path
56
+ quality: JPEG quality (1-100)
57
+
58
+ Returns:
59
+ True if successful, False otherwise
60
+ """
61
+ try:
62
+ # Ensure output directory exists
63
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
64
+
65
+ # Save with OpenCV
66
+ success = cv2.imwrite(output_path, image)
67
+
68
+ if not success:
69
+ # Fallback to PIL
70
+ if len(image.shape) == 3 and image.shape[2] == 3:
71
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
72
+ else:
73
+ pil_image = Image.fromarray(image)
74
+
75
+ pil_image.save(output_path, quality=quality)
76
+ success = True
77
+
78
+ return success
79
+ except Exception as e:
80
+ print(f"Error saving image to {output_path}: {e}")
81
+ return False
82
+
83
+ def validate_image_pair(user_room: np.ndarray, inspiration_room: np.ndarray) -> Tuple[bool, str]:
84
+ """
85
+ Validate that two images are suitable for style transfer
86
+
87
+ Args:
88
+ user_room: User room image
89
+ inspiration_room: Inspiration room image
90
+
91
+ Returns:
92
+ Tuple of (is_valid, error_message)
93
+ """
94
+ # Check image dimensions
95
+ if user_room.shape != inspiration_room.shape:
96
+ return False, f"Image dimensions don't match: {user_room.shape} vs {inspiration_room.shape}"
97
+
98
+ # Check minimum size
99
+ min_size = 256
100
+ if user_room.shape[0] < min_size or user_room.shape[1] < min_size:
101
+ return False, f"Images too small. Minimum size: {min_size}x{min_size}"
102
+
103
+ # Check aspect ratio (should be roughly square for best results)
104
+ aspect_ratio = user_room.shape[1] / user_room.shape[0]
105
+ if aspect_ratio < 0.5 or aspect_ratio > 2.0:
106
+ return False, f"Extreme aspect ratio: {aspect_ratio:.2f}. Square images work best."
107
+
108
+ # Check if images are too dark or too bright
109
+ user_brightness = np.mean(cv2.cvtColor(user_room, cv2.COLOR_BGR2GRAY))
110
+ inspiration_brightness = np.mean(cv2.cvtColor(inspiration_room, cv2.COLOR_BGR2GRAY))
111
+
112
+ if user_brightness < 30 or user_brightness > 225:
113
+ return False, f"User room too {'dark' if user_brightness < 30 else 'bright'}: {user_brightness:.1f}"
114
+
115
+ if inspiration_brightness < 30 or inspiration_brightness > 225:
116
+ return False, f"Inspiration room too {'dark' if inspiration_brightness < 30 else 'bright'}: {inspiration_brightness:.1f}"
117
+
118
+ return True, "Images are valid for style transfer"
119
+
120
+ def create_comparison_image(original: np.ndarray, result: np.ndarray,
121
+ title: str = "Style Transfer Comparison") -> np.ndarray:
122
+ """
123
+ Create a side-by-side comparison image
124
+
125
+ Args:
126
+ original: Original user room image
127
+ result: Style transfer result
128
+ title: Title for the comparison
129
+
130
+ Returns:
131
+ Comparison image
132
+ """
133
+ # Ensure both images have the same dimensions
134
+ if original.shape != result.shape:
135
+ result = cv2.resize(result, (original.shape[1], original.shape[0]))
136
+
137
+ # Create comparison image
138
+ comparison = np.hstack([original, result])
139
+
140
+ # Add title
141
+ font = cv2.FONT_HERSHEY_SIMPLEX
142
+ font_scale = 1.0
143
+ thickness = 2
144
+
145
+ # Calculate text position
146
+ text_size = cv2.getTextSize(title, font, font_scale, thickness)[0]
147
+ text_x = (comparison.shape[1] - text_size[0]) // 2
148
+ text_y = 50
149
+
150
+ # Add background for text
151
+ cv2.rectangle(comparison, (text_x - 10, text_y - 30),
152
+ (text_x + text_size[0] + 10, text_y + 10), (255, 255, 255), -1)
153
+
154
+ # Add text
155
+ cv2.putText(comparison, title, (text_x, text_y), font, font_scale, (0, 0, 0), thickness)
156
+
157
+ # Add labels
158
+ cv2.putText(comparison, "Original", (50, comparison.shape[0] - 30),
159
+ font, 0.7, (255, 255, 255), 2)
160
+ cv2.putText(comparison, "Result", (original.shape[1] + 50, comparison.shape[0] - 30),
161
+ font, 0.7, (255, 255, 255), 2)
162
+
163
+ return comparison
164
+
165
+ def enhance_image_quality(image: np.ndarray,
166
+ sharpness: float = 0.3,
167
+ contrast: float = 1.1,
168
+ saturation: float = 1.1) -> np.ndarray:
169
+ """
170
+ Enhance image quality with various filters
171
+
172
+ Args:
173
+ image: Input image
174
+ sharpness: Sharpening strength (0.0 to 1.0)
175
+ contrast: Contrast multiplier
176
+ saturation: Saturation multiplier
177
+
178
+ Returns:
179
+ Enhanced image
180
+ """
181
+ enhanced = image.copy()
182
+
183
+ # Sharpening
184
+ if sharpness > 0:
185
+ kernel = np.array([[-1, -1, -1],
186
+ [-1, 9, -1],
187
+ [-1, -1, -1]]) * sharpness
188
+ enhanced = cv2.filter2D(enhanced, -1, kernel)
189
+
190
+ # Contrast adjustment
191
+ if contrast != 1.0:
192
+ enhanced = np.clip(enhanced * contrast, 0, 255).astype(np.uint8)
193
+
194
+ # Saturation adjustment
195
+ if saturation != 1.0:
196
+ hsv = cv2.cvtColor(enhanced, cv2.COLOR_BGR2HSV).astype(np.float32)
197
+ hsv[:, :, 1] = np.clip(hsv[:, :, 1] * saturation, 0, 255)
198
+ enhanced = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
199
+
200
+ return enhanced
201
+
202
+ def create_progress_bar(total: int, description: str = "Processing") -> callable:
203
+ """
204
+ Create a simple progress bar function
205
+
206
+ Args:
207
+ total: Total number of steps
208
+ description: Description of the process
209
+
210
+ Returns:
211
+ Function to update progress
212
+ """
213
+ def update_progress(current: int):
214
+ percentage = (current / total) * 100
215
+ bar_length = 30
216
+ filled_length = int(bar_length * current // total)
217
+ bar = '█' * filled_length + '-' * (bar_length - filled_length)
218
+ print(f'\r{description}: |{bar}| {percentage:.1f}% ({current}/{total})', end='')
219
+ if current == total:
220
+ print()
221
+
222
+ return update_progress
223
+
224
+ def save_metadata(metadata: dict, output_path: str) -> bool:
225
+ """
226
+ Save metadata to JSON file
227
+
228
+ Args:
229
+ metadata: Dictionary of metadata
230
+ output_path: Output file path
231
+
232
+ Returns:
233
+ True if successful, False otherwise
234
+ """
235
+ try:
236
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
237
+
238
+ with open(output_path, 'w') as f:
239
+ json.dump(metadata, f, indent=2, default=str)
240
+
241
+ return True
242
+ except Exception as e:
243
+ print(f"Error saving metadata to {output_path}: {e}")
244
+ return False
245
+
246
+ def load_metadata(metadata_path: str) -> Optional[dict]:
247
+ """
248
+ Load metadata from JSON file
249
+
250
+ Args:
251
+ metadata_path: Path to metadata file
252
+
253
+ Returns:
254
+ Loaded metadata dictionary or None if failed
255
+ """
256
+ try:
257
+ with open(metadata_path, 'r') as f:
258
+ return json.load(f)
259
+ except Exception as e:
260
+ print(f"Error loading metadata from {metadata_path}: {e}")
261
+ return None
262
+
263
+ def calculate_image_similarity(img1: np.ndarray, img2: np.ndarray) -> float:
264
+ """
265
+ Calculate similarity between two images using structural similarity
266
+
267
+ Args:
268
+ img1: First image
269
+ img2: Second image
270
+
271
+ Returns:
272
+ Similarity score (0.0 to 1.0, higher is more similar)
273
+ """
274
+ try:
275
+ from skimage.metrics import structural_similarity as ssim
276
+
277
+ # Ensure same dimensions
278
+ if img1.shape != img2.shape:
279
+ img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
280
+
281
+ # Convert to grayscale for SSIM
282
+ gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
283
+ gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
284
+
285
+ # Calculate SSIM
286
+ similarity = ssim(gray1, gray2)
287
+ return max(0.0, similarity) # Ensure non-negative
288
+
289
+ except ImportError:
290
+ # Fallback to simple MSE-based similarity
291
+ if img1.shape != img2.shape:
292
+ img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
293
+
294
+ mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
295
+ max_mse = 255 ** 2
296
+ similarity = 1.0 - (mse / max_mse)
297
+ return max(0.0, similarity)
298
+
299
+ def create_thumbnail(image: np.ndarray, max_size: int = 200) -> np.ndarray:
300
+ """
301
+ Create a thumbnail version of an image
302
+
303
+ Args:
304
+ image: Input image
305
+ max_size: Maximum dimension size
306
+
307
+ Returns:
308
+ Thumbnail image
309
+ """
310
+ height, width = image.shape[:2]
311
+
312
+ if height <= max_size and width <= max_size:
313
+ return image.copy()
314
+
315
+ # Calculate new dimensions maintaining aspect ratio
316
+ if height > width:
317
+ new_height = max_size
318
+ new_width = int(width * max_size / height)
319
+ else:
320
+ new_width = max_size
321
+ new_height = int(height * max_size / width)
322
+
323
+ thumbnail = cv2.resize(image, (new_width, new_height))
324
+ return thumbnail
325
+
326
+ def batch_resize_images(images: List[np.ndarray],
327
+ target_size: Tuple[int, int]) -> List[np.ndarray]:
328
+ """
329
+ Resize a list of images to the same target size
330
+
331
+ Args:
332
+ images: List of input images
333
+ target_size: Target size (width, height)
334
+
335
+ Returns:
336
+ List of resized images
337
+ """
338
+ resized_images = []
339
+
340
+ for image in images:
341
+ resized = cv2.resize(image, target_size)
342
+ resized_images.append(resized)
343
+
344
+ return resized_images
345
+
346
+ def create_image_grid(images: List[np.ndarray],
347
+ grid_size: Tuple[int, int] = None) -> np.ndarray:
348
+ """
349
+ Create a grid layout of images
350
+
351
+ Args:
352
+ images: List of images to arrange in grid
353
+ grid_size: Grid dimensions (rows, cols). If None, auto-calculate
354
+
355
+ Returns:
356
+ Grid image
357
+ """
358
+ if not images:
359
+ return np.array([])
360
+
361
+ if grid_size is None:
362
+ # Auto-calculate grid size
363
+ n_images = len(images)
364
+ cols = int(np.ceil(np.sqrt(n_images)))
365
+ rows = int(np.ceil(n_images / cols))
366
+ grid_size = (rows, cols)
367
+
368
+ rows, cols = grid_size
369
+
370
+ # Ensure all images have the same size
371
+ target_size = (images[0].shape[1], images[0].shape[0])
372
+ resized_images = batch_resize_images(images, target_size)
373
+
374
+ # Create grid
375
+ grid_rows = []
376
+ for i in range(rows):
377
+ row_images = []
378
+ for j in range(cols):
379
+ idx = i * cols + j
380
+ if idx < len(resized_images):
381
+ row_images.append(resized_images[idx])
382
+ else:
383
+ # Fill empty space with black
384
+ empty_image = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
385
+ row_images.append(empty_image)
386
+
387
+ row = np.hstack(row_images)
388
+ grid_rows.append(row)
389
+
390
+ grid = np.vstack(grid_rows)
391
+ return grid