Create matanyone_fixed/inference/inference_core.py
Browse files
matanyone_fixed/inference/inference_core.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fixed MatAnyone Inference Core
|
| 3 |
+
Removes tensor-to-numpy conversion bugs that cause F.pad() errors
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
|
| 13 |
+
"""
|
| 14 |
+
FIXED VERSION: Ensures tensor input stays as tensor
|
| 15 |
+
"""
|
| 16 |
+
if not isinstance(in_tensor, torch.Tensor):
|
| 17 |
+
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
|
| 18 |
+
|
| 19 |
+
h, w = in_tensor.shape[-2:]
|
| 20 |
+
|
| 21 |
+
# Calculate padding needed
|
| 22 |
+
new_h = (h + d - 1) // d * d
|
| 23 |
+
new_w = (w + d - 1) // d * d
|
| 24 |
+
|
| 25 |
+
lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2
|
| 26 |
+
lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2
|
| 27 |
+
|
| 28 |
+
pad_array = (lw, uw, lh, uh)
|
| 29 |
+
|
| 30 |
+
# CRITICAL FIX: Ensure tensor stays as tensor
|
| 31 |
+
out = F.pad(in_tensor, pad_array, mode='reflect')
|
| 32 |
+
|
| 33 |
+
return out, pad_array
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor:
|
| 37 |
+
"""Remove padding from tensor"""
|
| 38 |
+
if not isinstance(in_tensor, torch.Tensor):
|
| 39 |
+
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
|
| 40 |
+
|
| 41 |
+
lw, uw, lh, uh = pad
|
| 42 |
+
h, w = in_tensor.shape[-2:]
|
| 43 |
+
|
| 44 |
+
# Remove padding
|
| 45 |
+
if lh > 0:
|
| 46 |
+
in_tensor = in_tensor[..., lh:, :]
|
| 47 |
+
if uh > 0:
|
| 48 |
+
in_tensor = in_tensor[..., :-uh, :]
|
| 49 |
+
if lw > 0:
|
| 50 |
+
in_tensor = in_tensor[..., :, lw:]
|
| 51 |
+
if uw > 0:
|
| 52 |
+
in_tensor = in_tensor[..., :, :-uw]
|
| 53 |
+
|
| 54 |
+
return in_tensor
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class InferenceCore:
|
| 58 |
+
"""
|
| 59 |
+
FIXED MatAnyone Inference Core
|
| 60 |
+
Handles video matting with proper tensor operations
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, model: torch.nn.Module):
|
| 64 |
+
self.model = model
|
| 65 |
+
self.model.eval()
|
| 66 |
+
self.device = next(model.parameters()).device
|
| 67 |
+
self.pad = None
|
| 68 |
+
|
| 69 |
+
# Memory storage for temporal consistency
|
| 70 |
+
self.image_feature_store = {}
|
| 71 |
+
self.frame_count = 0
|
| 72 |
+
|
| 73 |
+
def _ensure_tensor_format(self,
|
| 74 |
+
image: Union[torch.Tensor, np.ndarray],
|
| 75 |
+
prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 76 |
+
"""
|
| 77 |
+
CRITICAL FIX: Ensure all inputs are properly formatted tensors
|
| 78 |
+
"""
|
| 79 |
+
# Convert image to tensor if needed
|
| 80 |
+
if isinstance(image, np.ndarray):
|
| 81 |
+
if image.ndim == 3 and image.shape[-1] == 3: # HWC format
|
| 82 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to CHW
|
| 83 |
+
elif image.ndim == 3 and image.shape[0] == 3: # CHW format
|
| 84 |
+
image = torch.from_numpy(image).float()
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unexpected image shape: {image.shape}")
|
| 87 |
+
|
| 88 |
+
# Ensure image is on correct device and has correct format
|
| 89 |
+
if not isinstance(image, torch.Tensor):
|
| 90 |
+
raise TypeError(f"Image must be tensor after conversion, got {type(image)}")
|
| 91 |
+
|
| 92 |
+
image = image.float().to(self.device)
|
| 93 |
+
|
| 94 |
+
# Ensure CHW format (3, H, W)
|
| 95 |
+
if image.ndim == 3 and image.shape[0] == 3:
|
| 96 |
+
pass # Already correct
|
| 97 |
+
elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3:
|
| 98 |
+
image = image.squeeze(0) # Remove batch dimension
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}")
|
| 101 |
+
|
| 102 |
+
# Handle probability mask if provided
|
| 103 |
+
if prob is not None:
|
| 104 |
+
if isinstance(prob, np.ndarray):
|
| 105 |
+
prob = torch.from_numpy(prob).float()
|
| 106 |
+
|
| 107 |
+
if not isinstance(prob, torch.Tensor):
|
| 108 |
+
raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}")
|
| 109 |
+
|
| 110 |
+
prob = prob.float().to(self.device)
|
| 111 |
+
|
| 112 |
+
# Ensure HW format for prob
|
| 113 |
+
while prob.ndim > 2:
|
| 114 |
+
prob = prob.squeeze(0)
|
| 115 |
+
|
| 116 |
+
if prob.ndim != 2:
|
| 117 |
+
raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}")
|
| 118 |
+
|
| 119 |
+
return image, prob
|
| 120 |
+
|
| 121 |
+
def step(self,
|
| 122 |
+
image: Union[torch.Tensor, np.ndarray],
|
| 123 |
+
prob: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
| 124 |
+
**kwargs) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
FIXED step method with proper tensor handling
|
| 127 |
+
"""
|
| 128 |
+
# Convert inputs to proper tensor format
|
| 129 |
+
image, prob = self._ensure_tensor_format(image, prob)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
# Pad image for processing
|
| 133 |
+
image_padded, self.pad = pad_divide_by(image, 16)
|
| 134 |
+
|
| 135 |
+
# Add batch dimension for model
|
| 136 |
+
image_batch = image_padded.unsqueeze(0) # (1, 3, H_pad, W_pad)
|
| 137 |
+
|
| 138 |
+
if prob is not None:
|
| 139 |
+
# Pad probability mask to match image
|
| 140 |
+
h_pad, w_pad = image_padded.shape[-2:]
|
| 141 |
+
h_orig, w_orig = prob.shape
|
| 142 |
+
|
| 143 |
+
# Resize prob to match padded image size
|
| 144 |
+
prob_resized = F.interpolate(
|
| 145 |
+
prob.unsqueeze(0).unsqueeze(0), # (1, 1, H, W)
|
| 146 |
+
size=(h_pad, w_pad),
|
| 147 |
+
mode='bilinear',
|
| 148 |
+
align_corners=False
|
| 149 |
+
).squeeze() # (H_pad, W_pad)
|
| 150 |
+
|
| 151 |
+
prob_batch = prob_resized.unsqueeze(0).unsqueeze(0) # (1, 1, H_pad, W_pad)
|
| 152 |
+
|
| 153 |
+
# Forward pass with probability guidance
|
| 154 |
+
try:
|
| 155 |
+
if hasattr(self.model, 'forward_with_prob'):
|
| 156 |
+
output = self.model.forward_with_prob(image_batch, prob_batch)
|
| 157 |
+
else:
|
| 158 |
+
# Fallback: concatenate prob as additional channel
|
| 159 |
+
input_tensor = torch.cat([image_batch, prob_batch], dim=1) # (1, 4, H_pad, W_pad)
|
| 160 |
+
output = self.model(input_tensor)
|
| 161 |
+
except Exception:
|
| 162 |
+
# Final fallback: just use image
|
| 163 |
+
output = self.model(image_batch)
|
| 164 |
+
else:
|
| 165 |
+
# Forward pass without probability guidance
|
| 166 |
+
output = self.model(image_batch)
|
| 167 |
+
|
| 168 |
+
# Extract alpha channel (assume model outputs alpha as last channel or single channel)
|
| 169 |
+
if output.shape[1] == 1:
|
| 170 |
+
alpha = output.squeeze(1) # (1, H_pad, W_pad)
|
| 171 |
+
elif output.shape[1] > 1:
|
| 172 |
+
alpha = output[:, -1:, :, :] # Take last channel as alpha
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError(f"Unexpected model output shape: {output.shape}")
|
| 175 |
+
|
| 176 |
+
# Remove padding
|
| 177 |
+
alpha_unpadded = unpad_tensor(alpha, self.pad)
|
| 178 |
+
|
| 179 |
+
# Remove batch dimension and ensure 2D output
|
| 180 |
+
alpha_final = alpha_unpadded.squeeze(0) # (H, W)
|
| 181 |
+
|
| 182 |
+
# Ensure values are in [0, 1] range
|
| 183 |
+
alpha_final = torch.clamp(alpha_final, 0.0, 1.0)
|
| 184 |
+
|
| 185 |
+
self.frame_count += 1
|
| 186 |
+
|
| 187 |
+
return alpha_final
|
| 188 |
+
|
| 189 |
+
def clear_memory(self):
|
| 190 |
+
"""Clear stored features for memory management"""
|
| 191 |
+
self.image_feature_store.clear()
|
| 192 |
+
self.frame_count = 0
|