Update unet_wo_t/DDPM_Unet_wo_t_sample.py
Browse files
unet_wo_t/DDPM_Unet_wo_t_sample.py
CHANGED
@@ -909,17 +909,17 @@ save_each = 1
|
|
909 |
|
910 |
diffusion_model = diffusion_model.to(device)
|
911 |
|
912 |
-
last_trained_path = 'unet_wo_t
|
913 |
-
diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path))['model'])
|
914 |
|
915 |
-
sample_path = 'unet_wo_t/sample'
|
916 |
|
917 |
if not os.path.exists(sample_path):
|
918 |
os.mkdir(sample_path)
|
919 |
|
920 |
-
num_sample =
|
921 |
sample_batch = 16
|
922 |
-
count =
|
923 |
|
924 |
if num_sample % sample_batch != 0:
|
925 |
num_sample = num_sample + (sample_batch - (num_sample % sample_batch))
|
|
|
909 |
|
910 |
diffusion_model = diffusion_model.to(device)
|
911 |
|
912 |
+
last_trained_path = '/content/DDPM_ResNet_Unet/unet_wo_t/model/epoch_30.pth'
|
913 |
+
diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path), map_location=device)['model'])
|
914 |
|
915 |
+
sample_path = '/content/DDPM_ResNet_Unet/unet_wo_t/sample'
|
916 |
|
917 |
if not os.path.exists(sample_path):
|
918 |
os.mkdir(sample_path)
|
919 |
|
920 |
+
num_sample = 10000
|
921 |
sample_batch = 16
|
922 |
+
count = 0
|
923 |
|
924 |
if num_sample % sample_batch != 0:
|
925 |
num_sample = num_sample + (sample_batch - (num_sample % sample_batch))
|