Spaces:
Sleeping
Sleeping
File size: 1,951 Bytes
7e2a2a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import random
import torch
def load_patches(patch_batch_size, batch_size, patch_size, num_patch, diff_patch, index, data, transforms, return_dict):
if patch_size > 0:
assert (patch_batch_size % batch_size == 0), \
"patch_batch_size is not divisible by batch_size."
if 'paired_A' in return_dict or 'paired_B' in return_dict:
if not diff_patch:
# load patch from current image
patchA = return_dict['paired_A'].clone()
patchB = return_dict['paired_B'].clone()
else:
# load patch from a different image
pathA = data['paired_A_path'][(index + 1) % len(data['paired_A_path'])]
pathB = data['paired_B_path'][(index + 1) % len(data['paired_B_path'])]
patchA, patchB = transforms['paired'](pathA, pathB)
else:
if not diff_patch:
# load patch from current image
patchA = return_dict['unpaired_A'].clone()
patchB = return_dict['unpaired_B'].clone()
else:
# load patch from a different image
pathA = data['unpaired_A_path'][(index + 1) % len(data['unpaired_A_path'])]
pathB = data['unpaired_B_path'][(index + 1) % len(data['unpaired_B_path'])]
patchA, patchB = transforms['unpaired'](pathA, pathB)
# crop patch
patchAs = []
patchBs = []
_, h, w = patchA.size()
for _ in range(num_patch):
r = random.randint(0, h - patch_size - 1)
c = random.randint(0, w - patch_size - 1)
patchAs.append(patchA[:, r:r + patch_size, c:c + patch_size])
patchBs.append(patchB[:, r:r + patch_size, c:c + patch_size])
patchAs = torch.cat(patchAs, 0)
patchBs = torch.cat(patchBs, 0)
return_dict['patch_A'] = patchAs
return_dict['patch_B'] = patchBs
|