Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from scipy import signal | |
| from scipy import ndimage | |
| from scipy.fftpack import next_fast_len | |
| from skimage.transform import rotate | |
| from skimage._shared.utils import convert_to_float | |
| from skimage.transform import warp | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| from copy import deepcopy | |
| def get_directional_std(image, theta=None,*, preserve_range=False): | |
| if image.ndim != 2: | |
| raise ValueError('The input image must be 2-D') | |
| if theta is None: | |
| theta = np.arange(180) | |
| image = convert_to_float(image.copy(), preserve_range) #TODO: needed? | |
| shape_min = min(image.shape) | |
| img_shape = np.array(image.shape) | |
| # Crop image to make it square | |
| slices = tuple(slice(int(np.ceil(excess / 2)), | |
| int(np.ceil(excess / 2) + shape_min)) | |
| if excess > 0 else slice(None) | |
| for excess in (img_shape - shape_min)) | |
| image = image[slices] | |
| shape_min = min(image.shape) | |
| img_shape = np.array(image.shape) | |
| radius = shape_min // 2 | |
| coords = np.array(np.ogrid[:image.shape[0], :image.shape[1]], | |
| dtype=object) | |
| dist = ((coords - img_shape // 2) ** 2).sum(0) | |
| outside_reconstruction_circle = dist > radius ** 2 | |
| image[outside_reconstruction_circle] = 0 | |
| valid_square_slice = slice(int(np.ceil(radius*(1-1/np.sqrt(2)))), int(np.ceil(radius*(1+1/np.sqrt(2)))) ) | |
| # padded_image is always square | |
| if image.shape[0] != image.shape[1]: | |
| raise ValueError('padded_image must be a square') | |
| center = image.shape[0] // 2 | |
| result = np.zeros(len(theta)) | |
| for i, angle in enumerate(np.deg2rad(theta)): | |
| cos_a, sin_a = np.cos(angle), np.sin(angle) | |
| R = np.array([[cos_a, sin_a, -center * (cos_a + sin_a - 1)], | |
| [-sin_a, cos_a, -center * (cos_a - sin_a - 1)], | |
| [0, 0, 1]]) | |
| rotated = warp(image, R, clip=False) | |
| result[i] = rotated[valid_square_slice, valid_square_slice].std(axis=0).mean() | |
| return result | |
| def acf2d(x, nlags=None): | |
| xo = x - x.mean(axis=0) | |
| n = len(x) | |
| if nlags is None: | |
| nlags = n -1 | |
| lag_len = nlags | |
| xi = np.arange(1, n + 1) | |
| d = np.expand_dims(np.hstack((xi, xi[:-1][::-1])),1) | |
| nobs = len(xo) | |
| n = next_fast_len(2 * nobs + 1) | |
| Frf = np.fft.fft(xo, n=n, axis=0) | |
| acov = np.fft.ifft(Frf * np.conjugate(Frf), axis=0)[:nobs] / d[nobs - 1 :] | |
| acov = acov.real | |
| ac = acov[: nlags + 1] / acov[:1] | |
| return ac | |
| def get_period(acf_table, n_samples=50): | |
| #TODO: use peak heights to select best candidates. use std to eliminate outliers | |
| period_candidates = [] | |
| period_candidates_hights = [] | |
| for i in np.random.randint(0, acf_table.shape[1], min(acf_table.shape[1], n_samples)): | |
| peaks = signal.find_peaks(acf_table[:,i])[0] | |
| if len(peaks) == 0: | |
| continue | |
| peak_idx = peaks[0] | |
| period_candidates.append(peak_idx) | |
| period_candidates_hights.append(acf_table[peak_idx,i]) | |
| period_candidates = np.array(period_candidates) | |
| period_candidates_hights = np.array(period_candidates_hights) | |
| if len(period_candidates) == 0: | |
| return np.nan, np.nan | |
| elif len(period_candidates) == 1: | |
| return period_candidates[0], np.nan | |
| q1, q3 = np.quantile(period_candidates, [0.25, 0.75]) | |
| candidates_std = np.std(period_candidates[(period_candidates>=q1)&(period_candidates<=q3)]) | |
| # return period_candidates, period_candidates_hights | |
| return np.median(period_candidates), candidates_std | |
| def get_rotation_with_confidence(padded_image, blur_size=4, make_plots=True): | |
| std_by_angle = get_directional_std(cv2.blur(padded_image, (blur_size,blur_size))) | |
| rotation_angle = np.argmin(std_by_angle) | |
| rotation_quality = 1 - np.min(std_by_angle)/np.median(std_by_angle) | |
| if make_plots: | |
| plt.plot(std_by_angle) | |
| plt.axvline(rotation_angle, c='k') | |
| plt.title(f'quality: {rotation_quality:0.2f}') | |
| return rotation_angle, rotation_quality | |
| def calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=True): | |
| autocorrelation = acf2d(cv2.blur(oriented_img.T, blur_kernel)) | |
| if make_plots: | |
| fig, axs = plt.subplots(ncols=2, figsize=(12,6)) | |
| axs[0].imshow(autocorrelation) | |
| axs[1].plot(autocorrelation.sum(axis=1)) | |
| return autocorrelation | |
| def get_period_with_confidence(autocorrelation_tab, n_samples=30): | |
| period, period_std = get_period(autocorrelation_tab, n_samples=n_samples) | |
| if period_std == np.nan: | |
| period_confidence = 0.001 | |
| else: | |
| period_confidence = period/(period+2*period_std) | |
| return period, period_confidence | |
| def calculate_white_fraction(img, blur_size=4, make_plots=True): #TODO: add mask | |
| blurred = cv2.blur(img, (blur_size, blur_size)) | |
| blurred_sum = blurred.sum(axis=0) | |
| lower, upper = np.quantile(blurred_sum, [0.15, 0.85]) | |
| sign = blurred_sum > (lower+upper)/2 | |
| sign_change = sign[:-1] != sign[1:] | |
| sign_change_indices = np.where(sign_change)[0] | |
| if len(sign_change_indices) >= 2 + (sign[-1] == sign[0]): | |
| cut_first = sign_change_indices[0]+1 | |
| if sign[-1] == sign[0]: | |
| cut_last = sign_change_indices[-2] | |
| else: | |
| cut_last = sign_change_indices[-1] | |
| white_fraction = np.mean(sign[cut_first:cut_last]) | |
| else: | |
| white_fraction = np.nan | |
| cut_first, cut_last = None, None | |
| if make_plots: | |
| fig, axs = plt.subplots(ncols=3, figsize=(16,6)) | |
| blurred_sum_normalized = blurred_sum - blurred_sum.min() | |
| blurred_sum_normalized /= blurred_sum_normalized.max() | |
| axs[0].plot(blurred_sum_normalized) | |
| axs[0].plot(sign) | |
| axs[1].plot(blurred_sum_normalized[cut_first:cut_last]) | |
| axs[1].plot(sign[cut_first:cut_last]) | |
| axs[2].imshow(img, cmap='gray') | |
| for i, idx in enumerate(sign_change_indices): | |
| plt.axvline(idx, c=['r', 'lime'][i%2]) | |
| fig.suptitle(f'fraction: {white_fraction:0.2f}') | |
| return white_fraction | |
| def process_img_crop(img, nm_per_px=1, make_plots=False, return_extra=False): | |
| # image must be square | |
| assert img.shape[0] == img.shape[1] | |
| crop_size = img.shape[0] | |
| # find orientation | |
| rotation_angle, rotation_quality = get_rotation_with_confidence(img, blur_size=4, make_plots=make_plots) | |
| # rotate and crop image | |
| crop_margin = int((1 - 1/np.sqrt(2))*crop_size*0.5) | |
| oriented_img = rotate(img, -rotation_angle)[2*crop_margin:-crop_margin, crop_margin:-crop_margin] | |
| # calculate autocorrelation | |
| autocorrelation = calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=make_plots) | |
| # find period | |
| period, period_confidence = get_period_with_confidence(autocorrelation) | |
| if make_plots: | |
| print(f'period: {period}, confidence: {period_confidence}') | |
| # find white fraction | |
| white_fraction = calculate_white_fraction(oriented_img, make_plots=make_plots) | |
| white_width = white_fraction*period | |
| result = { | |
| 'direction': rotation_angle, | |
| 'direction confidence': rotation_quality, | |
| 'period': period*nm_per_px, | |
| 'period confidence': period_confidence, | |
| 'lumen width': white_width*nm_per_px | |
| } | |
| if return_extra: | |
| result['extra'] = { | |
| 'autocorrelation': autocorrelation, | |
| 'oriented_img': oriented_img | |
| } | |
| return result | |
| def get_top_k(a, k): | |
| ind = np.argpartition(a, -k)[-k:] | |
| return a[ind] | |
| def get_crops(img, distance_map, crop_size, N_sample): | |
| crop_r= np.sqrt(2)*crop_size / 2 | |
| possible_positions_y, possible_positions_x = np.where(distance_map >= crop_r) | |
| no_edge_mask = (possible_positions_y>crop_r) & \ | |
| (possible_positions_x>crop_r) & \ | |
| (possible_positions_y<(distance_map.shape[0]-crop_r)) & \ | |
| (possible_positions_x<(distance_map.shape[1]-crop_r)) | |
| possible_positions_x = possible_positions_x[no_edge_mask] | |
| possible_positions_y = possible_positions_y[no_edge_mask] | |
| N_available = len(possible_positions_x) | |
| positions_indices = np.random.choice(np.arange(N_available), min(N_sample, N_available), replace=False) | |
| for idx in positions_indices: | |
| yield img[possible_positions_y[idx]-crop_size//2:possible_positions_y[idx]+crop_size//2,possible_positions_x[idx]-crop_size//2:possible_positions_x[idx]+crop_size//2].copy() | |
| def sliced_mean(x, slice_size): | |
| cs_y = np.cumsum(x, axis=0) | |
| cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0) | |
| slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size | |
| cs_xy = np.cumsum(slices_y, axis=1) | |
| cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1) | |
| slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size | |
| return slices_xy | |
| def sliced_var(x, slice_size): | |
| x = x.astype('float64') | |
| return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2 | |
| def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=1.0, variance_p=2): | |
| granum_occupancy = sliced_mean(granum_mask, crop_size) | |
| possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1) | |
| if variance_p == 0: | |
| p = np.ones(len(possible_indices)) | |
| else: | |
| variance_map = sliced_var(granum_image, crop_size) | |
| p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p | |
| p /= np.sum(p) | |
| chosen_indices = np.random.choice( | |
| np.arange(len(possible_indices)), | |
| min(len(possible_indices), n_samples), | |
| replace=False, | |
| p = p | |
| ) | |
| crops = [] | |
| for crop_idx, idx in enumerate(chosen_indices): | |
| crops.append( | |
| granum_image[ | |
| possible_indices[idx,0]:possible_indices[idx,0]+crop_size, | |
| possible_indices[idx,1]:possible_indices[idx,1]+crop_size | |
| ] | |
| ) | |
| return np.array(crops) | |
| def calculate_distance_map(mask): | |
| padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False) | |
| distance_map_padded = ndimage.distance_transform_edt(padded) | |
| return distance_map_padded[1:-1,1:-1] | |
| def measure_object( | |
| img, mask, | |
| nm_per_px=1, n_tries = 3, | |
| direction_thr_min = 0.07, direction_thr_enough = 0.1, | |
| crop_size = 200, | |
| **kwargs): | |
| distance_map = calculate_distance_map(mask) | |
| crop_size = min(crop_size, int(min(get_top_k(distance_map.flatten(), n_tries)*0.5**0.5))) | |
| direction_confidence = 0 | |
| best_stripes_data = {} | |
| for i, img_crop in enumerate(get_crops(img, distance_map, crop_size, n_tries)): | |
| stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px) | |
| if stripes_data['direction confidence'] >= direction_confidence: | |
| best_stripes_data = deepcopy(stripes_data) | |
| direction_confidence = stripes_data['direction confidence'] | |
| if direction_confidence > direction_thr_enough: | |
| break | |
| result = best_stripes_data | |
| if direction_confidence >= direction_thr_min: | |
| mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool') | |
| idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])] | |
| idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])] | |
| result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
| result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
| # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True) | |
| # else: | |
| # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False) | |
| # result.update(**measurements) | |
| # N_layers = result['height'] / result['period'] | |
| # if np.isfinite(N_layers): | |
| # N_layers = round(N_layers) | |
| return result | |
| # def measure_object( | |
| # img, mask, | |
| # nm_per_px=1, n_tries = 3, | |
| # direction_thr_min = 0.07, direction_thr_enough = 0.1, | |
| # crop_size = 200, | |
| # **kwargs): | |
| # distance_map = calculate_distance_map(mask) | |
| # crop_size = min(crop_size, int((min(get_top_k(distance_map.flatten(), n_tries)*0.5)**0.5))) | |
| # direction_confidence = 0 | |
| # best_stripes_data = {} | |
| # for i, img_crop in enumerate(select_samples(img, mask, crop_size=crop_size, n_samples=n_tries)): | |
| # stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px) | |
| # if stripes_data['direction_confidence'] >= direction_confidence: | |
| # best_stripes_data = deepcopy(stripes_data) | |
| # direction_confidence = stripes_data['direction_confidence'] | |
| # if direction_confidence > direction_thr_enough: | |
| # break | |
| # result = best_stripes_data | |
| # if direction_confidence >= direction_thr_min: | |
| # mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool') | |
| # idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])] | |
| # idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])] | |
| # result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
| # result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
| # # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True) | |
| # # else: | |
| # # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False) | |
| # # result.update(**measurements) | |
| # # N_layers = result['height'] / result['period'] | |
| # # if np.isfinite(N_layers): | |
| # # N_layers = round(N_layers) | |
| # return result #{**measurements, **best_stripes_data, 'N layers': N_layers} |