Spaces:
Sleeping
Sleeping
import random | |
import torch | |
def load_patches(patch_batch_size, batch_size, patch_size, num_patch, diff_patch, index, data, transforms, return_dict): | |
if patch_size > 0: | |
assert (patch_batch_size % batch_size == 0), \ | |
"patch_batch_size is not divisible by batch_size." | |
if 'paired_A' in return_dict or 'paired_B' in return_dict: | |
if not diff_patch: | |
# load patch from current image | |
patchA = return_dict['paired_A'].clone() | |
patchB = return_dict['paired_B'].clone() | |
else: | |
# load patch from a different image | |
pathA = data['paired_A_path'][(index + 1) % len(data['paired_A_path'])] | |
pathB = data['paired_B_path'][(index + 1) % len(data['paired_B_path'])] | |
patchA, patchB = transforms['paired'](pathA, pathB) | |
else: | |
if not diff_patch: | |
# load patch from current image | |
patchA = return_dict['unpaired_A'].clone() | |
patchB = return_dict['unpaired_B'].clone() | |
else: | |
# load patch from a different image | |
pathA = data['unpaired_A_path'][(index + 1) % len(data['unpaired_A_path'])] | |
pathB = data['unpaired_B_path'][(index + 1) % len(data['unpaired_B_path'])] | |
patchA, patchB = transforms['unpaired'](pathA, pathB) | |
# crop patch | |
patchAs = [] | |
patchBs = [] | |
_, h, w = patchA.size() | |
for _ in range(num_patch): | |
r = random.randint(0, h - patch_size - 1) | |
c = random.randint(0, w - patch_size - 1) | |
patchAs.append(patchA[:, r:r + patch_size, c:c + patch_size]) | |
patchBs.append(patchB[:, r:r + patch_size, c:c + patch_size]) | |
patchAs = torch.cat(patchAs, 0) | |
patchBs = torch.cat(patchBs, 0) | |
return_dict['patch_A'] = patchAs | |
return_dict['patch_B'] = patchBs | |