Update model/Model_RGB.py
Browse files, map_location=torch.device('cpu')
- model/Model_RGB.py +2 -2
model/Model_RGB.py
CHANGED
|
@@ -324,7 +324,7 @@ class EstimateRGB(nn.Module):
|
|
| 324 |
|
| 325 |
def _make_model(self):
|
| 326 |
model = SAR(self.cfg.backbone, in_channels=self.cfg.in_channels).to(self.cfg.device)
|
| 327 |
-
checkpoint = torch.load(self.cfg.checkpoint)
|
| 328 |
if 'net' in checkpoint:
|
| 329 |
model.load_state_dict(checkpoint['net'])
|
| 330 |
else:
|
|
@@ -351,7 +351,7 @@ class EstimateRGB(nn.Module):
|
|
| 351 |
import onnxruntime as ort
|
| 352 |
print('export begin')
|
| 353 |
model = SAR(self.cfg.backbone, in_channels=self.cfg.in_channels).to(self.cfg.device)
|
| 354 |
-
checkpoint = torch.load(self.cfg.checkpoint)
|
| 355 |
if 'net' in checkpoint:
|
| 356 |
model.load_state_dict(checkpoint['net'])
|
| 357 |
else:
|
|
|
|
| 324 |
|
| 325 |
def _make_model(self):
|
| 326 |
model = SAR(self.cfg.backbone, in_channels=self.cfg.in_channels).to(self.cfg.device)
|
| 327 |
+
checkpoint = torch.load(self.cfg.checkpoint, map_location=torch.device('cpu'))
|
| 328 |
if 'net' in checkpoint:
|
| 329 |
model.load_state_dict(checkpoint['net'])
|
| 330 |
else:
|
|
|
|
| 351 |
import onnxruntime as ort
|
| 352 |
print('export begin')
|
| 353 |
model = SAR(self.cfg.backbone, in_channels=self.cfg.in_channels).to(self.cfg.device)
|
| 354 |
+
checkpoint = torch.load(self.cfg.checkpoint, map_location=torch.device('cpu'))
|
| 355 |
if 'net' in checkpoint:
|
| 356 |
model.load_state_dict(checkpoint['net'])
|
| 357 |
else:
|