Hannes Kuchelmeister commited on
Commit
a490849
1 Parent(s): c3f6f40

add mean laplacian and mdct

Browse files
notebooks/5.0-hfk-comparing-to-traditional.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:733ce4b03b6efa6a4a0996783c8d88dcd64e9e269b16bfcb747e81dfd78b5743
3
- size 11339
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4e1c4becea54e3d35081c4e2802d466cd0ef6f752ce5a49eeebad3091ef7262
3
+ size 15942
requirements.txt CHANGED
@@ -8,6 +8,10 @@ torchmetrics>=0.7.0
8
  scikit-image
9
  pandas
10
 
 
 
 
 
11
  # --------- hydra --------- #
12
  hydra-core>=1.1.0
13
  hydra-colorlog>=1.1.0
 
8
  scikit-image
9
  pandas
10
 
11
+ # --------- libraries for image filters ---------#
12
+ kornia
13
+
14
+
15
  # --------- hydra --------- #
16
  hydra-core>=1.1.0
17
  hydra-colorlog>=1.1.0
src/models/focus_traditional.py CHANGED
@@ -7,6 +7,7 @@ from pytorch_lightning import LightningModule
7
  from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
8
  from torchmetrics.classification.accuracy import Accuracy
9
  import torchvision.models as models
 
10
 
11
 
12
  def vol4(img):
@@ -16,6 +17,31 @@ def vol4(img):
16
  )
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class TraditionalLitModule(LightningModule):
20
  def __init__(
21
  self,
@@ -26,6 +52,7 @@ class TraditionalLitModule(LightningModule):
26
 
27
  Args:
28
  method (str, optional): The method to use for predicting focus. Defaults to "vol4".
 
29
 
30
  Raises:
31
  Exception: raises exception if method parameter is not known
@@ -34,6 +61,10 @@ class TraditionalLitModule(LightningModule):
34
 
35
  if method == "vol4":
36
  self.function = vol4
 
 
 
 
37
 
38
  def forward(self, x):
39
  return self.function(x)
 
7
  from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
8
  from torchmetrics.classification.accuracy import Accuracy
9
  import torchvision.models as models
10
+ import kornia
11
 
12
 
13
  def vol4(img):
 
17
  )
18
 
19
 
20
+ def laplacian(img):
21
+ img_grey = torch.mean(img, dim=0).unsqueeze(0)
22
+ filtered = kornia.filters.laplacian(img_grey, 3)
23
+ mean = torch.mean(filtered)
24
+ return 100 / mean # invert mean to fit metric of lower = better
25
+
26
+
27
+ def midfrequency_dct(img):
28
+ kernel = torch.tensor(
29
+ [
30
+ [
31
+ [1, 1, -1, -1],
32
+ [1, 1, -1, -1],
33
+ [-1, -1, 1, 1],
34
+ [-1, -1, 1, 1],
35
+ ]
36
+ ]
37
+ )
38
+
39
+ img_grey = torch.mean(img, dim=0).unsqueeze(0)
40
+ filtered = kornia.filters.filter2d(img_grey, kernel)
41
+ sum = torch.sum(filtered)
42
+ return 100 / sum
43
+
44
+
45
  class TraditionalLitModule(LightningModule):
46
  def __init__(
47
  self,
 
52
 
53
  Args:
54
  method (str, optional): The method to use for predicting focus. Defaults to "vol4".
55
+ Possible values are: vol4, mean_laplacian, midfrequency_dct
56
 
57
  Raises:
58
  Exception: raises exception if method parameter is not known
 
61
 
62
  if method == "vol4":
63
  self.function = vol4
64
+ if method == "mean_laplacian":
65
+ self.function = laplacian
66
+ if method == "midfrequency_dct":
67
+ self.function = midfrequency_dct
68
 
69
  def forward(self, x):
70
  return self.function(x)