0713-1700
Browse files- diffusion.py +4 -0
diffusion.py
CHANGED
|
@@ -631,9 +631,13 @@ def sample(rank, world_size, config, num_new_img, max_num_img_per_gpu, return_di
|
|
| 631 |
|
| 632 |
samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
|
| 633 |
|
|
|
|
|
|
|
| 634 |
if rank == 1:
|
| 635 |
return_dict['samples'] = samples
|
| 636 |
|
|
|
|
|
|
|
| 637 |
dist.destroy_process_group()
|
| 638 |
|
| 639 |
|
|
|
|
| 631 |
|
| 632 |
samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
|
| 633 |
|
| 634 |
+
print(f"device {torch.current_device()}, rank = {rank}, samples.shape = {samples.shape}")
|
| 635 |
+
|
| 636 |
if rank == 1:
|
| 637 |
return_dict['samples'] = samples
|
| 638 |
|
| 639 |
+
print(f"device {torch.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
|
| 640 |
+
|
| 641 |
dist.destroy_process_group()
|
| 642 |
|
| 643 |
|