0713-1459
Browse files- diffusion.py +1 -1
diffusion.py
CHANGED
|
@@ -532,7 +532,7 @@ class DDPM21CM:
|
|
| 532 |
self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 533 |
else:
|
| 534 |
self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 535 |
-
print(f"
|
| 536 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 537 |
# nn_model.train()
|
| 538 |
# self.nn_model.to(self.ddpm.device)
|
|
|
|
| 532 |
self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 533 |
else:
|
| 534 |
self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 535 |
+
print(f"device {torch.cuda.current_device()} resumed nn_model from {file}")
|
| 536 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 537 |
# nn_model.train()
|
| 538 |
# self.nn_model.to(self.ddpm.device)
|