master_thesis_models / src /models /focus_traditional.py
Hannes Kuchelmeister
fix mdct (add square)
1a31ea8
raw
history blame
1.97 kB
from typing import Any, List
import torch
import torch.nn.functional as F
from torch import nn
from pytorch_lightning import LightningModule
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
from torchmetrics.classification.accuracy import Accuracy
import torchvision.models as models
import kornia
def vol4(img):
img_grey = torch.mean(img, dim=0)
return 100 / torch.sum(torch.mul(img_grey[:-1, :], img_grey[1:, :])) - torch.sum(
torch.mul(img_grey[:-2, :], img_grey[2:, :])
)
def laplacian(img):
img_grey = torch.mean(img, dim=0).unsqueeze(0)
filtered = kornia.filters.laplacian(img_grey, 3)
mean = torch.mean(filtered)
return 100 / mean # invert mean to fit metric of lower = better
def midfrequency_dct(img):
kernel = torch.tensor(
[
[
[1, 1, -1, -1],
[1, 1, -1, -1],
[-1, -1, 1, 1],
[-1, -1, 1, 1],
]
]
)
img_grey = torch.mean(img, dim=0).unsqueeze(0)
filtered = torch.square(kornia.filters.filter2d(img_grey, kernel))
sum = torch.sum(filtered)
return 100 / sum
class TraditionalLitModule(LightningModule):
def __init__(
self,
method: str = "vol4",
):
"""Initialize function for a traditional focus measurement `model`. It cannot be trained.
Args:
method (str, optional): The method to use for predicting focus. Defaults to "vol4".
Possible values are: vol4, mean_laplacian, midfrequency_dct
Raises:
Exception: raises exception if method parameter is not known
"""
super().__init__()
if method == "vol4":
self.function = vol4
if method == "mean_laplacian":
self.function = laplacian
if method == "midfrequency_dct":
self.function = midfrequency_dct
def forward(self, x):
return self.function(x)