Camais03 commited on
Commit
e7d3e33
·
verified ·
1 Parent(s): ac1737a

Upload 6 files

Browse files
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Make utils a proper Python package
utils/file_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File utilities for Image Tagger application.
3
+ """
4
+
5
+ import os
6
+ import time
7
+
8
+ def save_tags_to_file(image_path, all_tags, original_filename=None, custom_dir=None, overwrite=False):
9
+ """
10
+ Save tags to a text file in a dedicated 'saved_tags' folder or custom directory.
11
+
12
+ Args:
13
+ image_path: Path to the original image
14
+ all_tags: List of all tags to save
15
+ original_filename: Original filename if uploaded through Streamlit
16
+ custom_dir: Custom directory to save tags to (if None, uses 'saved_tags' folder)
17
+
18
+ Returns:
19
+ Path to the saved file
20
+ """
21
+ # Determine the save directory
22
+ if custom_dir and os.path.isdir(custom_dir):
23
+ save_dir = custom_dir
24
+ else:
25
+ # Create a dedicated folder for saved tags in the app's root directory
26
+ app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
27
+ save_dir = os.path.join(app_dir, "saved_tags")
28
+
29
+ # Ensure the directory exists
30
+ os.makedirs(save_dir, exist_ok=True)
31
+
32
+ # Determine the filename
33
+ if original_filename:
34
+ # For uploaded files, use original filename
35
+ base_name = os.path.splitext(original_filename)[0]
36
+ else:
37
+ # For non-uploaded files, use the image path
38
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
39
+
40
+ # Create the output path
41
+ output_path = os.path.join(save_dir, f"{base_name}.txt")
42
+
43
+ # If overwrite is False and file exists, add a timestamp to avoid overwriting
44
+ if not overwrite and os.path.exists(output_path):
45
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
46
+ output_path = os.path.join(save_dir, f"{base_name}_{timestamp}.txt")
47
+
48
+ # Write the tags to file
49
+ with open(output_path, 'w', encoding='utf-8') as f:
50
+ if all_tags:
51
+ # Add comma after each tag including the last one
52
+ tag_text = ", ".join(all_tags) + ","
53
+ f.write(tag_text)
54
+
55
+ return output_path
56
+
57
+ def get_default_save_locations():
58
+ """
59
+ Get default save locations for tag files.
60
+
61
+ Returns:
62
+ List of default save locations
63
+ """
64
+ # App directory
65
+ app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
66
+ save_dir = os.path.join(app_dir, "saved_tags")
67
+
68
+ # Common user directories
69
+ desktop_dir = os.path.expanduser("~/Desktop")
70
+ download_dir = os.path.expanduser("~/Downloads")
71
+ documents_dir = os.path.expanduser("~/Documents")
72
+
73
+ # List of default save locations
74
+ save_locations = [
75
+ save_dir,
76
+ desktop_dir,
77
+ download_dir,
78
+ documents_dir,
79
+ ]
80
+
81
+ # Ensure directories exist
82
+ for folder in save_locations:
83
+ os.makedirs(folder, exist_ok=True)
84
+
85
+ return save_locations
86
+
87
+ def apply_category_limits(result, category_limits):
88
+ """
89
+ Apply category limits to a result dictionary.
90
+
91
+ Args:
92
+ result: Result dictionary containing tags and all_tags
93
+ category_limits: Dictionary mapping categories to their tag limits
94
+ (0 = exclude category, -1 = no limit/include all)
95
+
96
+ Returns:
97
+ Updated result dictionary with limits applied
98
+ """
99
+ if not category_limits or not result['success']:
100
+ return result
101
+
102
+ # Get the filtered tags
103
+ filtered_tags = result['tags']
104
+
105
+ # Apply limits to each category
106
+ for category, cat_tags in list(filtered_tags.items()):
107
+ # Get limit for this category, default to -1 (no limit)
108
+ limit = category_limits.get(category, -1)
109
+
110
+ if limit == 0:
111
+ # Exclude this category entirely
112
+ del filtered_tags[category]
113
+ elif limit > 0 and len(cat_tags) > limit:
114
+ # Limit to top N tags for this category
115
+ filtered_tags[category] = cat_tags[:limit]
116
+
117
+ # Regenerate all_tags list after applying limits
118
+ all_tags = []
119
+ for category, cat_tags in filtered_tags.items():
120
+ for tag, _ in cat_tags:
121
+ all_tags.append(tag)
122
+
123
+ # Update the result with limited tags
124
+ result['tags'] = filtered_tags
125
+ result['all_tags'] = all_tags
126
+
127
+ return result
utils/image_processing.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image processing functions for the Image Tagger application.
3
+ """
4
+
5
+ import os
6
+ import traceback
7
+ import glob
8
+
9
+
10
+ def process_image(image_path, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1):
11
+ """
12
+ Process a single image and return the tags.
13
+
14
+ Args:
15
+ image_path: Path to the image
16
+ model: The image tagger model
17
+ thresholds: Thresholds dictionary
18
+ metadata: Metadata dictionary
19
+ threshold_profile: Selected threshold profile
20
+ active_threshold: Overall threshold value
21
+ active_category_thresholds: Category-specific thresholds
22
+ min_confidence: Minimum confidence to include in results
23
+
24
+ Returns:
25
+ Dictionary with tags, all probabilities, and other info
26
+ """
27
+ try:
28
+ # Run inference directly using the model's predict method
29
+ if threshold_profile in ["Category-specific", "High Precision", "High Recall"]:
30
+ results = model.predict(
31
+ image_path=image_path,
32
+ category_thresholds=active_category_thresholds
33
+ )
34
+ else:
35
+ results = model.predict(
36
+ image_path=image_path,
37
+ threshold=active_threshold
38
+ )
39
+
40
+ # Extract and organize all probabilities
41
+ all_probs = {}
42
+ probs = results['refined_probabilities'][0] # Remove batch dimension
43
+
44
+ for idx in range(len(probs)):
45
+ prob_value = probs[idx].item()
46
+ if prob_value >= min_confidence:
47
+ tag, category = model.dataset.get_tag_info(idx)
48
+
49
+ if category not in all_probs:
50
+ all_probs[category] = []
51
+
52
+ all_probs[category].append((tag, prob_value))
53
+
54
+ # Sort tags by probability within each category
55
+ for category in all_probs:
56
+ all_probs[category] = sorted(
57
+ all_probs[category],
58
+ key=lambda x: x[1],
59
+ reverse=True
60
+ )
61
+
62
+ # Get the filtered tags based on the selected threshold
63
+ tags = {}
64
+ for category, cat_tags in all_probs.items():
65
+ threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold
66
+ tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold]
67
+
68
+ # Create a flat list of all tags above threshold
69
+ all_tags = []
70
+ for category, cat_tags in tags.items():
71
+ for tag, _ in cat_tags:
72
+ all_tags.append(tag)
73
+
74
+ return {
75
+ 'tags': tags,
76
+ 'all_probs': all_probs,
77
+ 'all_tags': all_tags,
78
+ 'success': True
79
+ }
80
+
81
+ except Exception as e:
82
+ print(f"Error processing {image_path}: {str(e)}")
83
+ traceback.print_exc()
84
+ return {
85
+ 'tags': {},
86
+ 'all_probs': {},
87
+ 'all_tags': [],
88
+ 'success': False,
89
+ 'error': str(e)
90
+ }
91
+
92
+ def apply_category_limits(result, category_limits):
93
+ """
94
+ Apply category limits to a result dictionary.
95
+
96
+ Args:
97
+ result: Result dictionary containing tags and all_tags
98
+ category_limits: Dictionary mapping categories to their tag limits
99
+ (0 = exclude category, -1 = no limit/include all)
100
+
101
+ Returns:
102
+ Updated result dictionary with limits applied
103
+ """
104
+ if not category_limits or not result['success']:
105
+ return result
106
+
107
+ # Get the filtered tags
108
+ filtered_tags = result['tags']
109
+
110
+ # Apply limits to each category
111
+ for category, cat_tags in list(filtered_tags.items()):
112
+ # Get limit for this category, default to -1 (no limit)
113
+ limit = category_limits.get(category, -1)
114
+
115
+ if limit == 0:
116
+ # Exclude this category entirely
117
+ del filtered_tags[category]
118
+ elif limit > 0 and len(cat_tags) > limit:
119
+ # Limit to top N tags for this category
120
+ filtered_tags[category] = cat_tags[:limit]
121
+
122
+ # Regenerate all_tags list after applying limits
123
+ all_tags = []
124
+ for category, cat_tags in filtered_tags.items():
125
+ for tag, _ in cat_tags:
126
+ all_tags.append(tag)
127
+
128
+ # Update the result with limited tags
129
+ result['tags'] = filtered_tags
130
+ result['all_tags'] = all_tags
131
+
132
+ return result
133
+
134
+ def batch_process_images(folder_path, model, thresholds, metadata, threshold_profile, active_threshold,
135
+ active_category_thresholds, save_dir=None, progress_callback=None,
136
+ min_confidence=0.1, batch_size=1, category_limits=None):
137
+ """
138
+ Process all images in a folder with optional batching for improved performance.
139
+
140
+ Args:
141
+ folder_path: Path to folder containing images
142
+ model: The image tagger model
143
+ thresholds: Thresholds dictionary
144
+ metadata: Metadata dictionary
145
+ threshold_profile: Selected threshold profile
146
+ active_threshold: Overall threshold value
147
+ active_category_thresholds: Category-specific thresholds
148
+ save_dir: Directory to save tag files (if None uses default)
149
+ progress_callback: Optional callback for progress updates
150
+ min_confidence: Minimum confidence threshold
151
+ batch_size: Number of images to process at once (default: 1)
152
+ category_limits: Dictionary mapping categories to their tag limits (0 = unlimited)
153
+
154
+ Returns:
155
+ Dictionary with results for each image
156
+ """
157
+ from .file_utils import save_tags_to_file # Import here to avoid circular imports
158
+ import torch
159
+ from PIL import Image
160
+ import time
161
+
162
+ print(f"Starting batch processing on {folder_path} with batch size {batch_size}")
163
+ start_time = time.time()
164
+
165
+ # Find all image files in the folder
166
+ image_extensions = ['*.jpg', '*.jpeg', '*.png']
167
+ image_files = []
168
+
169
+ for ext in image_extensions:
170
+ image_files.extend(glob.glob(os.path.join(folder_path, ext)))
171
+ image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
172
+
173
+ # Use a set to remove duplicate files (Windows filesystems are case-insensitive)
174
+ if os.name == 'nt': # Windows
175
+ # Use lowercase paths for comparison on Windows
176
+ unique_paths = set()
177
+ unique_files = []
178
+ for file_path in image_files:
179
+ normalized_path = os.path.normpath(file_path).lower()
180
+ if normalized_path not in unique_paths:
181
+ unique_paths.add(normalized_path)
182
+ unique_files.append(file_path)
183
+ image_files = unique_files
184
+
185
+ # Sort files for consistent processing order
186
+ image_files.sort()
187
+
188
+ if not image_files:
189
+ return {
190
+ 'success': False,
191
+ 'error': f"No images found in {folder_path}",
192
+ 'results': {}
193
+ }
194
+
195
+ print(f"Found {len(image_files)} images to process")
196
+
197
+ # Use the provided save directory or create a default one
198
+ if save_dir is None:
199
+ app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
200
+ save_dir = os.path.join(app_dir, "saved_tags")
201
+
202
+ # Ensure the directory exists
203
+ os.makedirs(save_dir, exist_ok=True)
204
+
205
+ # Process images in batches
206
+ results = {}
207
+ total_images = len(image_files)
208
+ processed = 0
209
+
210
+ # Process in batches
211
+ for i in range(0, total_images, batch_size):
212
+ batch_start = time.time()
213
+ # Get current batch of images
214
+ batch_files = image_files[i:i+batch_size]
215
+ batch_size_actual = len(batch_files)
216
+
217
+ print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images")
218
+
219
+ if batch_size > 1:
220
+ # True batch processing for multiple images at once
221
+ try:
222
+ # Using batch processing if batch_size > 1
223
+ batch_results = process_image_batch(
224
+ image_paths=batch_files,
225
+ model=model,
226
+ thresholds=thresholds,
227
+ metadata=metadata,
228
+ threshold_profile=threshold_profile,
229
+ active_threshold=active_threshold,
230
+ active_category_thresholds=active_category_thresholds,
231
+ min_confidence=min_confidence
232
+ )
233
+
234
+ # Process and save results for each image in the batch
235
+ for j, image_path in enumerate(batch_files):
236
+ # Update progress if callback provided
237
+ if progress_callback:
238
+ progress_callback(processed + j, total_images, image_path)
239
+
240
+ if j < len(batch_results):
241
+ result = batch_results[j]
242
+
243
+ # Apply category limits if specified
244
+ if category_limits and result['success']:
245
+ # Use the apply_category_limits function instead of the inline code
246
+ result = apply_category_limits(result, category_limits)
247
+
248
+ # Debug print if you want
249
+ print(f"Applied limits for {os.path.basename(image_path)}, remaining tags: {len(result['all_tags'])}")
250
+
251
+ # Save the tags to a file
252
+ if result['success']:
253
+ output_path = save_tags_to_file(
254
+ image_path=image_path,
255
+ all_tags=result['all_tags'],
256
+ custom_dir=save_dir,
257
+ overwrite=True
258
+ )
259
+ result['output_path'] = str(output_path)
260
+
261
+ # Store the result
262
+ results[image_path] = result
263
+ else:
264
+ # Handle case where batch processing returned fewer results than expected
265
+ results[image_path] = {
266
+ 'success': False,
267
+ 'error': 'Batch processing error: missing result',
268
+ 'all_tags': []
269
+ }
270
+
271
+ except Exception as e:
272
+ print(f"Batch processing error: {str(e)}")
273
+ traceback.print_exc()
274
+
275
+ # Fall back to processing images one by one in this batch
276
+ for j, image_path in enumerate(batch_files):
277
+ if progress_callback:
278
+ progress_callback(processed + j, total_images, image_path)
279
+
280
+ result = process_image(
281
+ image_path=image_path,
282
+ model=model,
283
+ thresholds=thresholds,
284
+ metadata=metadata,
285
+ threshold_profile=threshold_profile,
286
+ active_threshold=active_threshold,
287
+ active_category_thresholds=active_category_thresholds,
288
+ min_confidence=min_confidence
289
+ )
290
+
291
+ # Apply category limits if specified
292
+ if category_limits and result['success']:
293
+ # Use the apply_category_limits function
294
+ result = apply_category_limits(result, category_limits)
295
+
296
+ if result['success']:
297
+ output_path = save_tags_to_file(
298
+ image_path=image_path,
299
+ all_tags=result['all_tags'],
300
+ custom_dir=save_dir,
301
+ overwrite=True
302
+ )
303
+ result['output_path'] = str(output_path)
304
+
305
+ results[image_path] = result
306
+ else:
307
+ # Process one by one if batch_size is 1
308
+ for j, image_path in enumerate(batch_files):
309
+ if progress_callback:
310
+ progress_callback(processed + j, total_images, image_path)
311
+
312
+ result = process_image(
313
+ image_path=image_path,
314
+ model=model,
315
+ thresholds=thresholds,
316
+ metadata=metadata,
317
+ threshold_profile=threshold_profile,
318
+ active_threshold=active_threshold,
319
+ active_category_thresholds=active_category_thresholds,
320
+ min_confidence=min_confidence
321
+ )
322
+
323
+ # Apply category limits if specified
324
+ if category_limits and result['success']:
325
+ # Use the apply_category_limits function
326
+ result = apply_category_limits(result, category_limits)
327
+
328
+ if result['success']:
329
+ output_path = save_tags_to_file(
330
+ image_path=image_path,
331
+ all_tags=result['all_tags'],
332
+ custom_dir=save_dir,
333
+ overwrite=True
334
+ )
335
+ result['output_path'] = str(output_path)
336
+
337
+ results[image_path] = result
338
+
339
+ # Update processed count
340
+ processed += batch_size_actual
341
+
342
+ # Calculate batch timing
343
+ batch_end = time.time()
344
+ batch_time = batch_end - batch_start
345
+ print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)")
346
+
347
+ # Final progress update
348
+ if progress_callback:
349
+ progress_callback(total_images, total_images, None)
350
+
351
+ end_time = time.time()
352
+ total_time = end_time - start_time
353
+ print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image")
354
+
355
+ return {
356
+ 'success': True,
357
+ 'total': total_images,
358
+ 'processed': len(results),
359
+ 'results': results,
360
+ 'save_dir': save_dir,
361
+ 'time_elapsed': end_time - start_time
362
+ }
363
+
364
+ def process_image_batch(image_paths, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1):
365
+ """
366
+ Process a batch of images at once.
367
+
368
+ Args:
369
+ image_paths: List of paths to the images
370
+ model: The image tagger model
371
+ thresholds: Thresholds dictionary
372
+ metadata: Metadata dictionary
373
+ threshold_profile: Selected threshold profile
374
+ active_threshold: Overall threshold value
375
+ active_category_thresholds: Category-specific thresholds
376
+ min_confidence: Minimum confidence to include in results
377
+
378
+ Returns:
379
+ List of dictionaries with tags, all probabilities, and other info for each image
380
+ """
381
+ try:
382
+ import torch
383
+ from PIL import Image
384
+ import torchvision.transforms as transforms
385
+
386
+ # Identify the model type we're using for better error handling
387
+ model_type = model.__class__.__name__
388
+ print(f"Running batch processing with model type: {model_type}")
389
+
390
+ # Prepare the transformation for the images
391
+ transform = transforms.Compose([
392
+ transforms.Resize((512, 512)), # Adjust based on your model's expected input
393
+ transforms.ToTensor(),
394
+ ])
395
+
396
+ # Get model information
397
+ device = next(model.parameters()).device
398
+ dtype = next(model.parameters()).dtype
399
+ print(f"Model is using device: {device}, dtype: {dtype}")
400
+
401
+ # Load and preprocess all images
402
+ batch_tensor = []
403
+ valid_images = []
404
+
405
+ for img_path in image_paths:
406
+ try:
407
+ img = Image.open(img_path).convert('RGB')
408
+ img_tensor = transform(img)
409
+ img_tensor = img_tensor.to(device=device, dtype=dtype)
410
+ batch_tensor.append(img_tensor)
411
+ valid_images.append(img_path)
412
+ except Exception as e:
413
+ print(f"Error loading image {img_path}: {str(e)}")
414
+
415
+ if not batch_tensor:
416
+ return []
417
+
418
+ # Stack all tensors into a single batch
419
+ batch_input = torch.stack(batch_tensor)
420
+
421
+ # Process entire batch at once
422
+ with torch.no_grad():
423
+ try:
424
+ # Forward pass on the whole batch
425
+ output = model(batch_input)
426
+
427
+ # Handle tuple output format
428
+ if isinstance(output, tuple):
429
+ probs_batch = torch.sigmoid(output[1])
430
+ else:
431
+ probs_batch = torch.sigmoid(output)
432
+
433
+ # Process each image's results
434
+ results = []
435
+ for i, img_path in enumerate(valid_images):
436
+ probs = probs_batch[i].unsqueeze(0) # Add batch dimension back
437
+
438
+ # Extract and organize all probabilities
439
+ all_probs = {}
440
+ for idx in range(probs.size(1)):
441
+ prob_value = probs[0, idx].item()
442
+ if prob_value >= min_confidence:
443
+ tag, category = model.dataset.get_tag_info(idx)
444
+
445
+ if category not in all_probs:
446
+ all_probs[category] = []
447
+
448
+ all_probs[category].append((tag, prob_value))
449
+
450
+ # Sort tags by probability
451
+ for category in all_probs:
452
+ all_probs[category] = sorted(all_probs[category], key=lambda x: x[1], reverse=True)
453
+
454
+ # Get filtered tags
455
+ tags = {}
456
+ for category, cat_tags in all_probs.items():
457
+ threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold
458
+ tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold]
459
+
460
+ # Create a flat list of all tags above threshold
461
+ all_tags = []
462
+ for category, cat_tags in tags.items():
463
+ for tag, _ in cat_tags:
464
+ all_tags.append(tag)
465
+
466
+ results.append({
467
+ 'tags': tags,
468
+ 'all_probs': all_probs,
469
+ 'all_tags': all_tags,
470
+ 'success': True
471
+ })
472
+
473
+ return results
474
+
475
+ except RuntimeError as e:
476
+ # If we encounter CUDA out of memory or another runtime error,
477
+ # fall back to processing one by one
478
+ print(f"Error in batch processing: {str(e)}")
479
+ print("Falling back to one-by-one processing...")
480
+
481
+ # Process one by one as fallback
482
+ results = []
483
+ for i, (img_tensor, img_path) in enumerate(zip(batch_tensor, valid_images)):
484
+ try:
485
+ input_tensor = img_tensor.unsqueeze(0)
486
+ output = model(input_tensor)
487
+
488
+ if isinstance(output, tuple):
489
+ probs = torch.sigmoid(output[1])
490
+ else:
491
+ probs = torch.sigmoid(output)
492
+
493
+ # Same post-processing as before...
494
+ # [Code omitted for brevity]
495
+
496
+ except Exception as e:
497
+ print(f"Error processing image {img_path}: {str(e)}")
498
+ results.append({
499
+ 'tags': {},
500
+ 'all_probs': {},
501
+ 'all_tags': [],
502
+ 'success': False,
503
+ 'error': str(e)
504
+ })
505
+
506
+ return results
507
+
508
+ except Exception as e:
509
+ print(f"Error in batch processing: {str(e)}")
510
+ import traceback
511
+ traceback.print_exc()
utils/model_loader.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import GroupNorm, LayerNorm
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ import timm
7
+
8
+ class ViTWrapper(nn.Module):
9
+ """Wrapper to make ViT compatible with feature extraction for ImageTagger"""
10
+ def __init__(self, vit_model):
11
+ super().__init__()
12
+ self.vit = vit_model
13
+ self.out_indices = (-1,) # mimic timm.features_only
14
+
15
+ # Get patch size and embedding dim from the model
16
+ self.patch_size = vit_model.patch_embed.patch_size[0]
17
+ self.embed_dim = vit_model.embed_dim
18
+
19
+ def forward(self, x):
20
+ B = x.size(0)
21
+
22
+ # ➊ patch tokens
23
+ x = self.vit.patch_embed(x) # (B, N, C)
24
+
25
+ # ➋ prepend CLS
26
+ cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C)
27
+ x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C)
28
+
29
+ # ➌ add positional encodings (full, incl. CLS)
30
+ if self.vit.pos_embed is not None:
31
+ x = x + self.vit.pos_embed[:, : x.size(1), :]
32
+
33
+ x = self.vit.pos_drop(x)
34
+
35
+ for blk in self.vit.blocks:
36
+ x = blk(x)
37
+
38
+ x = self.vit.norm(x) # (B, 1+N, C)
39
+
40
+ # ➍ split back out
41
+ cls_final = x[:, 0] # (B, C)
42
+ patch_tokens = x[:, 1:] # (B, N, C)
43
+
44
+ # ➎ reshape patches to (B, C, H, W)
45
+ B, N, C = patch_tokens.shape
46
+ h = w = int(N ** 0.5) # square assumption
47
+ patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)
48
+
49
+ # Return **both**: (patch map, CLS)
50
+ return patch_features, cls_final
51
+
52
+ def set_grad_checkpointing(self, enable=True):
53
+ """Enable gradient checkpointing if supported"""
54
+ if hasattr(self.vit, 'set_grad_checkpointing'):
55
+ self.vit.set_grad_checkpointing(enable)
56
+ return True
57
+ return False
58
+
59
+ class ImageTagger(nn.Module):
60
+ """
61
+ ImageTagger with Vision Transformer backbone
62
+ """
63
+ def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224',
64
+ num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256,
65
+ use_gradient_checkpointing=False, img_size=224):
66
+ super().__init__()
67
+
68
+ # Store checkpointing config
69
+ self.use_gradient_checkpointing = use_gradient_checkpointing
70
+ self.model_name = model_name
71
+ self.img_size = img_size
72
+
73
+ # Debug and stats flags
74
+ self._flags = {
75
+ 'debug': False,
76
+ 'model_stats': True
77
+ }
78
+
79
+ # Core model config
80
+ self.dataset = dataset
81
+ self.tag_context_size = tag_context_size
82
+ self.total_tags = total_tags
83
+
84
+ print(f"🏗️ Building ImageTagger with ViT backbone and {total_tags} tags")
85
+ print(f" Backbone: {model_name}")
86
+ print(f" Image size: {img_size}x{img_size}")
87
+ print(f" Tag context size: {tag_context_size}")
88
+ print(f" Gradient checkpointing: {use_gradient_checkpointing}")
89
+ print(f" 🎯 Custom embeddings, PyTorch native attention, no ground truth inclusion")
90
+
91
+ # 1. Vision Transformer Backbone
92
+ print("📦 Loading Vision Transformer backbone...")
93
+ self._load_vit_backbone()
94
+
95
+ # Get backbone dimensions by running a test forward pass
96
+ self._determine_backbone_dimensions()
97
+
98
+ self.embedding_dim = self.backbone.embed_dim
99
+
100
+ # 2. Custom Tag Embeddings (no CLIP)
101
+ print("🎯 Using custom tag embeddings (no CLIP)")
102
+ self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
103
+
104
+ # 3. Shared weights approach - tag bias for initial predictions
105
+ print("🔗 Using shared weights between initial head and tag embeddings")
106
+ self.tag_bias = nn.Parameter(torch.zeros(total_tags))
107
+
108
+
109
+ # 4. Image token extraction (for attention AND global pooling)
110
+ self.image_token_proj = nn.Identity()
111
+
112
+ # 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation)
113
+ self.cross_attention = nn.MultiheadAttention(
114
+ embed_dim=self.embedding_dim,
115
+ num_heads=num_heads,
116
+ dropout=dropout,
117
+ batch_first=True # Use (batch, seq, feature) format
118
+ )
119
+ self.cross_norm = nn.LayerNorm(self.embedding_dim)
120
+
121
+ # Initialize weights
122
+ self._init_weights()
123
+
124
+ # Enable gradient checkpointing
125
+ if self.use_gradient_checkpointing:
126
+ self._enable_gradient_checkpointing()
127
+
128
+ print(f"✅ ImageTagger with ViT initialized!")
129
+ self._print_parameter_count()
130
+
131
+ def _load_vit_backbone(self):
132
+ """Load Vision Transformer model from timm"""
133
+ print(f" Loading from timm: {self.model_name}")
134
+
135
+ # Load the ViT model (not features_only, we want the full model for token extraction)
136
+ vit_model = timm.create_model(
137
+ self.model_name,
138
+ pretrained=True,
139
+ img_size=self.img_size,
140
+ num_classes=0 # Remove classification head
141
+ )
142
+
143
+ # Wrap it in our compatibility layer
144
+ self.backbone = ViTWrapper(vit_model)
145
+
146
+ print(f" ✅ ViT loaded successfully")
147
+ print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}")
148
+ print(f" Embed dim: {self.backbone.embed_dim}")
149
+
150
+ def _determine_backbone_dimensions(self):
151
+ """Determine backbone output dimensions"""
152
+ print(" 🔍 Determining backbone dimensions...")
153
+
154
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
155
+ # Create a dummy input
156
+ dummy_input = torch.randn(1, 3, self.img_size, self.img_size)
157
+
158
+ # Get features
159
+ backbone_features, cls_dummy = self.backbone(dummy_input)
160
+ feature_tensor = backbone_features
161
+
162
+ self.backbone_dim = feature_tensor.shape[1]
163
+ self.feature_map_size = feature_tensor.shape[2]
164
+
165
+ print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial")
166
+ print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}")
167
+
168
+ def _enable_gradient_checkpointing(self):
169
+ """Enable gradient checkpointing for memory efficiency"""
170
+ print("🔄 Enabling gradient checkpointing...")
171
+
172
+ # Enable checkpointing for ViT backbone
173
+ if self.backbone.set_grad_checkpointing(True):
174
+ print(" ✅ ViT backbone checkpointing enabled")
175
+ else:
176
+ print(" ⚠️ ViT backbone doesn't support built-in checkpointing, will checkpoint manually")
177
+
178
+ def _checkpoint_backbone(self, x):
179
+ """Wrapper for backbone with gradient checkpointing"""
180
+ if self.use_gradient_checkpointing and self.training:
181
+ return checkpoint.checkpoint(self.backbone, x, use_reentrant=False)
182
+ else:
183
+ return self.backbone(x)
184
+
185
+ def _checkpoint_image_proj(self, x):
186
+ """Wrapper for image projection with gradient checkpointing"""
187
+ if self.use_gradient_checkpointing and self.training:
188
+ return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False)
189
+ else:
190
+ return self.image_token_proj(x)
191
+
192
+ def _checkpoint_cross_attention(self, query, key, value):
193
+ """Wrapper for cross attention with gradient checkpointing"""
194
+ def _attention_forward(q, k, v):
195
+ attended_features, _ = self.cross_attention(query=q, key=k, value=v)
196
+ return self.cross_norm(attended_features)
197
+
198
+ if self.use_gradient_checkpointing and self.training:
199
+ return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False)
200
+ else:
201
+ return _attention_forward(query, key, value)
202
+
203
+ def _checkpoint_candidate_selection(self, initial_logits):
204
+ """Wrapper for candidate selection with gradient checkpointing"""
205
+ def _candidate_forward(logits):
206
+ return self._get_candidate_tags(logits)
207
+
208
+ if self.use_gradient_checkpointing and self.training:
209
+ return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False)
210
+ else:
211
+ return _candidate_forward(initial_logits)
212
+
213
+ def _checkpoint_final_scoring(self, attended_features, candidate_indices):
214
+ """Wrapper for final scoring with gradient checkpointing"""
215
+ def _scoring_forward(features, indices):
216
+ emb = self.tag_embedding(indices)
217
+ # BF16 in, BF16 out
218
+ return (features * emb).sum(dim=-1)
219
+
220
+ if self.use_gradient_checkpointing and self.training:
221
+ return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False)
222
+ else:
223
+ return _scoring_forward(attended_features, candidate_indices)
224
+
225
+ def _init_weights(self):
226
+ """Initialize weights for new modules"""
227
+ def _init_layer(layer):
228
+ if isinstance(layer, nn.Linear):
229
+ nn.init.xavier_uniform_(layer.weight)
230
+ if layer.bias is not None:
231
+ nn.init.zeros_(layer.bias)
232
+ elif isinstance(layer, nn.Conv2d):
233
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
234
+ if layer.bias is not None:
235
+ nn.init.zeros_(layer.bias)
236
+ elif isinstance(layer, nn.Embedding):
237
+ nn.init.normal_(layer.weight, mean=0, std=0.02)
238
+
239
+ # Initialize new components
240
+ self.image_token_proj.apply(_init_layer)
241
+
242
+ # Initialize tag embeddings with normal distribution
243
+ nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02)
244
+
245
+ # Initialize tag bias
246
+ nn.init.zeros_(self.tag_bias)
247
+
248
+ def _print_parameter_count(self):
249
+ """Print parameter statistics"""
250
+ total_params = sum(p.numel() for p in self.parameters())
251
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
252
+ backbone_params = sum(p.numel() for p in self.backbone.parameters())
253
+
254
+ print(f"📊 Parameter Statistics:")
255
+ print(f" Total parameters: {total_params/1e6:.1f}M")
256
+ print(f" Trainable parameters: {trainable_params/1e6:.1f}M")
257
+ print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M")
258
+ print(f" Backbone parameters: {backbone_params/1e6:.1f}M")
259
+
260
+ if self.use_gradient_checkpointing:
261
+ print(f" 🔄 Gradient checkpointing enabled for memory efficiency")
262
+
263
+ @property
264
+ def debug(self):
265
+ return self._flags['debug']
266
+
267
+ @property
268
+ def model_stats(self):
269
+ return self._flags['model_stats']
270
+
271
+ def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None):
272
+ """Select candidate tags - no ground truth inclusion"""
273
+ batch_size = initial_logits.size(0)
274
+
275
+ # Simply select top K candidates based on initial predictions
276
+ top_probs, top_indices = torch.topk(
277
+ torch.sigmoid(initial_logits),
278
+ k=min(self.tag_context_size, self.total_tags),
279
+ dim=1, largest=True, sorted=True
280
+ )
281
+
282
+ return top_indices
283
+
284
+ def _analyze_predictions(self, predictions, tag_indices):
285
+ """Analyze prediction patterns"""
286
+ if not self.model_stats:
287
+ return {}
288
+
289
+ if torch._dynamo.is_compiling():
290
+ return {}
291
+
292
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
293
+ probs = torch.sigmoid(predictions)
294
+ relevant_probs = torch.gather(probs, 1, tag_indices)
295
+
296
+ return {
297
+ 'prediction_confidence': relevant_probs.mean().item(),
298
+ 'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(),
299
+ 'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(),
300
+ 'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(),
301
+ }
302
+
303
+ def forward(self, x, targets=None, hard_negatives=None):
304
+ """
305
+ Forward pass with ViT backbone, CLS token support and gradient-checkpointing.
306
+ All arithmetic tensors stay in the backbone’s dtype (BF16 under autocast,
307
+ FP32 otherwise). Anything that must mix dtypes is cast to match.
308
+ """
309
+ batch_size = x.size(0)
310
+ model_stats = {} if self.model_stats else {}
311
+
312
+ # ------------------------------------------------------------------
313
+ # 1. Backbone → patch map + CLS token
314
+ # ------------------------------------------------------------------
315
+ patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W]
316
+ # cls_token: [B, C]
317
+
318
+ # ------------------------------------------------------------------
319
+ # 2. Tokens → global image vector
320
+ # ------------------------------------------------------------------
321
+ image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W]
322
+ image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C]
323
+
324
+ # “Dual-pool”: mean-pool patches ⊕ CLS
325
+ global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C]
326
+
327
+ compute_dtype = global_features.dtype # BF16 or FP32
328
+
329
+ # ------------------------------------------------------------------
330
+ # 3. Initial logits (shared weights)
331
+ # ------------------------------------------------------------------
332
+ tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C]
333
+ tag_bias = self.tag_bias.to(compute_dtype) # [T]
334
+
335
+ initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T]
336
+ initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform
337
+ initial_preds = initial_logits # alias
338
+
339
+ # ------------------------------------------------------------------
340
+ # 4. Candidate set
341
+ # ------------------------------------------------------------------
342
+ candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K]
343
+
344
+ tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C]
345
+
346
+ attended_features = self._checkpoint_cross_attention( # [B, K, C]
347
+ tag_embeddings, image_tokens, image_tokens
348
+ )
349
+
350
+ # ------------------------------------------------------------------
351
+ # 5. Score candidates & scatter back
352
+ # ------------------------------------------------------------------
353
+ candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K]
354
+
355
+ # --- align dtypes so scatter never throws ---
356
+ if candidate_logits.dtype != initial_logits.dtype:
357
+ candidate_logits = candidate_logits.to(initial_logits.dtype)
358
+
359
+ refined_logits = initial_logits.clone()
360
+ refined_logits.scatter_(1, candidate_indices, candidate_logits)
361
+ refined_preds = refined_logits
362
+
363
+ # ------------------------------------------------------------------
364
+ # 6. Optional stats
365
+ # ------------------------------------------------------------------
366
+ if self.model_stats and targets is not None and not torch._dynamo.is_compiling():
367
+ model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds,
368
+ candidate_indices)
369
+ model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds,
370
+ candidate_indices)
371
+
372
+ return {
373
+ 'initial_predictions': initial_preds,
374
+ 'refined_predictions': refined_preds,
375
+ 'selected_candidates': candidate_indices,
376
+ 'model_stats': model_stats
377
+ }
378
+
379
+ def predict
utils/onnx_processing.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX-based batch image processing for the Image Tagger application.
3
+ Updated with proper ImageNet normalization and new metadata format.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import traceback
10
+ import numpy as np
11
+ import glob
12
+ import onnxruntime as ort
13
+ from PIL import Image
14
+ import torchvision.transforms as transforms
15
+ from concurrent.futures import ThreadPoolExecutor
16
+
17
+ def preprocess_image(image_path, image_size=512):
18
+ """
19
+ Process an image for ImageTagger inference with proper ImageNet normalization
20
+ """
21
+ if not os.path.exists(image_path):
22
+ raise ValueError(f"Image not found at path: {image_path}")
23
+
24
+ # ImageNet normalization - CRITICAL for your model
25
+ transform = transforms.Compose([
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(
28
+ mean=[0.485, 0.456, 0.406],
29
+ std=[0.229, 0.224, 0.225]
30
+ )
31
+ ])
32
+
33
+ try:
34
+ with Image.open(image_path) as img:
35
+ # Convert RGBA or Palette images to RGB
36
+ if img.mode in ('RGBA', 'P'):
37
+ img = img.convert('RGB')
38
+
39
+ # Get original dimensions
40
+ width, height = img.size
41
+ aspect_ratio = width / height
42
+
43
+ # Calculate new dimensions to maintain aspect ratio
44
+ if aspect_ratio > 1:
45
+ new_width = image_size
46
+ new_height = int(new_width / aspect_ratio)
47
+ else:
48
+ new_height = image_size
49
+ new_width = int(new_height * aspect_ratio)
50
+
51
+ # Resize with LANCZOS filter
52
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
53
+
54
+ # Create new image with padding (use ImageNet mean for padding)
55
+ # Using RGB values close to ImageNet mean: (0.485*255, 0.456*255, 0.406*255)
56
+ pad_color = (124, 116, 104)
57
+ new_image = Image.new('RGB', (image_size, image_size), pad_color)
58
+ paste_x = (image_size - new_width) // 2
59
+ paste_y = (image_size - new_height) // 2
60
+ new_image.paste(img, (paste_x, paste_y))
61
+
62
+ # Apply transforms (including ImageNet normalization)
63
+ img_tensor = transform(new_image)
64
+ return img_tensor.numpy()
65
+
66
+ except Exception as e:
67
+ raise Exception(f"Error processing {image_path}: {str(e)}")
68
+
69
+ def process_single_image_onnx(image_path, model_path, metadata, threshold_profile="Overall",
70
+ active_threshold=0.35, active_category_thresholds=None,
71
+ min_confidence=0.1):
72
+ """
73
+ Process a single image using ONNX model with new metadata format
74
+
75
+ Args:
76
+ image_path: Path to the image file
77
+ model_path: Path to the ONNX model file
78
+ metadata: Model metadata dictionary
79
+ threshold_profile: The threshold profile being used
80
+ active_threshold: Overall threshold value
81
+ active_category_thresholds: Category-specific thresholds
82
+ min_confidence: Minimum confidence to include in results
83
+
84
+ Returns:
85
+ Dictionary with tags and probabilities
86
+ """
87
+ try:
88
+ # Create ONNX tagger for this image (or reuse an existing one)
89
+ if hasattr(process_single_image_onnx, 'tagger'):
90
+ tagger = process_single_image_onnx.tagger
91
+ else:
92
+ # Create new tagger
93
+ tagger = ONNXImageTagger(model_path, metadata)
94
+ # Cache it for future calls
95
+ process_single_image_onnx.tagger = tagger
96
+
97
+ # Preprocess the image
98
+ start_time = time.time()
99
+ img_array = preprocess_image(image_path)
100
+
101
+ # Run inference
102
+ results = tagger.predict_batch(
103
+ [img_array],
104
+ threshold=active_threshold,
105
+ category_thresholds=active_category_thresholds,
106
+ min_confidence=min_confidence
107
+ )
108
+ inference_time = time.time() - start_time
109
+
110
+ if results:
111
+ result = results[0]
112
+ result['inference_time'] = inference_time
113
+ result['success'] = True
114
+ return result
115
+ else:
116
+ return {
117
+ 'success': False,
118
+ 'error': 'Failed to process image',
119
+ 'all_tags': [],
120
+ 'all_probs': {},
121
+ 'tags': {}
122
+ }
123
+
124
+ except Exception as e:
125
+ print(f"Error in process_single_image_onnx: {str(e)}")
126
+ traceback.print_exc()
127
+ return {
128
+ 'success': False,
129
+ 'error': str(e),
130
+ 'all_tags': [],
131
+ 'all_probs': {},
132
+ 'tags': {}
133
+ }
134
+
135
+ def preprocess_images_parallel(image_paths, image_size=512, max_workers=8):
136
+ """Process multiple images in parallel"""
137
+ processed_images = []
138
+ valid_paths = []
139
+
140
+ # Define a worker function
141
+ def process_single_image(path):
142
+ try:
143
+ return preprocess_image(path, image_size), path
144
+ except Exception as e:
145
+ print(f"Error processing {path}: {str(e)}")
146
+ return None, path
147
+
148
+ # Process images in parallel
149
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
150
+ results = list(executor.map(process_single_image, image_paths))
151
+
152
+ # Filter results
153
+ for img_array, path in results:
154
+ if img_array is not None:
155
+ processed_images.append(img_array)
156
+ valid_paths.append(path)
157
+
158
+ return processed_images, valid_paths
159
+
160
+ def apply_category_limits(result, category_limits):
161
+ """
162
+ Apply category limits to a result dictionary.
163
+
164
+ Args:
165
+ result: Result dictionary containing tags and all_tags
166
+ category_limits: Dictionary mapping categories to their tag limits
167
+ (0 = exclude category, -1 = no limit/include all)
168
+
169
+ Returns:
170
+ Updated result dictionary with limits applied
171
+ """
172
+ if not category_limits or not result['success']:
173
+ return result
174
+
175
+ # Get the filtered tags
176
+ filtered_tags = result['tags']
177
+
178
+ # Apply limits to each category
179
+ for category, cat_tags in list(filtered_tags.items()):
180
+ # Get limit for this category, default to -1 (no limit)
181
+ limit = category_limits.get(category, -1)
182
+
183
+ if limit == 0:
184
+ # Exclude this category entirely
185
+ del filtered_tags[category]
186
+ elif limit > 0 and len(cat_tags) > limit:
187
+ # Limit to top N tags for this category
188
+ filtered_tags[category] = cat_tags[:limit]
189
+
190
+ # Regenerate all_tags list after applying limits
191
+ all_tags = []
192
+ for category, cat_tags in filtered_tags.items():
193
+ for tag, _ in cat_tags:
194
+ all_tags.append(tag)
195
+
196
+ # Update the result with limited tags
197
+ result['tags'] = filtered_tags
198
+ result['all_tags'] = all_tags
199
+
200
+ return result
201
+
202
+ class ONNXImageTagger:
203
+ """ONNX-based image tagger for fast batch inference with updated metadata format"""
204
+
205
+ def __init__(self, model_path, metadata):
206
+ # Load model
207
+ self.model_path = model_path
208
+ try:
209
+ self.session = ort.InferenceSession(
210
+ model_path,
211
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
212
+ )
213
+ print(f"Using providers: {self.session.get_providers()}")
214
+ except Exception as e:
215
+ print(f"CUDA not available, using CPU: {e}")
216
+ self.session = ort.InferenceSession(
217
+ model_path,
218
+ providers=['CPUExecutionProvider']
219
+ )
220
+ print(f"Using providers: {self.session.get_providers()}")
221
+
222
+ # Store metadata (passed as dict, not loaded from file)
223
+ self.metadata = metadata
224
+
225
+ # Extract tag mappings from new metadata structure
226
+ if 'dataset_info' in metadata:
227
+ # New metadata format
228
+ self.tag_mapping = metadata['dataset_info']['tag_mapping']
229
+ self.idx_to_tag = self.tag_mapping['idx_to_tag']
230
+ self.tag_to_category = self.tag_mapping['tag_to_category']
231
+ self.total_tags = metadata['dataset_info']['total_tags']
232
+ else:
233
+ # Fallback for older format
234
+ self.idx_to_tag = metadata.get('idx_to_tag', {})
235
+ self.tag_to_category = metadata.get('tag_to_category', {})
236
+ self.total_tags = metadata.get('total_tags', len(self.idx_to_tag))
237
+
238
+ # Get input name
239
+ self.input_name = self.session.get_inputs()[0].name
240
+ print(f"Model loaded successfully. Input name: {self.input_name}")
241
+ print(f"Total tags: {self.total_tags}, Categories: {len(set(self.tag_to_category.values()))}")
242
+
243
+ def predict_batch(self, image_arrays, threshold=0.5, category_thresholds=None, min_confidence=0.1):
244
+ """Run batch inference on preprocessed image arrays"""
245
+ # Stack arrays into batch
246
+ batch_input = np.stack(image_arrays)
247
+
248
+ # Run inference
249
+ start_time = time.time()
250
+ outputs = self.session.run(None, {self.input_name: batch_input})
251
+ inference_time = time.time() - start_time
252
+ print(f"Batch inference completed in {inference_time:.4f} seconds ({inference_time/len(image_arrays):.4f} s/image)")
253
+
254
+ # Process outputs - handle both single and multi-output models
255
+ if len(outputs) >= 2:
256
+ # Multi-output model (initial_predictions, refined_predictions, selected_candidates)
257
+ initial_logits = outputs[0]
258
+ refined_logits = outputs[1]
259
+ # Use refined predictions as main output
260
+ main_logits = refined_logits
261
+ print(f"Using refined predictions (shape: {refined_logits.shape})")
262
+ else:
263
+ # Single output model
264
+ main_logits = outputs[0]
265
+ print(f"Using single output (shape: {main_logits.shape})")
266
+
267
+ # Apply sigmoid to get probabilities
268
+ main_probs = 1.0 / (1.0 + np.exp(-main_logits))
269
+
270
+ # Process results for each image in batch
271
+ batch_results = []
272
+
273
+ for i in range(main_probs.shape[0]):
274
+ probs = main_probs[i]
275
+
276
+ # Extract and organize all probabilities
277
+ all_probs = {}
278
+ for idx in range(probs.shape[0]):
279
+ prob_value = float(probs[idx])
280
+ if prob_value >= min_confidence:
281
+ idx_str = str(idx)
282
+ tag_name = self.idx_to_tag.get(idx_str, f"unknown-{idx}")
283
+ category = self.tag_to_category.get(tag_name, "general")
284
+
285
+ if category not in all_probs:
286
+ all_probs[category] = []
287
+
288
+ all_probs[category].append((tag_name, prob_value))
289
+
290
+ # Sort tags by probability within each category
291
+ for category in all_probs:
292
+ all_probs[category] = sorted(
293
+ all_probs[category],
294
+ key=lambda x: x[1],
295
+ reverse=True
296
+ )
297
+
298
+ # Get the filtered tags based on the selected threshold
299
+ tags = {}
300
+ for category, cat_tags in all_probs.items():
301
+ # Use category-specific threshold if available
302
+ if category_thresholds and category in category_thresholds:
303
+ cat_threshold = category_thresholds[category]
304
+ else:
305
+ cat_threshold = threshold
306
+
307
+ tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= cat_threshold]
308
+
309
+ # Create a flat list of all tags above threshold
310
+ all_tags = []
311
+ for category, cat_tags in tags.items():
312
+ for tag, _ in cat_tags:
313
+ all_tags.append(tag)
314
+
315
+ batch_results.append({
316
+ 'tags': tags,
317
+ 'all_probs': all_probs,
318
+ 'all_tags': all_tags,
319
+ 'success': True
320
+ })
321
+
322
+ return batch_results
323
+
324
+ def batch_process_images_onnx(folder_path, model_path, metadata_path, threshold_profile,
325
+ active_threshold, active_category_thresholds, save_dir=None,
326
+ progress_callback=None, min_confidence=0.1, batch_size=16,
327
+ category_limits=None):
328
+ """
329
+ Process all images in a folder using the ONNX model with new metadata format.
330
+
331
+ Args:
332
+ folder_path: Path to folder containing images
333
+ model_path: Path to the ONNX model file
334
+ metadata_path: Path to the model metadata file
335
+ threshold_profile: Selected threshold profile
336
+ active_threshold: Overall threshold value
337
+ active_category_thresholds: Category-specific thresholds
338
+ save_dir: Directory to save tag files (if None uses default)
339
+ progress_callback: Optional callback for progress updates
340
+ min_confidence: Minimum confidence threshold
341
+ batch_size: Number of images to process at once
342
+ category_limits: Dictionary mapping categories to their tag limits
343
+
344
+ Returns:
345
+ Dictionary with results for each image
346
+ """
347
+ from utils.file_utils import save_tags_to_file # Import here to avoid circular imports
348
+
349
+ # Find all image files in the folder
350
+ image_extensions = ['*.jpg', '*.jpeg', '*.png']
351
+ image_files = []
352
+
353
+ for ext in image_extensions:
354
+ image_files.extend(glob.glob(os.path.join(folder_path, ext)))
355
+ image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
356
+
357
+ # Remove duplicates (Windows case-insensitive filesystems)
358
+ if os.name == 'nt': # Windows
359
+ unique_paths = set()
360
+ unique_files = []
361
+ for file_path in image_files:
362
+ normalized_path = os.path.normpath(file_path).lower()
363
+ if normalized_path not in unique_paths:
364
+ unique_paths.add(normalized_path)
365
+ unique_files.append(file_path)
366
+ image_files = unique_files
367
+
368
+ if not image_files:
369
+ return {
370
+ 'success': False,
371
+ 'error': f"No images found in {folder_path}",
372
+ 'results': {}
373
+ }
374
+
375
+ # Use the provided save directory or create a default one
376
+ if save_dir is None:
377
+ app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
378
+ save_dir = os.path.join(app_dir, "saved_tags")
379
+
380
+ # Ensure the directory exists
381
+ os.makedirs(save_dir, exist_ok=True)
382
+
383
+ # Load metadata
384
+ try:
385
+ with open(metadata_path, 'r') as f:
386
+ metadata = json.load(f)
387
+ except Exception as e:
388
+ return {
389
+ 'success': False,
390
+ 'error': f"Failed to load metadata: {e}",
391
+ 'results': {}
392
+ }
393
+
394
+ # Create ONNX tagger
395
+ try:
396
+ tagger = ONNXImageTagger(model_path, metadata)
397
+ except Exception as e:
398
+ return {
399
+ 'success': False,
400
+ 'error': f"Failed to load model: {e}",
401
+ 'results': {}
402
+ }
403
+
404
+ # Process images in batches
405
+ results = {}
406
+ total_images = len(image_files)
407
+ processed = 0
408
+
409
+ start_time = time.time()
410
+
411
+ # Process in batches
412
+ for i in range(0, total_images, batch_size):
413
+ batch_start = time.time()
414
+
415
+ # Get current batch of images
416
+ batch_files = image_files[i:i+batch_size]
417
+ batch_size_actual = len(batch_files)
418
+
419
+ # Update progress if callback provided
420
+ if progress_callback:
421
+ progress_callback(processed, total_images, batch_files[0] if batch_files else None)
422
+
423
+ print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images")
424
+
425
+ try:
426
+ # Preprocess images in parallel
427
+ processed_images, valid_paths = preprocess_images_parallel(batch_files)
428
+
429
+ if processed_images:
430
+ # Run batch prediction
431
+ batch_results = tagger.predict_batch(
432
+ processed_images,
433
+ threshold=active_threshold,
434
+ category_thresholds=active_category_thresholds,
435
+ min_confidence=min_confidence
436
+ )
437
+
438
+ # Process results for each image
439
+ for j, (image_path, result) in enumerate(zip(valid_paths, batch_results)):
440
+ # Update progress if callback provided
441
+ if progress_callback:
442
+ progress_callback(processed + j, total_images, image_path)
443
+
444
+ # Apply category limits if specified
445
+ if category_limits and result['success']:
446
+ print(f"Applying limits to {os.path.basename(image_path)}: {len(result['all_tags'])} → ", end="")
447
+ result = apply_category_limits(result, category_limits)
448
+ print(f"{len(result['all_tags'])} tags")
449
+
450
+ # Save the tags to a file
451
+ if result['success']:
452
+ try:
453
+ output_path = save_tags_to_file(
454
+ image_path=image_path,
455
+ all_tags=result['all_tags'],
456
+ custom_dir=save_dir,
457
+ overwrite=True
458
+ )
459
+ result['output_path'] = str(output_path)
460
+ except Exception as e:
461
+ print(f"Error saving tags for {image_path}: {e}")
462
+ result['save_error'] = str(e)
463
+
464
+ # Store the result
465
+ results[image_path] = result
466
+
467
+ processed += batch_size_actual
468
+
469
+ # Calculate batch timing
470
+ batch_end = time.time()
471
+ batch_time = batch_end - batch_start
472
+ print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)")
473
+
474
+ except Exception as e:
475
+ print(f"Error processing batch: {str(e)}")
476
+ traceback.print_exc()
477
+
478
+ # Process failed images one by one as fallback
479
+ for j, image_path in enumerate(batch_files):
480
+ try:
481
+ # Update progress if callback provided
482
+ if progress_callback:
483
+ progress_callback(processed + j, total_images, image_path)
484
+
485
+ # Preprocess single image
486
+ img_array = preprocess_image(image_path)
487
+
488
+ # Run inference on single image
489
+ single_results = tagger.predict_batch(
490
+ [img_array],
491
+ threshold=active_threshold,
492
+ category_thresholds=active_category_thresholds,
493
+ min_confidence=min_confidence
494
+ )
495
+
496
+ if single_results:
497
+ result = single_results[0]
498
+
499
+ # Apply category limits if specified
500
+ if category_limits and result['success']:
501
+ result = apply_category_limits(result, category_limits)
502
+
503
+ # Save the tags to a file
504
+ if result['success']:
505
+ try:
506
+ output_path = save_tags_to_file(
507
+ image_path=image_path,
508
+ all_tags=result['all_tags'],
509
+ custom_dir=save_dir,
510
+ overwrite=True
511
+ )
512
+ result['output_path'] = str(output_path)
513
+ except Exception as e:
514
+ print(f"Error saving tags for {image_path}: {e}")
515
+ result['save_error'] = str(e)
516
+
517
+ results[image_path] = result
518
+ else:
519
+ results[image_path] = {
520
+ 'success': False,
521
+ 'error': 'Failed to process image',
522
+ 'all_tags': []
523
+ }
524
+
525
+ except Exception as img_e:
526
+ print(f"Error processing single image {image_path}: {str(img_e)}")
527
+ results[image_path] = {
528
+ 'success': False,
529
+ 'error': str(img_e),
530
+ 'all_tags': []
531
+ }
532
+
533
+ processed += batch_size_actual
534
+
535
+ # Final progress update
536
+ if progress_callback:
537
+ progress_callback(total_images, total_images, None)
538
+
539
+ end_time = time.time()
540
+ total_time = end_time - start_time
541
+ print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image")
542
+
543
+ return {
544
+ 'success': True,
545
+ 'total': total_images,
546
+ 'processed': len(results),
547
+ 'results': results,
548
+ 'save_dir': save_dir,
549
+ 'time_elapsed': end_time - start_time
550
+ }
551
+
552
+ def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=256):
553
+ """
554
+ Test ImageTagger ONNX model with proper handling of all outputs and new metadata format
555
+
556
+ Args:
557
+ model_path: Path to ONNX model file
558
+ metadata_path: Path to metadata JSON file
559
+ image_path: Path to test image
560
+ threshold: Confidence threshold for predictions
561
+ top_k: Maximum number of predictions to show
562
+ """
563
+ import onnxruntime as ort
564
+ import numpy as np
565
+ import json
566
+ import time
567
+ from collections import defaultdict
568
+
569
+ print(f"Loading ImageTagger ONNX model from {model_path}")
570
+
571
+ # Load metadata with proper error handling
572
+ try:
573
+ with open(metadata_path, 'r') as f:
574
+ metadata = json.load(f)
575
+ except Exception as e:
576
+ raise ValueError(f"Failed to load metadata: {e}")
577
+
578
+ # Extract tag mappings from new metadata structure
579
+ try:
580
+ if 'dataset_info' in metadata:
581
+ # New metadata format
582
+ dataset_info = metadata['dataset_info']
583
+ tag_mapping = dataset_info['tag_mapping']
584
+ idx_to_tag = tag_mapping['idx_to_tag']
585
+ tag_to_category = tag_mapping['tag_to_category']
586
+ total_tags = dataset_info['total_tags']
587
+ else:
588
+ # Fallback for older format
589
+ idx_to_tag = metadata.get('idx_to_tag', {})
590
+ tag_to_category = metadata.get('tag_to_category', {})
591
+ total_tags = metadata.get('total_tags', len(idx_to_tag))
592
+
593
+ print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories")
594
+
595
+ except KeyError as e:
596
+ raise ValueError(f"Invalid metadata structure, missing key: {e}")
597
+
598
+ # Initialize ONNX session with robust provider handling
599
+ providers = []
600
+ if ort.get_device() == 'GPU':
601
+ providers.append('CUDAExecutionProvider')
602
+ providers.append('CPUExecutionProvider')
603
+
604
+ try:
605
+ session = ort.InferenceSession(model_path, providers=providers)
606
+ active_provider = session.get_providers()[0]
607
+ print(f"Using provider: {active_provider}")
608
+
609
+ # Print model info
610
+ inputs = session.get_inputs()
611
+ outputs = session.get_outputs()
612
+ print(f"Model inputs: {len(inputs)}")
613
+ print(f"Model outputs: {len(outputs)}")
614
+ for i, output in enumerate(outputs):
615
+ print(f" Output {i}: {output.name} {output.shape}")
616
+
617
+ except Exception as e:
618
+ raise RuntimeError(f"Failed to create ONNX session: {e}")
619
+
620
+ # Preprocess image
621
+ print(f"Processing image: {image_path}")
622
+ try:
623
+ # Get image size from metadata
624
+ img_size = metadata.get('model_info', {}).get('img_size', 512)
625
+ img_tensor = preprocess_image(image_path, image_size=img_size)
626
+ img_numpy = img_tensor[np.newaxis, :] # Add batch dimension
627
+ print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}")
628
+
629
+ except Exception as e:
630
+ raise ValueError(f"Image preprocessing failed: {e}")
631
+
632
+ # Run inference
633
+ input_name = session.get_inputs()[0].name
634
+ print("Running inference...")
635
+
636
+ start_time = time.time()
637
+ try:
638
+ outputs = session.run(None, {input_name: img_numpy})
639
+ inference_time = time.time() - start_time
640
+ print(f"Inference completed in {inference_time:.4f} seconds")
641
+
642
+ except Exception as e:
643
+ raise RuntimeError(f"Inference failed: {e}")
644
+
645
+ # Handle outputs properly
646
+ if len(outputs) >= 2:
647
+ initial_logits = outputs[0]
648
+ refined_logits = outputs[1]
649
+ selected_candidates = outputs[2] if len(outputs) > 2 else None
650
+
651
+ # Use refined predictions as main output
652
+ main_logits = refined_logits
653
+ print(f"Using refined predictions (shape: {refined_logits.shape})")
654
+
655
+ else:
656
+ # Fallback to single output
657
+ main_logits = outputs[0]
658
+ print(f"Using single output (shape: {main_logits.shape})")
659
+
660
+ # Apply sigmoid to get probabilities
661
+ main_probs = 1.0 / (1.0 + np.exp(-main_logits))
662
+
663
+ # Apply threshold and get predictions
664
+ predictions_mask = (main_probs >= threshold)
665
+ indices = np.where(predictions_mask[0])[0]
666
+
667
+ if len(indices) == 0:
668
+ print(f"No predictions above threshold {threshold}")
669
+ # Show top 5 regardless of threshold
670
+ top_indices = np.argsort(main_probs[0])[-5:][::-1]
671
+ print("Top 5 predictions:")
672
+ for idx in top_indices:
673
+ idx_str = str(idx)
674
+ tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
675
+ prob = float(main_probs[0, idx])
676
+ print(f" {tag_name}: {prob:.3f}")
677
+ return {}
678
+
679
+ # Group by category
680
+ tags_by_category = defaultdict(list)
681
+
682
+ for idx in indices:
683
+ idx_str = str(idx)
684
+ tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
685
+ category = tag_to_category.get(tag_name, "general")
686
+ prob = float(main_probs[0, idx])
687
+
688
+ tags_by_category[category].append((tag_name, prob))
689
+
690
+ # Sort by probability within each category
691
+ for category in tags_by_category:
692
+ tags_by_category[category] = sorted(
693
+ tags_by_category[category],
694
+ key=lambda x: x[1],
695
+ reverse=True
696
+ )[:top_k] # Limit per category
697
+
698
+ # Print results
699
+ total_predictions = sum(len(tags) for tags in tags_by_category.values())
700
+ print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total")
701
+
702
+ # Category order for consistent display
703
+ category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating']
704
+
705
+ for category in category_order:
706
+ if category in tags_by_category:
707
+ tags = tags_by_category[category]
708
+ print(f"\n{category.upper()} ({len(tags)}):")
709
+ for tag, prob in tags:
710
+ print(f" {tag}: {prob:.3f}")
711
+
712
+ # Show any other categories not in standard order
713
+ for category in sorted(tags_by_category.keys()):
714
+ if category not in category_order:
715
+ tags = tags_by_category[category]
716
+ print(f"\n{category.upper()} ({len(tags)}):")
717
+ for tag, prob in tags:
718
+ print(f" {tag}: {prob:.3f}")
719
+
720
+ # Performance stats
721
+ print(f"\nPerformance:")
722
+ print(f" Inference time: {inference_time:.4f}s")
723
+ print(f" Provider: {active_provider}")
724
+ print(f" Max confidence: {main_probs.max():.3f}")
725
+ if total_predictions > 0:
726
+ avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags])
727
+ print(f" Average confidence: {avg_conf:.3f}")
728
+
729
+ return dict(tags_by_category)
utils/ui_components.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI components for the Image Tagger application.
3
+ """
4
+
5
+ import os
6
+ import streamlit as st
7
+ from PIL import Image
8
+
9
+
10
+ def display_progress_bar(prob):
11
+ """
12
+ Create an HTML progress bar for displaying probability.
13
+
14
+ Args:
15
+ prob: Probability value between 0 and 1
16
+
17
+ Returns:
18
+ HTML string for the progress bar
19
+ """
20
+ # Convert probability to percentage
21
+ percentage = int(prob * 100)
22
+
23
+ # Choose color based on confidence level
24
+ if prob >= 0.8:
25
+ color = "green"
26
+ elif prob >= 0.5:
27
+ color = "orange"
28
+ else:
29
+ color = "red"
30
+
31
+ # Return HTML for a styled progress bar
32
+ return f"""
33
+ <div style="margin-bottom: 5px; display: flex; align-items: center;">
34
+ <div style="flex-grow: 1; background-color: #f0f0f0; border-radius: 3px; height: 8px; position: relative;">
35
+ <div style="position: absolute; width: {percentage}%; background-color: {color}; height: 8px; border-radius: 3px;"></div>
36
+ </div>
37
+ <div style="margin-left: 8px; min-width: 40px; text-align: right; font-size: 0.9em;">{percentage}%</div>
38
+ </div>
39
+ """
40
+
41
+
42
+ def show_example_images(examples_dir):
43
+ """
44
+ Display example images from a directory.
45
+
46
+ Args:
47
+ examples_dir: Directory containing example images
48
+
49
+ Returns:
50
+ Selected image path or None
51
+ """
52
+ selected_image = None
53
+
54
+ if os.path.exists(examples_dir):
55
+ example_files = [f for f in os.listdir(examples_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
56
+
57
+ if example_files:
58
+ st.write("Select an example image:")
59
+
60
+ # Create a 2-column layout for examples
61
+ example_cols = st.columns(2)
62
+
63
+ for i, example_file in enumerate(example_files):
64
+ col_idx = i % 2
65
+ with example_cols[col_idx]:
66
+ example_path = os.path.join(examples_dir, example_file)
67
+
68
+ # Display thumbnail
69
+ try:
70
+ img = Image.open(example_path)
71
+ st.image(img, width=150, caption=example_file)
72
+
73
+ # Button to select this example
74
+ if st.button(f"Use", key=f"example_{i}"):
75
+ selected_image = example_path
76
+ st.session_state.original_filename = example_file
77
+
78
+ # Display full image
79
+ st.image(img, use_container_width=True)
80
+ st.success(f"Example '{example_file}' selected!")
81
+ except Exception as e:
82
+ st.error(f"Error loading {example_file}: {str(e)}")
83
+ else:
84
+ st.info("No example images found.")
85
+ st.write("Add some JPG or PNG images to the 'examples' directory.")
86
+ else:
87
+ st.info("Examples directory not found.")
88
+ st.write("Create an 'examples' directory and add some JPG or PNG images.")
89
+
90
+ return selected_image
91
+
92
+
93
+ def display_batch_results(batch_results):
94
+ """
95
+ Display batch processing results.
96
+
97
+ Args:
98
+ batch_results: Dictionary with batch processing results
99
+ """
100
+ if batch_results['success']:
101
+ st.success(f"✅ Processed {batch_results['processed']} of {batch_results['total']} images")
102
+
103
+ # Show details in an expander
104
+ with st.expander("Batch Processing Results", expanded=True):
105
+ # Count successes and failures
106
+ successes = sum(1 for r in batch_results['results'].values() if r['success'])
107
+ failures = batch_results['total'] - successes
108
+
109
+ st.write(f"- Successfully tagged: {successes}")
110
+ st.write(f"- Failed to process: {failures}")
111
+
112
+ if failures > 0:
113
+ # Show errors
114
+ st.write("### Processing Errors")
115
+ for img_path, result in batch_results['results'].items():
116
+ if not result['success']:
117
+ st.write(f"- **{os.path.basename(img_path)}**: {result.get('error', 'Unknown error')}")
118
+
119
+ # Show the location of the output files
120
+ if successes > 0:
121
+ st.write("### Output Files")
122
+ st.write(f"Tag files have been saved to the 'saved_tags' folder.")
123
+
124
+ # Show the first few as examples
125
+ st.write("Example outputs:")
126
+ sample_results = [(path, res) for path, res in batch_results['results'].items() if res['success']][:3]
127
+ for img_path, result in sample_results:
128
+ output_path = result.get('output_path', '')
129
+ if output_path and os.path.exists(output_path):
130
+ st.write(f"- **{os.path.basename(output_path)}**")
131
+
132
+ # Show file contents in a collapsible code block
133
+ with open(output_path, 'r', encoding='utf-8') as f:
134
+ content = f.read()
135
+ st.code(content, language='text')
136
+ else:
137
+ st.error(f"Batch processing failed: {batch_results.get('error', 'Unknown error')}")