Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import pytorch_lightning as pl | |
| import timm | |
| # from hydra.utils import instantiate | |
| from scipy.stats import circmean, circstd | |
| from scipy import ndimage | |
| from skimage.transform import resize | |
| from sampling import get_crop_batch | |
| from granum_utils import get_circle_mask | |
| import image_transforms | |
| from envelope_correction import calculate_best_angle_from_mask | |
| ## loss | |
| class ConfidenceScaler: | |
| def __init__(self, data: np.ndarray): | |
| self.data = data | |
| self.data.sort() | |
| def __call__(self, x): | |
| return np.searchsorted(self.data,x) / len(self.data) | |
| class PatchedPredictor: | |
| def __init__(self, | |
| model, | |
| crop_size=96, | |
| normalization=dict(mean=0,std=1), | |
| n_samples=32, | |
| mask=None,# 'circle', None | |
| filter_outliers=True, | |
| apply_radon=False, # apply Radon transform | |
| radon_size=(128,128), # (int, int) reshape radon transformed image to this shape, | |
| angle_confidence_threshold=0, | |
| use_envelope_correction=True | |
| ): | |
| self.model = model | |
| self.crop_size = crop_size | |
| self.normalization = normalization | |
| self.n_samples = n_samples | |
| if mask not in [None, 'circle']: | |
| raise ValueError(f'unknown mask {mask}') | |
| self.mask = mask | |
| self.filter_outliers = filter_outliers | |
| self.apply_radon = apply_radon | |
| self.radon_size = radon_size | |
| self.angle_confidence_threshold = angle_confidence_threshold | |
| self.use_envelope_correction = use_envelope_correction | |
| def __call__(self, img: np.ndarray, mask: np.ndarray): | |
| pl.seed_everything(44) | |
| # get crops with different scales and rotation | |
| crops, angles_tta, scales_tta = get_crop_batch( | |
| img, mask, | |
| crop_size=self.crop_size, | |
| samples_per_scale=self.n_samples, | |
| use_variance_threshold=True | |
| ) | |
| if len(crops) == 0: | |
| return dict( | |
| est_angle=np.nan, | |
| est_angle_confidence=0., | |
| ) | |
| # preprocess batch (normalize, mask, transform) | |
| batch = self._preprocess_batch(crops) | |
| # predict for batch - we don't use period and lumen anymore | |
| preds_direction, preds_period, preds_lumen_width = self.model(batch) | |
| # # convert to numpy | |
| # preds_direction = preds_direction.numpy() | |
| # preds_period = preds_period.numpy() | |
| # preds_lumen_width = preds_lumen_width.numpy() | |
| # aggregate angles | |
| est_angles = (preds_direction - angles_tta) % 180 | |
| est_angle = circmean(est_angles, low=-90, high=90) + 90 | |
| est_angle_std = circstd(est_angles, low=-90, high=90) | |
| est_angle_confidence = self._std_to_confidence(est_angle_std, 10) # confidence 0.5 for std =10 degrees | |
| if est_angle_confidence < self.angle_confidence_threshold: | |
| est_angle = np.nan | |
| est_angle_confidence = 0. | |
| if self.use_envelope_correction and (not np.isnan(est_angle)): | |
| angle_correction = -calculate_best_angle_from_mask( | |
| ndimage.rotate(mask, -est_angle, reshape=True, order=0) | |
| ) | |
| est_angle += angle_correction | |
| return dict( | |
| est_angle=est_angle, | |
| est_angle_confidence=est_angle_confidence, | |
| ) | |
| def _apply_radon(self, batch): # may reauire circle mask | |
| crops_radon = image_transforms.batched_radon(batch.numpy()) | |
| crops_radon = np.transpose(resize(np.transpose(crops_radon, (1, 2, 0)), self.radon_size), (2, 0, 1)) | |
| return torch.tensor(crops_radon) | |
| def _preprocess_batch(self, batch): | |
| if self.mask == 'circle': | |
| mask = get_circle_mask(batch.shape[1]) | |
| batch[:,mask] = 0 | |
| if self.apply_radon: | |
| batch = self._apply_radon(batch) | |
| batch = ((batch/255) - self.normalization['mean'])/self.normalization['std'] | |
| return batch.unsqueeze(1) # add channel dimension | |
| def _filter_outliers(self, x, qmin=0.25, qmax=0.75): | |
| x_min, x_max = np.quantile(x, [qmin, qmax]) | |
| return x[(x>=x_min) & (x<=x_max)] | |
| def _std_to_confidence(self, x, x_thr, y_thr=0.5): | |
| """transform [0, inf] to [1,0], such that f(x_thr)=y_thr""" | |
| return 1 / (1+x*(1-y_thr)/(x_thr*y_thr)) | |
| class CosineLoss(torch.nn.Module): | |
| def __init__(self, p=1, degrees=False, scale=1): | |
| super().__init__() | |
| self.p = p | |
| self.degrees = degrees | |
| self.scale = scale | |
| def forward(self, x, y): | |
| if self.degrees: | |
| x = torch.deg2rad(x) | |
| y = torch.deg2rad(y) | |
| return torch.mean((1-torch.cos(x-y))**self.p) * self.scale | |
| ## model | |
| class AngleParser2d(torch.nn.Module): | |
| def __init__(self, angle_range=180): | |
| super().__init__() | |
| self.angle_range = angle_range | |
| def forward(self, batch): | |
| # r = torch.linalg.norm(batch, dim=1) | |
| preds_y_proj = torch.sigmoid(batch[:,0]) - 0.5 | |
| preds_x_proj = torch.sigmoid(batch[:,1]) - 0.5 | |
| preds_direction = self.angle_range/360.*torch.rad2deg(torch.arctan2(preds_y_proj, preds_x_proj)) | |
| return preds_direction | |
| class AngleRegularizer(torch.nn.Module): | |
| def __init__(self, strength=1.0, scale=1.0, p=2): | |
| super().__init__() | |
| self.strength = strength | |
| self.scale = scale | |
| self.p = p | |
| def forward(self, batch): | |
| r = torch.linalg.norm(batch, dim=1) | |
| return self.strength * torch.norm(r - self.scale, p=self.p) | |
| class AngleRegularizerLog(torch.nn.Module): | |
| def __init__(self, strength=1.0, scale=1.0, p=2): | |
| super().__init__() | |
| self.strength = strength | |
| self.scale = scale | |
| self.p = p | |
| def forward(self, batch): | |
| r = torch.linalg.norm(batch, dim=1) | |
| return self.strength * torch.norm(torch.log(r/self.scale), p=self.p) | |
| class StripsModel(pl.LightningModule): | |
| def __init__(self, | |
| model_name = 'resnet18', | |
| lr=0.001, | |
| optimizer_hparams=dict(), | |
| lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)), | |
| loss_hparams=dict(rotation_weight=10., lumen_fraction_weight=50.), | |
| angle_hparams=dict(angle_range=180.), | |
| regularizer_hparams=None, | |
| sigmoid_smoother=10. | |
| ): | |
| super().__init__() | |
| # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace | |
| self.save_hyperparameters() | |
| # Create model - implemented in non-abstract classes | |
| self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim']) | |
| self.angle_parser = AngleParser2d(**self.hparams.angle_hparams) | |
| self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams) | |
| self.losses = { | |
| 'direction': CosineLoss(2., True), | |
| 'period': torch.nn.functional.mse_loss, | |
| 'lumen_fraction': torch.nn.functional.mse_loss | |
| } | |
| self.losses_weights = { | |
| 'direction': self.hparams.loss_hparams['rotation_weight'], | |
| 'period': 1, | |
| 'lumen_fraction': self.hparams.loss_hparams['lumen_fraction_weight'], | |
| 'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.) | |
| } | |
| def _get_regularizer(self, regularizer_params): | |
| if regularizer_params is None: | |
| return None | |
| else: | |
| return instantiate(regularizer_params) | |
| def forward(self, x, return_raw=False): | |
| """get predictions from image batch""" | |
| preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction | |
| preds_direction = self.angle_parser(preds) | |
| preds_period = preds[:,-2] | |
| preds_lumen_fraction = torch.sigmoid(preds[:,-1]*self.hparams.sigmoid_smoother) #lumen fraction is between 0 and 1, so we take sigmoid fo this | |
| outputs = [preds_direction, preds_period, preds_lumen_fraction] | |
| if return_raw: | |
| outputs.append(preds) | |
| return tuple(outputs) | |
| def configure_optimizers(self): | |
| # AdamW is Adam with a correct implementation of weight decay (see here | |
| # for details: https://arxiv.org/pdf/1711.05101.pdf) | |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams) | |
| # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs']) | |
| scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer) | |
| return [optimizer], [scheduler] | |
| def process_batch_supervised(self, batch): | |
| """get predictions, losses and mean errors (MAE)""" | |
| # get predictions | |
| preds = {} | |
| preds['direction'], preds['period'], preds['lumen_fraction'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds | |
| # calculate losses | |
| losses = { | |
| 'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']), | |
| 'period': self.losses['period'](batch['period'], preds['period']), | |
| 'lumen_fraction': self.losses['lumen_fraction'](batch['lumen_fraction'], preds['lumen_fraction']), | |
| } | |
| if self.regularizer is not None: | |
| losses['regularization'] = self.regularizer(preds_raw[:,:2]) | |
| losses['final'] = \ | |
| losses['direction']*self.losses_weights['direction'] + \ | |
| losses['period']*self.losses_weights['period'] + \ | |
| losses['lumen_fraction']*self.losses_weights['lumen_fraction'] + \ | |
| losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.) | |
| # calculate mean errors | |
| period_difference = np.mean(abs( | |
| batch['period'].detach().cpu().numpy() - \ | |
| preds['period'].detach().cpu().numpy() | |
| )) | |
| a1 = batch['direction'].detach().cpu().numpy() | |
| a2 = preds['direction'].detach().cpu().numpy() | |
| angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1))))) | |
| lumen_fraction_difference = np.mean(abs(preds['lumen_fraction'].detach().cpu().numpy()-batch['lumen_fraction'].detach().cpu().numpy())) | |
| mae = { | |
| 'period': period_difference, | |
| 'direction': angle_difference, | |
| 'lumen_fraction': lumen_fraction_difference | |
| } | |
| return preds, losses, mae | |
| def log_all(self, losses, mae, prefix=''): | |
| self.log(f"{prefix}angle_loss", losses['direction'].item()) | |
| self.log(f"{prefix}period_loss", losses['period'].item()) | |
| self.log(f"{prefix}lumen_fraction_loss", losses['lumen_fraction'].item()) | |
| self.log(f"{prefix}period_difference", mae['period']) | |
| self.log(f"{prefix}angle_difference", mae['direction']) | |
| self.log(f"{prefix}lumen_fraction_difference", mae['lumen_fraction']) | |
| self.log(f"{prefix}loss", losses['final']) | |
| if 'regularization' in losses: | |
| self.log(f"{prefix}regularization_loss", losses['regularization'].item()) | |
| def training_step(self, batch, batch_idx): | |
| # "batch" is the output of the training data loader. | |
| preds, losses, mae = self.process_batch_supervised(batch) | |
| self.log_all(losses, mae, prefix='train_') | |
| return losses['final'] | |
| def validation_step(self, batch, batch_idx): | |
| preds, losses, mae = self.process_batch_supervised(batch) | |
| self.log_all(losses, mae, prefix='val_') | |
| def test_step(self, batch, batch_idx): | |
| preds, losses, mae = self.process_batch_supervised(batch) | |
| self.log_all(losses, mae, prefix='test_') | |
| class StripsModelLumenWidth(pl.LightningModule): | |
| def __init__(self, | |
| model_name = 'resnet18', | |
| lr=0.001, | |
| optimizer_hparams=dict(), | |
| lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)), | |
| loss_hparams=dict(rotation_weight=10., lumen_width_weight=50.), | |
| angle_hparams=dict(angle_range=180.), | |
| regularizer_hparams=None, | |
| sigmoid_smoother=10. | |
| ): | |
| super().__init__() | |
| # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace | |
| self.save_hyperparameters() | |
| # Create model - implemented in non-abstract classes | |
| self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim']) | |
| self.angle_parser = AngleParser2d(**self.hparams.angle_hparams) | |
| self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams) | |
| self.losses = { | |
| 'direction': CosineLoss(2., True), | |
| 'period': torch.nn.functional.mse_loss, | |
| 'lumen_width': torch.nn.functional.mse_loss | |
| } | |
| self.losses_weights = { | |
| 'direction': self.hparams.loss_hparams['rotation_weight'], | |
| 'period': 1, | |
| 'lumen_width': self.hparams.loss_hparams['lumen_width_weight'], | |
| 'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.) | |
| } | |
| def _get_regularizer(self, regularizer_params): | |
| if regularizer_params is None: | |
| return None | |
| else: | |
| return instantiate(regularizer_params) | |
| def forward(self, x, return_raw=False): | |
| """get predictions from image batch""" | |
| preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction | |
| preds_direction = self.angle_parser(preds) | |
| preds_period = preds[:,-2] | |
| preds_lumen_width = preds[:,-1] #lumen fraction is between 0 and 1, so we take sigmoid fo this | |
| outputs = [preds_direction, preds_period, preds_lumen_width] | |
| if return_raw: | |
| outputs.append(preds) | |
| return tuple(outputs) | |
| def configure_optimizers(self): | |
| # AdamW is Adam with a correct implementation of weight decay (see here | |
| # for details: https://arxiv.org/pdf/1711.05101.pdf) | |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams) | |
| # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs']) | |
| scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer) | |
| return [optimizer], [scheduler] | |
| def process_batch_supervised(self, batch): | |
| """get predictions, losses and mean errors (MAE)""" | |
| # get predictions | |
| preds = {} | |
| preds['direction'], preds['period'], preds['lumen_width'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds | |
| # calculate losses | |
| losses = { | |
| 'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']), | |
| 'period': self.losses['period'](batch['period'], preds['period']), | |
| 'lumen_width': self.losses['lumen_width'](batch['lumen_width'], preds['lumen_width']), | |
| } | |
| if self.regularizer is not None: | |
| losses['regularization'] = self.regularizer(preds_raw[:,:2]) | |
| losses['final'] = \ | |
| losses['direction']*self.losses_weights['direction'] + \ | |
| losses['period']*self.losses_weights['period'] + \ | |
| losses['lumen_width']*self.losses_weights['lumen_width'] + \ | |
| losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.) | |
| # calculate mean errors | |
| period_difference = np.mean(abs( | |
| batch['period'].detach().cpu().numpy() - \ | |
| preds['period'].detach().cpu().numpy() | |
| )) | |
| a1 = batch['direction'].detach().cpu().numpy() | |
| a2 = preds['direction'].detach().cpu().numpy() | |
| angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1))))) | |
| lumen_width_difference = np.mean(abs(preds['lumen_width'].detach().cpu().numpy()-batch['lumen_width'].detach().cpu().numpy())) | |
| lumen_fraction_pred = preds['lumen_width'].detach().cpu().numpy()/preds['period'].detach().cpu().numpy() | |
| lumen_fraction_gt = batch['lumen_width'].detach().cpu().numpy()/batch['period'].detach().cpu().numpy() | |
| lumen_fraction_difference = np.mean(abs(lumen_fraction_pred-lumen_fraction_gt)) | |
| mae = { | |
| 'period': period_difference, | |
| 'direction': angle_difference, | |
| 'lumen_width': lumen_width_difference, | |
| 'lumen_fraction': lumen_fraction_difference | |
| } | |
| return preds, losses, mae | |
| def log_all(self, losses, mae, prefix=''): | |
| for k, v in losses.items(): | |
| self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v) | |
| for k, v in mae.items(): | |
| self.log(f'{prefix}{k}_difference', v.item() if isinstance(v, torch.Tensor) else v) | |
| def training_step(self, batch, batch_idx): | |
| # "batch" is the output of the training data loader. | |
| preds, losses, mae = self.process_batch_supervised(batch) | |
| self.log_all(losses, mae, prefix='train_') | |
| return losses['final'] | |
| def validation_step(self, batch, batch_idx): | |
| preds, losses, mae = self.process_batch_supervised(batch) | |
| self.log_all(losses, mae, prefix='val_') | |
| def test_step(self, batch, batch_idx): | |
| preds, losses, mae = self.process_batch_supervised(batch) | |
| self.log_all(losses, mae, prefix='test_') | |
| # class StripsModel(StripsModelGeneral): | |
| # def __init__(self, model_name, *args, **kwargs): | |
| # super().__init__( *args, **kwargs) | |
| # self.model = timm.create_model(model_name, in_chans=1, num_classes=4) | |
| # def forward(self, x): | |
| # """get predictions from image batch""" | |
| # preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction | |
| # preds_sin = 1. - 2*torch.sigmoid(preds[:,0]) | |
| # preds_cos = 1. - 2*torch.sigmoid(preds[:,1]) | |
| # preds_direction = 0.5*torch.rad2deg(torch.arctan2(preds_sin, preds_cos)) | |
| # preds_period = preds[:,2] | |
| # preds_lumen_fraction = torch.sigmoid(preds[:,3]) #lumen fraction is between 0 and 1, so we take sigmoid fo this | |
| # return preds_direction, preds_period, preds_lumen_fraction | |
| # class StripsModelAngle1(StripsModelGeneral): | |
| # def __init__(self, model_name, *args, **kwargs): | |
| # super().__init__( *args, **kwargs) | |
| # self.model = timm.create_model(model_name, in_chans=1, num_classes=3) | |
| # def forward(self, x): | |
| # """get predictions from image batch""" | |
| # preds = self.model(x) # preds: logit angle_sin, logit angle | |
| # preds_direction = torch.pi * torch.sigmoid(preds[:,0]) | |
| # preds_period = preds[:,1] | |
| # preds_lumen_fraction = torch.sigmoid(preds[:,2]) #lumen fraction is between 0 and 1, so we take sigmoid fo this | |
| # return preds_direction, preds_period, preds_lumen_fraction | |