模型代码
Browse files- models/ddm.py +3 -1
models/ddm.py
CHANGED
@@ -100,6 +100,7 @@ class DenoisingDiffusion(object):
|
|
100 |
self.writer = SummaryWriter(config.data.tensorboard)
|
101 |
self.model = DiffusionUNet(config)
|
102 |
self.model.to(self.device)
|
|
|
103 |
if test:
|
104 |
self.model = torch.nn.DataParallel(self.model)
|
105 |
else:
|
@@ -129,7 +130,8 @@ class DenoisingDiffusion(object):
|
|
129 |
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
|
130 |
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
131 |
self.ema_helper.load_state_dict(checkpoint['ema_helper'])
|
132 |
-
self.
|
|
|
133 |
if ema:
|
134 |
self.ema_helper.ema(self.model)
|
135 |
print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
|
|
|
100 |
self.writer = SummaryWriter(config.data.tensorboard)
|
101 |
self.model = DiffusionUNet(config)
|
102 |
self.model.to(self.device)
|
103 |
+
self.test = test
|
104 |
if test:
|
105 |
self.model = torch.nn.DataParallel(self.model)
|
106 |
else:
|
|
|
130 |
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
|
131 |
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
132 |
self.ema_helper.load_state_dict(checkpoint['ema_helper'])
|
133 |
+
if not self.test:
|
134 |
+
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
135 |
if ema:
|
136 |
self.ema_helper.ema(self.model)
|
137 |
print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
|