Commit
·
ae29148
1
Parent(s):
c9d2859
Improve weight loading verification and skip configs with too many missing keys
Browse files- app/pytorch_colorizer.py +18 -9
app/pytorch_colorizer.py
CHANGED
|
@@ -78,10 +78,11 @@ class ResNetGenerator(nn.Module):
|
|
| 78 |
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
| 79 |
model += [nn.Tanh()]
|
| 80 |
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
def forward(self, input):
|
| 84 |
-
return self.
|
| 85 |
|
| 86 |
|
| 87 |
class UNetGenerator(nn.Module):
|
|
@@ -211,15 +212,18 @@ class PyTorchColorizer:
|
|
| 211 |
|
| 212 |
# Log state dict keys to understand model structure
|
| 213 |
if isinstance(state_dict, dict):
|
| 214 |
-
keys = list(state_dict.keys())[:
|
| 215 |
logger.info(f"Model state_dict keys (sample): {keys}")
|
| 216 |
logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
|
| 217 |
|
| 218 |
# Try to infer architecture from key names
|
|
|
|
|
|
|
| 219 |
if any('down' in k.lower() or 'up' in k.lower() for k in keys):
|
| 220 |
logger.info("Detected U-Net style architecture")
|
| 221 |
if any('resnet' in k.lower() for k in keys):
|
| 222 |
logger.info("Detected ResNet style architecture")
|
|
|
|
| 223 |
|
| 224 |
except Exception as e:
|
| 225 |
logger.error(f"Failed to load model file: {e}")
|
|
@@ -250,12 +254,17 @@ class PyTorchColorizer:
|
|
| 250 |
|
| 251 |
# Try strict loading first
|
| 252 |
try:
|
| 253 |
-
model.load_state_dict(state_dict, strict=
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
model.eval()
|
| 261 |
model.to(self.device)
|
|
|
|
| 78 |
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
| 79 |
model += [nn.Tanh()]
|
| 80 |
|
| 81 |
+
# Wrap in Sequential with 'layers' attribute to match state_dict structure
|
| 82 |
+
self.layers = nn.Sequential(*model)
|
| 83 |
|
| 84 |
def forward(self, input):
|
| 85 |
+
return self.layers(input)
|
| 86 |
|
| 87 |
|
| 88 |
class UNetGenerator(nn.Module):
|
|
|
|
| 212 |
|
| 213 |
# Log state dict keys to understand model structure
|
| 214 |
if isinstance(state_dict, dict):
|
| 215 |
+
keys = list(state_dict.keys())[:30] # First 30 keys
|
| 216 |
logger.info(f"Model state_dict keys (sample): {keys}")
|
| 217 |
logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
|
| 218 |
|
| 219 |
# Try to infer architecture from key names
|
| 220 |
+
if any('layers' in k.lower() for k in keys):
|
| 221 |
+
logger.info("Detected sequential 'layers' structure")
|
| 222 |
if any('down' in k.lower() or 'up' in k.lower() for k in keys):
|
| 223 |
logger.info("Detected U-Net style architecture")
|
| 224 |
if any('resnet' in k.lower() for k in keys):
|
| 225 |
logger.info("Detected ResNet style architecture")
|
| 226 |
+
|
| 227 |
|
| 228 |
except Exception as e:
|
| 229 |
logger.error(f"Failed to load model file: {e}")
|
|
|
|
| 254 |
|
| 255 |
# Try strict loading first
|
| 256 |
try:
|
| 257 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 258 |
+
if not missing_keys and not unexpected_keys:
|
| 259 |
+
logger.info(f"✅ Successfully loaded {model_type} model with perfect matching: {config_copy}")
|
| 260 |
+
else:
|
| 261 |
+
logger.warning(f"⚠️ Loaded {model_type} model with mismatches - Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
|
| 262 |
+
if len(missing_keys) > len(state_dict) * 0.5: # If more than 50% missing, skip
|
| 263 |
+
logger.warning(f"Skipping this config - too many missing keys ({len(missing_keys)}/{len(state_dict)})")
|
| 264 |
+
continue
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.debug(f"Failed to load state_dict: {e}")
|
| 267 |
+
continue
|
| 268 |
|
| 269 |
model.eval()
|
| 270 |
model.to(self.device)
|