HaisuGuan commited on
Commit
d18bc8b
·
1 Parent(s): c6fb835

模型代码

Browse files
Files changed (1) hide show
  1. models/ddm.py +3 -1
models/ddm.py CHANGED
@@ -100,6 +100,7 @@ class DenoisingDiffusion(object):
100
  self.writer = SummaryWriter(config.data.tensorboard)
101
  self.model = DiffusionUNet(config)
102
  self.model.to(self.device)
 
103
  if test:
104
  self.model = torch.nn.DataParallel(self.model)
105
  else:
@@ -129,7 +130,8 @@ class DenoisingDiffusion(object):
129
  self.model.load_state_dict(checkpoint['state_dict'], strict=True)
130
  self.optimizer.load_state_dict(checkpoint['optimizer'])
131
  self.ema_helper.load_state_dict(checkpoint['ema_helper'])
132
- self.scheduler.load_state_dict(checkpoint['scheduler'])
 
133
  if ema:
134
  self.ema_helper.ema(self.model)
135
  print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
 
100
  self.writer = SummaryWriter(config.data.tensorboard)
101
  self.model = DiffusionUNet(config)
102
  self.model.to(self.device)
103
+ self.test = test
104
  if test:
105
  self.model = torch.nn.DataParallel(self.model)
106
  else:
 
130
  self.model.load_state_dict(checkpoint['state_dict'], strict=True)
131
  self.optimizer.load_state_dict(checkpoint['optimizer'])
132
  self.ema_helper.load_state_dict(checkpoint['ema_helper'])
133
+ if not self.test:
134
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
135
  if ema:
136
  self.ema_helper.ema(self.model)
137
  print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))