MogensR commited on
Commit
f5fcafb
·
verified ·
1 Parent(s): a767d84

Create matanyone_fixed/inference/inference_core.py

Browse files
matanyone_fixed/inference/inference_core.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed MatAnyone Inference Core
3
+ Removes tensor-to-numpy conversion bugs that cause F.pad() errors
4
+ """
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import Optional, Union, Tuple
10
+
11
+
12
+ def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
13
+ """
14
+ FIXED VERSION: Ensures tensor input stays as tensor
15
+ """
16
+ if not isinstance(in_tensor, torch.Tensor):
17
+ raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
18
+
19
+ h, w = in_tensor.shape[-2:]
20
+
21
+ # Calculate padding needed
22
+ new_h = (h + d - 1) // d * d
23
+ new_w = (w + d - 1) // d * d
24
+
25
+ lh, uh = (new_h - h) // 2, (new_h - h) // 2 + (new_h - h) % 2
26
+ lw, uw = (new_w - w) // 2, (new_w - w) // 2 + (new_w - w) % 2
27
+
28
+ pad_array = (lw, uw, lh, uh)
29
+
30
+ # CRITICAL FIX: Ensure tensor stays as tensor
31
+ out = F.pad(in_tensor, pad_array, mode='reflect')
32
+
33
+ return out, pad_array
34
+
35
+
36
+ def unpad_tensor(in_tensor: torch.Tensor, pad: Tuple[int, int, int, int]) -> torch.Tensor:
37
+ """Remove padding from tensor"""
38
+ if not isinstance(in_tensor, torch.Tensor):
39
+ raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)}")
40
+
41
+ lw, uw, lh, uh = pad
42
+ h, w = in_tensor.shape[-2:]
43
+
44
+ # Remove padding
45
+ if lh > 0:
46
+ in_tensor = in_tensor[..., lh:, :]
47
+ if uh > 0:
48
+ in_tensor = in_tensor[..., :-uh, :]
49
+ if lw > 0:
50
+ in_tensor = in_tensor[..., :, lw:]
51
+ if uw > 0:
52
+ in_tensor = in_tensor[..., :, :-uw]
53
+
54
+ return in_tensor
55
+
56
+
57
+ class InferenceCore:
58
+ """
59
+ FIXED MatAnyone Inference Core
60
+ Handles video matting with proper tensor operations
61
+ """
62
+
63
+ def __init__(self, model: torch.nn.Module):
64
+ self.model = model
65
+ self.model.eval()
66
+ self.device = next(model.parameters()).device
67
+ self.pad = None
68
+
69
+ # Memory storage for temporal consistency
70
+ self.image_feature_store = {}
71
+ self.frame_count = 0
72
+
73
+ def _ensure_tensor_format(self,
74
+ image: Union[torch.Tensor, np.ndarray],
75
+ prob: Optional[Union[torch.Tensor, np.ndarray]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
76
+ """
77
+ CRITICAL FIX: Ensure all inputs are properly formatted tensors
78
+ """
79
+ # Convert image to tensor if needed
80
+ if isinstance(image, np.ndarray):
81
+ if image.ndim == 3 and image.shape[-1] == 3: # HWC format
82
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to CHW
83
+ elif image.ndim == 3 and image.shape[0] == 3: # CHW format
84
+ image = torch.from_numpy(image).float()
85
+ else:
86
+ raise ValueError(f"Unexpected image shape: {image.shape}")
87
+
88
+ # Ensure image is on correct device and has correct format
89
+ if not isinstance(image, torch.Tensor):
90
+ raise TypeError(f"Image must be tensor after conversion, got {type(image)}")
91
+
92
+ image = image.float().to(self.device)
93
+
94
+ # Ensure CHW format (3, H, W)
95
+ if image.ndim == 3 and image.shape[0] == 3:
96
+ pass # Already correct
97
+ elif image.ndim == 4 and image.shape[0] == 1 and image.shape[1] == 3:
98
+ image = image.squeeze(0) # Remove batch dimension
99
+ else:
100
+ raise ValueError(f"Image must be (3,H,W) or (1,3,H,W), got {image.shape}")
101
+
102
+ # Handle probability mask if provided
103
+ if prob is not None:
104
+ if isinstance(prob, np.ndarray):
105
+ prob = torch.from_numpy(prob).float()
106
+
107
+ if not isinstance(prob, torch.Tensor):
108
+ raise TypeError(f"Prob must be tensor after conversion, got {type(prob)}")
109
+
110
+ prob = prob.float().to(self.device)
111
+
112
+ # Ensure HW format for prob
113
+ while prob.ndim > 2:
114
+ prob = prob.squeeze(0)
115
+
116
+ if prob.ndim != 2:
117
+ raise ValueError(f"Prob must be (H,W) after processing, got {prob.shape}")
118
+
119
+ return image, prob
120
+
121
+ def step(self,
122
+ image: Union[torch.Tensor, np.ndarray],
123
+ prob: Optional[Union[torch.Tensor, np.ndarray]] = None,
124
+ **kwargs) -> torch.Tensor:
125
+ """
126
+ FIXED step method with proper tensor handling
127
+ """
128
+ # Convert inputs to proper tensor format
129
+ image, prob = self._ensure_tensor_format(image, prob)
130
+
131
+ with torch.no_grad():
132
+ # Pad image for processing
133
+ image_padded, self.pad = pad_divide_by(image, 16)
134
+
135
+ # Add batch dimension for model
136
+ image_batch = image_padded.unsqueeze(0) # (1, 3, H_pad, W_pad)
137
+
138
+ if prob is not None:
139
+ # Pad probability mask to match image
140
+ h_pad, w_pad = image_padded.shape[-2:]
141
+ h_orig, w_orig = prob.shape
142
+
143
+ # Resize prob to match padded image size
144
+ prob_resized = F.interpolate(
145
+ prob.unsqueeze(0).unsqueeze(0), # (1, 1, H, W)
146
+ size=(h_pad, w_pad),
147
+ mode='bilinear',
148
+ align_corners=False
149
+ ).squeeze() # (H_pad, W_pad)
150
+
151
+ prob_batch = prob_resized.unsqueeze(0).unsqueeze(0) # (1, 1, H_pad, W_pad)
152
+
153
+ # Forward pass with probability guidance
154
+ try:
155
+ if hasattr(self.model, 'forward_with_prob'):
156
+ output = self.model.forward_with_prob(image_batch, prob_batch)
157
+ else:
158
+ # Fallback: concatenate prob as additional channel
159
+ input_tensor = torch.cat([image_batch, prob_batch], dim=1) # (1, 4, H_pad, W_pad)
160
+ output = self.model(input_tensor)
161
+ except Exception:
162
+ # Final fallback: just use image
163
+ output = self.model(image_batch)
164
+ else:
165
+ # Forward pass without probability guidance
166
+ output = self.model(image_batch)
167
+
168
+ # Extract alpha channel (assume model outputs alpha as last channel or single channel)
169
+ if output.shape[1] == 1:
170
+ alpha = output.squeeze(1) # (1, H_pad, W_pad)
171
+ elif output.shape[1] > 1:
172
+ alpha = output[:, -1:, :, :] # Take last channel as alpha
173
+ else:
174
+ raise ValueError(f"Unexpected model output shape: {output.shape}")
175
+
176
+ # Remove padding
177
+ alpha_unpadded = unpad_tensor(alpha, self.pad)
178
+
179
+ # Remove batch dimension and ensure 2D output
180
+ alpha_final = alpha_unpadded.squeeze(0) # (H, W)
181
+
182
+ # Ensure values are in [0, 1] range
183
+ alpha_final = torch.clamp(alpha_final, 0.0, 1.0)
184
+
185
+ self.frame_count += 1
186
+
187
+ return alpha_final
188
+
189
+ def clear_memory(self):
190
+ """Clear stored features for memory management"""
191
+ self.image_feature_store.clear()
192
+ self.frame_count = 0