MogensR commited on
Commit
9015f7f
·
1 Parent(s): 4142570

Create core/models.py

Browse files
Files changed (1) hide show
  1. core/models.py +559 -0
core/models.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model management and optimization for BackgroundFX Pro.
3
+ Fixes MatAnyone quality issues and manages model loading.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, Any, Optional, Tuple, List
10
+ from dataclasses import dataclass
11
+ import numpy as np
12
+ from pathlib import Path
13
+ import logging
14
+ import gc
15
+ from functools import lru_cache
16
+ import warnings
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class ModelConfig:
23
+ """Configuration for model management."""
24
+ sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt"
25
+ matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth"
26
+ device: str = "cuda"
27
+ dtype: torch.dtype = torch.float16
28
+ optimize_memory: bool = True
29
+ use_amp: bool = True
30
+ cache_size: int = 5
31
+ enable_quality_fixes: bool = True
32
+ matanyone_enhancement: bool = True
33
+ use_tensorrt: bool = False
34
+ batch_size: int = 1
35
+
36
+
37
+ class ModelCache:
38
+ """Intelligent model caching system."""
39
+
40
+ def __init__(self, max_size: int = 5):
41
+ self.cache = {}
42
+ self.max_size = max_size
43
+ self.access_count = {}
44
+ self.memory_usage = {}
45
+
46
+ def add(self, key: str, model: Any, memory_size: float):
47
+ """Add model to cache with memory tracking."""
48
+ if len(self.cache) >= self.max_size:
49
+ # Remove least recently used
50
+ lru_key = min(self.access_count, key=self.access_count.get)
51
+ self.remove(lru_key)
52
+
53
+ self.cache[key] = model
54
+ self.access_count[key] = 0
55
+ self.memory_usage[key] = memory_size
56
+
57
+ def get(self, key: str) -> Optional[Any]:
58
+ """Get model from cache."""
59
+ if key in self.cache:
60
+ self.access_count[key] += 1
61
+ return self.cache[key]
62
+ return None
63
+
64
+ def remove(self, key: str):
65
+ """Remove model from cache and free memory."""
66
+ if key in self.cache:
67
+ model = self.cache[key]
68
+ del self.cache[key]
69
+ del self.access_count[key]
70
+ del self.memory_usage[key]
71
+
72
+ # Force cleanup
73
+ del model
74
+ gc.collect()
75
+ if torch.cuda.is_available():
76
+ torch.cuda.empty_cache()
77
+
78
+ def clear(self):
79
+ """Clear entire cache."""
80
+ keys = list(self.cache.keys())
81
+ for key in keys:
82
+ self.remove(key)
83
+
84
+
85
+ class MatAnyoneModel(nn.Module):
86
+ """Enhanced MatAnyone model with quality fixes."""
87
+
88
+ def __init__(self, config: ModelConfig):
89
+ super().__init__()
90
+ self.config = config
91
+ self.base_model = None
92
+ self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None
93
+ self.loaded = False
94
+
95
+ def load(self):
96
+ """Load MatAnyone model with optimizations."""
97
+ if self.loaded:
98
+ return
99
+
100
+ try:
101
+ # Load checkpoint
102
+ checkpoint_path = Path(self.config.matanyone_checkpoint)
103
+ if not checkpoint_path.exists():
104
+ logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}")
105
+ return
106
+
107
+ # Load model weights
108
+ state_dict = torch.load(
109
+ checkpoint_path,
110
+ map_location=self.config.device
111
+ )
112
+
113
+ # Initialize base model (placeholder - replace with actual MatAnyone architecture)
114
+ self.base_model = self._build_matanyone_architecture()
115
+
116
+ # Load weights with compatibility fixes
117
+ self._load_weights_safe(state_dict)
118
+
119
+ # Optimize model
120
+ if self.config.optimize_memory:
121
+ self._optimize_model()
122
+
123
+ self.loaded = True
124
+ logger.info("MatAnyone model loaded successfully")
125
+
126
+ except Exception as e:
127
+ logger.error(f"Failed to load MatAnyone model: {e}")
128
+ self.loaded = False
129
+
130
+ def _build_matanyone_architecture(self) -> nn.Module:
131
+ """Build MatAnyone architecture."""
132
+ # This is a placeholder - replace with actual MatAnyone architecture
133
+ class MatAnyoneBase(nn.Module):
134
+ def __init__(self):
135
+ super().__init__()
136
+ self.encoder = nn.Sequential(
137
+ nn.Conv2d(4, 64, 3, padding=1),
138
+ nn.ReLU(),
139
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
140
+ nn.ReLU(),
141
+ nn.Conv2d(128, 256, 3, stride=2, padding=1),
142
+ nn.ReLU(),
143
+ )
144
+ self.decoder = nn.Sequential(
145
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
146
+ nn.ReLU(),
147
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
148
+ nn.ReLU(),
149
+ nn.Conv2d(64, 4, 3, padding=1),
150
+ nn.Sigmoid()
151
+ )
152
+
153
+ def forward(self, x):
154
+ features = self.encoder(x)
155
+ output = self.decoder(features)
156
+ return output
157
+
158
+ return MatAnyoneBase().to(self.config.device)
159
+
160
+ def _load_weights_safe(self, state_dict: Dict):
161
+ """Safely load weights with compatibility handling."""
162
+ model_dict = self.base_model.state_dict()
163
+
164
+ # Filter compatible weights
165
+ compatible_dict = {}
166
+ for k, v in state_dict.items():
167
+ # Remove module prefix if present
168
+ if k.startswith('module.'):
169
+ k = k[7:]
170
+
171
+ if k in model_dict and model_dict[k].shape == v.shape:
172
+ compatible_dict[k] = v
173
+ else:
174
+ logger.warning(f"Skipping incompatible weight: {k}")
175
+
176
+ # Load compatible weights
177
+ model_dict.update(compatible_dict)
178
+ self.base_model.load_state_dict(model_dict, strict=False)
179
+
180
+ logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights")
181
+
182
+ def _optimize_model(self):
183
+ """Optimize model for inference."""
184
+ if not self.base_model:
185
+ return
186
+
187
+ self.base_model.eval()
188
+
189
+ # Convert to half precision if using GPU
190
+ if self.config.dtype == torch.float16 and self.config.device != "cpu":
191
+ self.base_model = self.base_model.half()
192
+
193
+ # Disable gradient computation
194
+ for param in self.base_model.parameters():
195
+ param.requires_grad = False
196
+
197
+ # TensorRT optimization (if available)
198
+ if self.config.use_tensorrt:
199
+ try:
200
+ self._optimize_with_tensorrt()
201
+ except Exception as e:
202
+ logger.warning(f"TensorRT optimization failed: {e}")
203
+
204
+ def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]:
205
+ """Enhanced forward pass with quality fixes."""
206
+ if not self.loaded:
207
+ self.load()
208
+
209
+ if not self.base_model:
210
+ return {'alpha': mask, 'foreground': image}
211
+
212
+ # Prepare input
213
+ x = torch.cat([image, mask.unsqueeze(1)], dim=1)
214
+
215
+ # Fix input quality issues
216
+ if self.config.matanyone_enhancement:
217
+ x = self._preprocess_input(x)
218
+
219
+ # Forward pass with mixed precision
220
+ with torch.cuda.amp.autocast(enabled=self.config.use_amp):
221
+ output = self.base_model(x)
222
+
223
+ # Parse output
224
+ alpha = output[:, 3:4, :, :]
225
+ foreground = output[:, :3, :, :]
226
+
227
+ # Apply quality enhancement
228
+ if self.quality_enhancer:
229
+ alpha = self.quality_enhancer.enhance_alpha(alpha, mask)
230
+ foreground = self.quality_enhancer.enhance_foreground(foreground, image)
231
+
232
+ # Post-process to fix common MatAnyone issues
233
+ alpha = self._fix_matanyone_artifacts(alpha, mask)
234
+
235
+ return {
236
+ 'alpha': alpha,
237
+ 'foreground': foreground,
238
+ 'confidence': self._compute_confidence(alpha, mask)
239
+ }
240
+
241
+ def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
242
+ """Preprocess input to improve MatAnyone quality."""
243
+ # Denoise input
244
+ if x.shape[2] > 64: # Only for reasonable resolutions
245
+ x = self._bilateral_filter_torch(x)
246
+
247
+ # Normalize properly
248
+ x = torch.clamp(x, 0, 1)
249
+
250
+ # Enhance edges in mask channel
251
+ mask_channel = x[:, 3:4, :, :]
252
+ mask_enhanced = self._enhance_mask_edges(mask_channel)
253
+ x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1)
254
+
255
+ return x
256
+
257
+ def _fix_matanyone_artifacts(self, alpha: torch.Tensor,
258
+ original_mask: torch.Tensor) -> torch.Tensor:
259
+ """Fix common MatAnyone artifacts."""
260
+ # Fix edge bleeding
261
+ alpha = self._fix_edge_bleeding(alpha, original_mask)
262
+
263
+ # Fix transparency issues
264
+ alpha = self._fix_transparency_issues(alpha)
265
+
266
+ # Ensure consistency with original mask
267
+ alpha = self._ensure_mask_consistency(alpha, original_mask)
268
+
269
+ return alpha
270
+
271
+ def _fix_edge_bleeding(self, alpha: torch.Tensor,
272
+ original_mask: torch.Tensor) -> torch.Tensor:
273
+ """Fix edge bleeding artifacts."""
274
+ # Detect edges
275
+ edges = self._detect_edges_torch(original_mask)
276
+
277
+ # Create edge mask
278
+ edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2)
279
+
280
+ # Refine alpha near edges
281
+ alpha_refined = alpha.clone()
282
+ edge_region = edge_mask > 0.1
283
+
284
+ # Apply guided filter near edges
285
+ if edge_region.any():
286
+ alpha_refined[edge_region] = (
287
+ 0.7 * alpha[edge_region] +
288
+ 0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region]
289
+ )
290
+
291
+ return alpha_refined
292
+
293
+ def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor:
294
+ """Fix transparency artifacts."""
295
+ # Identify problematic transparency values
296
+ mid_range = (alpha > 0.2) & (alpha < 0.8)
297
+
298
+ # Push mid-range values toward 0 or 1
299
+ alpha_fixed = alpha.clone()
300
+ alpha_fixed[mid_range] = torch.where(
301
+ alpha[mid_range] > 0.5,
302
+ torch.clamp(alpha[mid_range] * 1.2, max=1.0),
303
+ torch.clamp(alpha[mid_range] * 0.8, min=0.0)
304
+ )
305
+
306
+ # Smooth transitions
307
+ alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3))
308
+
309
+ return alpha_fixed
310
+
311
+ def _ensure_mask_consistency(self, alpha: torch.Tensor,
312
+ original_mask: torch.Tensor) -> torch.Tensor:
313
+ """Ensure consistency with original mask."""
314
+ # Expand mask dimensions if needed
315
+ if original_mask.dim() == 2:
316
+ original_mask = original_mask.unsqueeze(0).unsqueeze(0)
317
+ elif original_mask.dim() == 3:
318
+ original_mask = original_mask.unsqueeze(1)
319
+
320
+ # Where original mask is 0, alpha should also be 0
321
+ alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha)
322
+
323
+ # Where original mask is 1, alpha should be close to 1
324
+ alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha)
325
+
326
+ return alpha
327
+
328
+ def _compute_confidence(self, alpha: torch.Tensor,
329
+ original_mask: torch.Tensor) -> torch.Tensor:
330
+ """Compute confidence score for the output."""
331
+ # Expand dimensions if needed
332
+ if original_mask.dim() < alpha.dim():
333
+ original_mask = original_mask.unsqueeze(1).expand_as(alpha)
334
+
335
+ # Compute similarity
336
+ diff = torch.abs(alpha - original_mask)
337
+ confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3))
338
+
339
+ return confidence
340
+
341
+ def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor:
342
+ """Apply bilateral filter in PyTorch."""
343
+ # Simple approximation using Gaussian blur
344
+ # For true bilateral filtering, would need custom CUDA kernel
345
+ return F.gaussian_blur(x, kernel_size=(5, 5))
346
+
347
+ def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor:
348
+ """Enhance edges in mask channel."""
349
+ # Detect edges
350
+ edges = self._detect_edges_torch(mask)
351
+
352
+ # Enhance mask with edges
353
+ enhanced = mask + 0.3 * edges
354
+ enhanced = torch.clamp(enhanced, 0, 1)
355
+
356
+ return enhanced
357
+
358
+ def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor:
359
+ """Detect edges using Sobel filters."""
360
+ # Sobel kernels
361
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
362
+ dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
363
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
364
+ dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
365
+
366
+ # Apply Sobel filters
367
+ edges_x = F.conv2d(x, sobel_x, padding=1)
368
+ edges_y = F.conv2d(x, sobel_y, padding=1)
369
+
370
+ # Compute edge magnitude
371
+ edges = torch.sqrt(edges_x ** 2 + edges_y ** 2)
372
+
373
+ return edges
374
+
375
+
376
+ class SAM2Model:
377
+ """SAM2 model wrapper with optimizations."""
378
+
379
+ def __init__(self, config: ModelConfig):
380
+ self.config = config
381
+ self.model = None
382
+ self.predictor = None
383
+ self.loaded = False
384
+
385
+ def load(self):
386
+ """Load SAM2 model."""
387
+ if self.loaded:
388
+ return
389
+
390
+ try:
391
+ # Import SAM2 (assuming it's installed)
392
+ from sam2.build_sam import build_sam2
393
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
394
+
395
+ # Build model
396
+ self.model = build_sam2(
397
+ config_file="sam2_hiera_l.yaml",
398
+ ckpt_path=self.config.sam2_checkpoint,
399
+ device=self.config.device
400
+ )
401
+
402
+ # Create predictor
403
+ self.predictor = SAM2ImagePredictor(self.model)
404
+
405
+ self.loaded = True
406
+ logger.info("SAM2 model loaded successfully")
407
+
408
+ except Exception as e:
409
+ logger.error(f"Failed to load SAM2 model: {e}")
410
+ self.loaded = False
411
+
412
+ def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray:
413
+ """Generate segmentation mask."""
414
+ if not self.loaded:
415
+ self.load()
416
+
417
+ if not self.predictor:
418
+ return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
419
+
420
+ # Set image
421
+ self.predictor.set_image(image)
422
+
423
+ # Use prompts if provided, otherwise use automatic segmentation
424
+ if prompts:
425
+ masks, scores, _ = self.predictor.predict(
426
+ point_coords=prompts.get('points'),
427
+ point_labels=prompts.get('labels'),
428
+ box=prompts.get('box'),
429
+ multimask_output=True
430
+ )
431
+ # Select best mask
432
+ mask = masks[np.argmax(scores)]
433
+ else:
434
+ # Automatic segmentation
435
+ masks = self.predictor.generate_auto_masks(image)
436
+ mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0])
437
+
438
+ return mask
439
+
440
+
441
+ class QualityEnhancer(nn.Module):
442
+ """Neural quality enhancement module."""
443
+
444
+ def __init__(self):
445
+ super().__init__()
446
+ self.alpha_refiner = nn.Sequential(
447
+ nn.Conv2d(1, 16, 3, padding=1),
448
+ nn.ReLU(),
449
+ nn.Conv2d(16, 16, 3, padding=1),
450
+ nn.ReLU(),
451
+ nn.Conv2d(16, 1, 3, padding=1),
452
+ nn.Sigmoid()
453
+ )
454
+
455
+ self.foreground_enhancer = nn.Sequential(
456
+ nn.Conv2d(3, 32, 3, padding=1),
457
+ nn.ReLU(),
458
+ nn.Conv2d(32, 32, 3, padding=1),
459
+ nn.ReLU(),
460
+ nn.Conv2d(32, 3, 3, padding=1),
461
+ nn.Tanh()
462
+ )
463
+
464
+ def enhance_alpha(self, alpha: torch.Tensor,
465
+ original_mask: torch.Tensor) -> torch.Tensor:
466
+ """Enhance alpha channel quality."""
467
+ # Refine with neural network
468
+ refined = self.alpha_refiner(alpha)
469
+
470
+ # Blend with original for stability
471
+ enhanced = 0.7 * refined + 0.3 * alpha
472
+
473
+ return torch.clamp(enhanced, 0, 1)
474
+
475
+ def enhance_foreground(self, foreground: torch.Tensor,
476
+ original_image: torch.Tensor) -> torch.Tensor:
477
+ """Enhance foreground quality."""
478
+ # Compute residual
479
+ residual = self.foreground_enhancer(foreground)
480
+
481
+ # Add residual
482
+ enhanced = foreground + 0.1 * residual
483
+
484
+ return torch.clamp(enhanced, 0, 1)
485
+
486
+
487
+ class ModelManager:
488
+ """Central model management system."""
489
+
490
+ def __init__(self, config: Optional[ModelConfig] = None):
491
+ self.config = config or ModelConfig()
492
+ self.cache = ModelCache(max_size=self.config.cache_size)
493
+ self.models = {}
494
+
495
+ # Initialize models
496
+ self.sam2 = SAM2Model(self.config)
497
+ self.matanyone = MatAnyoneModel(self.config)
498
+
499
+ def load_all(self):
500
+ """Load all models."""
501
+ logger.info("Loading all models...")
502
+ self.sam2.load()
503
+ self.matanyone.load()
504
+ logger.info("All models loaded")
505
+
506
+ def get_sam2(self) -> SAM2Model:
507
+ """Get SAM2 model."""
508
+ if not self.sam2.loaded:
509
+ self.sam2.load()
510
+ return self.sam2
511
+
512
+ def get_matanyone(self) -> MatAnyoneModel:
513
+ """Get MatAnyone model."""
514
+ if not self.matanyone.loaded:
515
+ self.matanyone.load()
516
+ return self.matanyone
517
+
518
+ def process_frame(self, image: np.ndarray,
519
+ mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
520
+ """Process single frame through pipeline."""
521
+ # Convert to tensor
522
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
523
+ image_tensor = image_tensor.to(self.config.device)
524
+
525
+ # Get or generate mask
526
+ if mask is None:
527
+ mask = self.sam2.predict(image)
528
+
529
+ mask_tensor = torch.from_numpy(mask).float().to(self.config.device)
530
+
531
+ # Process with MatAnyone
532
+ result = self.matanyone(image_tensor, mask_tensor)
533
+
534
+ # Convert back to numpy
535
+ output = {
536
+ 'alpha': result['alpha'].squeeze().cpu().numpy(),
537
+ 'foreground': result['foreground'].squeeze().permute(1, 2, 0).cpu().numpy() * 255,
538
+ 'confidence': result['confidence'].cpu().numpy()
539
+ }
540
+
541
+ return output
542
+
543
+ def cleanup(self):
544
+ """Cleanup models and free memory."""
545
+ self.cache.clear()
546
+ gc.collect()
547
+ if torch.cuda.is_available():
548
+ torch.cuda.empty_cache()
549
+
550
+
551
+ # Export classes
552
+ __all__ = [
553
+ 'ModelManager',
554
+ 'SAM2Model',
555
+ 'MatAnyoneModel',
556
+ 'ModelConfig',
557
+ 'ModelCache',
558
+ 'QualityEnhancer'
559
+ ]