MogensR commited on
Commit
508dfbb
·
verified ·
1 Parent(s): 5a3f7f6

Update matanyone_fixed/utils/get_default_model.py

Browse files
matanyone_fixed/utils/get_default_model.py CHANGED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed MatAnyone Model Interface
3
+ Simplified and reliable model loading
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Union, Optional
9
+ from pathlib import Path
10
+
11
+
12
+ class SimpleMatteModel(nn.Module):
13
+ """
14
+ Simplified matting model that ensures proper tensor handling
15
+ """
16
+
17
+ def __init__(self, backbone_channels: int = 3):
18
+ super().__init__()
19
+
20
+ # Simple encoder-decoder architecture
21
+ self.encoder = nn.Sequential(
22
+ # Initial conv
23
+ nn.Conv2d(backbone_channels, 64, 7, padding=3),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(64, 64, 3, padding=1),
26
+ nn.ReLU(inplace=True),
27
+
28
+ # Downsampling blocks
29
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(128, 128, 3, padding=1),
32
+ nn.ReLU(inplace=True),
33
+
34
+ nn.Conv2d(128, 256, 3, stride=2, padding=1),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv2d(256, 256, 3, padding=1),
37
+ nn.ReLU(inplace=True),
38
+
39
+ # Bottleneck
40
+ nn.Conv2d(256, 512, 3, stride=2, padding=1),
41
+ nn.ReLU(inplace=True),
42
+ nn.Conv2d(512, 512, 3, padding=1),
43
+ nn.ReLU(inplace=True),
44
+ )
45
+
46
+ self.decoder = nn.Sequential(
47
+ # Upsampling blocks
48
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
49
+ nn.ReLU(inplace=True),
50
+ nn.Conv2d(256, 256, 3, padding=1),
51
+ nn.ReLU(inplace=True),
52
+
53
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
54
+ nn.ReLU(inplace=True),
55
+ nn.Conv2d(128, 128, 3, padding=1),
56
+ nn.ReLU(inplace=True),
57
+
58
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
59
+ nn.ReLU(inplace=True),
60
+ nn.Conv2d(64, 64, 3, padding=1),
61
+ nn.ReLU(inplace=True),
62
+
63
+ # Final prediction
64
+ nn.Conv2d(64, 1, 3, padding=1),
65
+ nn.Sigmoid()
66
+ )
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Forward pass ensuring tensor operations
71
+
72
+ Args:
73
+ x: Input tensor (B, C, H, W)
74
+
75
+ Returns:
76
+ torch.Tensor: Alpha matte (B, 1, H, W)
77
+ """
78
+ if not isinstance(x, torch.Tensor):
79
+ raise TypeError(f"Input must be torch.Tensor, got {type(x)}")
80
+
81
+ # Encode
82
+ features = self.encoder(x)
83
+
84
+ # Decode
85
+ alpha = self.decoder(features)
86
+
87
+ return alpha
88
+
89
+ def forward_with_prob(self, image: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Forward pass with probability guidance
92
+
93
+ Args:
94
+ image: Input image (B, 3, H, W)
95
+ prob: Probability mask (B, 1, H, W)
96
+
97
+ Returns:
98
+ torch.Tensor: Alpha matte (B, 1, H, W)
99
+ """
100
+ if not isinstance(image, torch.Tensor) or not isinstance(prob, torch.Tensor):
101
+ raise TypeError("Both inputs must be torch.Tensor")
102
+
103
+ # Concatenate image and probability as input
104
+ x = torch.cat([image, prob], dim=1) # (B, 4, H, W)
105
+
106
+ # Forward pass
107
+ return self.forward(x)
108
+
109
+
110
+ def load_pretrained_weights(model: nn.Module, checkpoint_path: Union[str, Path]) -> nn.Module:
111
+ """
112
+ Load pretrained weights with error handling
113
+
114
+ Args:
115
+ model: Model to load weights into
116
+ checkpoint_path: Path to checkpoint file
117
+
118
+ Returns:
119
+ nn.Module: Model with loaded weights
120
+ """
121
+ checkpoint_path = Path(checkpoint_path)
122
+
123
+ if not checkpoint_path.exists():
124
+ print(f"Warning: Checkpoint not found at {checkpoint_path}")
125
+ print("Using randomly initialized weights")
126
+ return model
127
+
128
+ try:
129
+ # Load checkpoint
130
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
131
+
132
+ # Extract state dict
133
+ if isinstance(checkpoint, dict):
134
+ if 'state_dict' in checkpoint:
135
+ state_dict = checkpoint['state_dict']
136
+ elif 'model' in checkpoint:
137
+ state_dict = checkpoint['model']
138
+ else:
139
+ state_dict = checkpoint
140
+ else:
141
+ state_dict = checkpoint
142
+
143
+ # Load weights with flexible key matching
144
+ model_dict = model.state_dict()
145
+ matched_dict = {}
146
+
147
+ for key, value in state_dict.items():
148
+ # Remove module prefix if present
149
+ clean_key = key.replace('module.', '')
150
+
151
+ if clean_key in model_dict:
152
+ if model_dict[clean_key].shape == value.shape:
153
+ matched_dict[clean_key] = value
154
+ else:
155
+ print(f"Shape mismatch for {clean_key}: model {model_dict[clean_key].shape} vs checkpoint {value.shape}")
156
+ else:
157
+ print(f"Key not found in model: {clean_key}")
158
+
159
+ # Load matched weights
160
+ model_dict.update(matched_dict)
161
+ model.load_state_dict(model_dict)
162
+
163
+ print(f"Loaded {len(matched_dict)} weights from {checkpoint_path}")
164
+
165
+ except Exception as e:
166
+ print(f"Error loading checkpoint: {e}")
167
+ print("Using randomly initialized weights")
168
+
169
+ return model
170
+
171
+
172
+ def get_matanyone_model(checkpoint_path: Union[str, Path],
173
+ device: Union[str, torch.device] = 'cpu',
174
+ backbone_channels: int = 3) -> nn.Module:
175
+ """
176
+ FIXED MODEL LOADING: Create and load MatAnyone model
177
+
178
+ Args:
179
+ checkpoint_path: Path to model checkpoint
180
+ device: Device to load model on
181
+ backbone_channels: Number of input channels (3 for RGB, 4 for RGB + prob)
182
+
183
+ Returns:
184
+ nn.Module: Loaded model
185
+ """
186
+ # Determine input channels based on usage
187
+ # If we're using probability guidance, we need 4 channels (RGB + prob)
188
+ # Otherwise, 3 channels (RGB only)
189
+ input_channels = 4 # Support both RGB and RGB+prob inputs
190
+
191
+ # Create model
192
+ model = SimpleMatteModel(backbone_channels=input_channels)
193
+
194
+ # Load pretrained weights if available
195
+ model = load_pretrained_weights(model, checkpoint_path)
196
+
197
+ # Move to device
198
+ device = torch.device(device)
199
+ model = model.to(device)
200
+ model.eval()
201
+
202
+ print(f"MatAnyone model loaded on {device}")
203
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
204
+
205
+ return model
206
+
207
+
208
+ # Fallback for compatibility with original MatAnyone interface
209
+ def build_model(*args, **kwargs):
210
+ """Compatibility function for original MatAnyone interface"""
211
+ return get_matanyone_model(*args, **kwargs)
212
+
213
+
214
+ class ModelWrapper:
215
+ """
216
+ Wrapper to match original MatAnyone model interface
217
+ """
218
+
219
+ def __init__(self, model: nn.Module):
220
+ self.model = model
221
+ self.device = next(model.parameters()).device
222
+
223
+ def __call__(self, *args, **kwargs):
224
+ return self.model(*args, **kwargs)
225
+
226
+ def eval(self):
227
+ return self.model.eval()
228
+
229
+ def train(self, mode=True):
230
+ return self.model.train(mode)
231
+
232
+ def to(self, device):
233
+ return ModelWrapper(self.model.to(device))
234
+
235
+ def parameters(self):
236
+ return self.model.parameters()
237
+
238
+ def state_dict(self):
239
+ return self.model.state_dict()
240
+
241
+ def load_state_dict(self, state_dict):
242
+ return self.model.load_state_dict(state_dict)