from typing import Any, List import torch import torch.nn.functional as F from torch import nn from pytorch_lightning import LightningModule from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric from torchmetrics.classification.accuracy import Accuracy import torchvision.models as models import kornia def vol4(img): img_grey = torch.mean(img, dim=0) return 100 / torch.sum(torch.mul(img_grey[:-1, :], img_grey[1:, :])) - torch.sum( torch.mul(img_grey[:-2, :], img_grey[2:, :]) ) def laplacian(img): img_grey = torch.mean(img, dim=0).unsqueeze(0) filtered = kornia.filters.laplacian(img_grey, 3) mean = torch.mean(filtered) return 100 / mean # invert mean to fit metric of lower = better def midfrequency_dct(img): kernel = torch.tensor( [ [ [1, 1, -1, -1], [1, 1, -1, -1], [-1, -1, 1, 1], [-1, -1, 1, 1], ] ] ) img_grey = torch.mean(img, dim=0).unsqueeze(0) filtered = torch.square(kornia.filters.filter2d(img_grey, kernel)) sum = torch.sum(filtered) return 100 / sum class TraditionalLitModule(LightningModule): def __init__( self, method: str = "vol4", ): """Initialize function for a traditional focus measurement `model`. It cannot be trained. Args: method (str, optional): The method to use for predicting focus. Defaults to "vol4". Possible values are: vol4, mean_laplacian, midfrequency_dct Raises: Exception: raises exception if method parameter is not known """ super().__init__() if method == "vol4": self.function = vol4 if method == "mean_laplacian": self.function = laplacian if method == "midfrequency_dct": self.function = midfrequency_dct def forward(self, x): return self.function(x)