|
from typing import Any, List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from pytorch_lightning import LightningModule |
|
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric |
|
from torchmetrics.classification.accuracy import Accuracy |
|
import torchvision.models as models |
|
import kornia |
|
|
|
|
|
def vol4(img): |
|
img_grey = torch.mean(img, dim=0) |
|
return 100 / torch.sum(torch.mul(img_grey[:-1, :], img_grey[1:, :])) - torch.sum( |
|
torch.mul(img_grey[:-2, :], img_grey[2:, :]) |
|
) |
|
|
|
|
|
def laplacian(img): |
|
img_grey = torch.mean(img, dim=0).unsqueeze(0) |
|
filtered = kornia.filters.laplacian(img_grey, 3) |
|
mean = torch.mean(filtered) |
|
return 100 / mean |
|
|
|
|
|
def midfrequency_dct(img): |
|
kernel = torch.tensor( |
|
[ |
|
[ |
|
[1, 1, -1, -1], |
|
[1, 1, -1, -1], |
|
[-1, -1, 1, 1], |
|
[-1, -1, 1, 1], |
|
] |
|
] |
|
) |
|
|
|
img_grey = torch.mean(img, dim=0).unsqueeze(0) |
|
filtered = torch.square(kornia.filters.filter2d(img_grey, kernel)) |
|
sum = torch.sum(filtered) |
|
return 100 / sum |
|
|
|
|
|
class TraditionalLitModule(LightningModule): |
|
def __init__( |
|
self, |
|
method: str = "vol4", |
|
): |
|
|
|
"""Initialize function for a traditional focus measurement `model`. It cannot be trained. |
|
|
|
Args: |
|
method (str, optional): The method to use for predicting focus. Defaults to "vol4". |
|
Possible values are: vol4, mean_laplacian, midfrequency_dct |
|
|
|
Raises: |
|
Exception: raises exception if method parameter is not known |
|
""" |
|
super().__init__() |
|
|
|
if method == "vol4": |
|
self.function = vol4 |
|
if method == "mean_laplacian": |
|
self.function = laplacian |
|
if method == "midfrequency_dct": |
|
self.function = midfrequency_dct |
|
|
|
def forward(self, x): |
|
return self.function(x) |
|
|