|
import random |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
def crop_4_patches(image): |
|
crop_size = int(image.size[0]/2) |
|
return (image.crop((0, 0, crop_size, crop_size)), image.crop((0, crop_size, crop_size, 2*crop_size)), |
|
image.crop((crop_size, 0, 2*crop_size, crop_size)), image.crop((crop_size, crop_size, 2*crop_size, 2*crop_size))) |
|
|
|
|
|
def pre_processing(image, transform): |
|
high_level = [] |
|
middle_level = [] |
|
low_level = [] |
|
crops_4 = crop_4_patches(image) |
|
for c_4 in crops_4: |
|
crops_8 = crop_4_patches(c_4) |
|
high_level.append(transform(crops_8[0])) |
|
high_level.append(transform(crops_8[3])) |
|
for c_8 in [crops_8[1], crops_8[2]]: |
|
crops_16 = crop_4_patches(c_8) |
|
middle_level.append(transform(crops_16[0])) |
|
middle_level.append(transform(crops_16[3])) |
|
for c_16 in [crops_16[1], crops_16[2]]: |
|
crops_32 = crop_4_patches(c_16) |
|
low_level.append(transform(crops_32[0])) |
|
low_level.append(transform(crops_32[3])) |
|
return torch.stack(high_level), torch.stack(middle_level), torch.stack(low_level) |
|
|