Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| from typing import Tuple | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from scipy import ndimage | |
| import torch | |
| from torchvision.transforms import functional as tvf | |
| from pathlib import Path | |
| def sliced_mean(x, slice_size): | |
| cs_y = np.cumsum(x, axis=0) | |
| cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0) | |
| slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size | |
| cs_xy = np.cumsum(slices_y, axis=1) | |
| cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1) | |
| slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size | |
| return slices_xy | |
| def sliced_var(x, slice_size): | |
| x = x.astype('float64') | |
| return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2 | |
| def calculate_local_variance(img, var_window): | |
| """return local variance map with the same size as input image""" | |
| var = sliced_var(img, var_window) | |
| left_pad = var_window // 2 -1 | |
| right_pad = var_window -1 - left_pad | |
| var_padded = np.pad( | |
| var, | |
| pad_width=( | |
| (left_pad,right_pad), | |
| (left_pad,right_pad) | |
| )) | |
| return var_padded | |
| def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations. | |
| Parameters | |
| ---------- | |
| img : np.ndarray | |
| The input image from which crops are generated. | |
| mask : np.ndarray | |
| The binary mask indicating the region of interest in the image. | |
| crop_size : int, optional | |
| The size of the square crop. | |
| crop_scales : np.ndarray, optional | |
| An array of scale factors to apply to the crop size. | |
| samples_per_scale : int, optional | |
| Number of samples to generate per scale factor. | |
| use_variance_threshold : bool, optional | |
| Flag to use variance thresholding for selecting crop locations. | |
| Returns | |
| ------- | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | |
| A tuple containing the tensor of crops, their rotation angles, and scale factors. | |
| """ | |
| # pad | |
| pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1))) | |
| img_padded = np.pad(img, pad_size) | |
| mask_padded = np.pad(mask, pad_size) | |
| # distance map | |
| distance_map_padded = ndimage.distance_transform_edt(mask_padded) | |
| # TODO: adjust scales and samples_per_scale | |
| if use_variance_threshold: | |
| variance_window = min(crop_size//2, min(img.shape)) | |
| variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size) | |
| variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded)) | |
| variance_mask = variance_map_padded >= variance_median | |
| else: | |
| variance_mask = np.ones_like(mask_padded) | |
| # initilize output | |
| crops_granum = [] | |
| angles_granum = [] | |
| scales_granum = [] | |
| # loop over scales | |
| for scale in crop_scales: | |
| half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling | |
| crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation | |
| half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated | |
| possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1) | |
| if len(possible_indices) == 0: | |
| continue | |
| chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False) | |
| crops = [ | |
| img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices] | |
| ] | |
| # rotate | |
| rotation_angles = np.random.rand(len(crops))*180 - 90 | |
| crops = [ | |
| ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles) | |
| ] | |
| # add to output | |
| crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size | |
| angles_granum.extend(rotation_angles.tolist()) | |
| scales_granum.extend([scale]*len(crops)) | |
| if len(angles_granum) == 0: | |
| return [], [], [] | |
| crops_granum = torch.concat(crops_granum) | |
| angles_granum = torch.tensor(angles_granum, dtype=torch.float) | |
| scales_granum = torch.tensor(scales_granum, dtype=torch.float) | |
| return crops_granum, angles_granum, scales_granum | |
| def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False): | |
| """ | |
| Load an image and its mask from file paths and generate a batch of cropped images. | |
| Parameters | |
| ---------- | |
| img_path : str | |
| Path to the input image. | |
| mask_path : str, optional | |
| Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'. | |
| use_variance_threshold : bool, optional | |
| Flag to use variance thresholding for selecting crop locations. | |
| Returns | |
| ------- | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | |
| A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path. | |
| """ | |
| if mask_path is None: | |
| mask_path = str(Path(img_path).with_suffix('.npy')) | |
| mask = np.load(mask_path) | |
| img = np.array(Image.open(img_path))[:,:,0] | |
| return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold) | |