Artyom
commited on
MiAlgo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- MiAlgo/Dockerfile +9 -0
- MiAlgo/README.txt +16 -0
- MiAlgo/__pycache__/net_torch.cpython-36.pyc +0 -0
- MiAlgo/__pycache__/network.cpython-36.pyc +0 -0
- MiAlgo/__pycache__/network.cpython-37.pyc +0 -0
- MiAlgo/__pycache__/network.cpython-38.pyc +0 -0
- MiAlgo/__pycache__/network_raw_denoise.cpython-36.pyc +0 -0
- MiAlgo/__pycache__/network_raw_denoise.cpython-38.pyc +0 -0
- MiAlgo/__pycache__/tf2onnx.cpython-36.pyc +0 -0
- MiAlgo/__pycache__/unet.cpython-36.pyc +0 -0
- MiAlgo/__pycache__/unet.cpython-37.pyc +0 -0
- MiAlgo/__pycache__/unet.cpython-38.pyc +0 -0
- MiAlgo/__pycache__/utils.cpython-36.pyc +0 -0
- MiAlgo/__pycache__/utils.cpython-37.pyc +0 -0
- MiAlgo/__pycache__/utils.cpython-38.pyc +0 -0
- MiAlgo/assets/pretrained/hub/checkpoints/squeezenet1_1-f364aa15.pth +3 -0
- MiAlgo/auxiliary/__pycache__/settings.cpython-36.pyc +0 -0
- MiAlgo/auxiliary/__pycache__/settings.cpython-37.pyc +0 -0
- MiAlgo/auxiliary/__pycache__/settings.cpython-38.pyc +0 -0
- MiAlgo/auxiliary/__pycache__/utils.cpython-36.pyc +0 -0
- MiAlgo/auxiliary/__pycache__/utils.cpython-37.pyc +0 -0
- MiAlgo/auxiliary/__pycache__/utils.cpython-38.pyc +0 -0
- MiAlgo/auxiliary/settings.py +47 -0
- MiAlgo/auxiliary/utils.py +107 -0
- MiAlgo/checkpoint/nn_enhance.pth +3 -0
- MiAlgo/checkpoint/raw_denoise.pth +3 -0
- MiAlgo/classes/core/Evaluator.py +49 -0
- MiAlgo/classes/core/Loss.py +13 -0
- MiAlgo/classes/core/LossTracker.py +16 -0
- MiAlgo/classes/core/Model.py +43 -0
- MiAlgo/classes/core/__pycache__/Evaluator.cpython-36.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Evaluator.cpython-37.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Evaluator.cpython-38.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Loss.cpython-36.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Loss.cpython-37.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Loss.cpython-38.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Model.cpython-36.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Model.cpython-37.pyc +0 -0
- MiAlgo/classes/core/__pycache__/Model.cpython-38.pyc +0 -0
- MiAlgo/classes/data/ColorCheckerDataset.py +52 -0
- MiAlgo/classes/data/DataAugmenter.py +152 -0
- MiAlgo/classes/data/__pycache__/ColorCheckerDataset.cpython-36.pyc +0 -0
- MiAlgo/classes/data/__pycache__/ColorCheckerDataset.cpython-37.pyc +0 -0
- MiAlgo/classes/data/__pycache__/ColorCheckerDataset.cpython-38.pyc +0 -0
- MiAlgo/classes/data/__pycache__/DataAugmenter.cpython-36.pyc +0 -0
- MiAlgo/classes/data/__pycache__/DataAugmenter.cpython-37.pyc +0 -0
- MiAlgo/classes/data/__pycache__/DataAugmenter.cpython-38.pyc +0 -0
- MiAlgo/classes/fc4/FC4.py +63 -0
- MiAlgo/classes/fc4/ModelFC4.py +44 -0
- MiAlgo/classes/fc4/__pycache__/FC4.cpython-36.pyc +0 -0
MiAlgo/Dockerfile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM q935970314/mialgo:2022
|
2 |
+
|
3 |
+
COPY . /workdir/
|
4 |
+
|
5 |
+
RUN export PATH=/opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin && pip install grayness_index_python
|
6 |
+
|
7 |
+
WORKDIR /workdir/
|
8 |
+
|
9 |
+
|
MiAlgo/README.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Run the code to build docker and process the night image:
|
2 |
+
|
3 |
+
docker build -t mialgo24 .
|
4 |
+
|
5 |
+
cd DATA_FOLDER_PATH
|
6 |
+
docker run -it --rm --gpus all -v $(pwd)/data:/data mialgo24 ./run.sh
|
7 |
+
|
8 |
+
|
9 |
+
Please note that the "--gpus all" parameter is used and ensure that the model is running on the GPU.
|
10 |
+
We first load all images into memory. If the memory explodes, please process less than 50 images at a time.
|
11 |
+
|
12 |
+
We will open source the code after the competition ends.
|
13 |
+
|
14 |
+
If you have any questions, please contact me in time, thank you.
|
15 |
+
|
16 |
+
Email:liushuai21@xiaomi.com
|
MiAlgo/__pycache__/net_torch.cpython-36.pyc
ADDED
Binary file (4.68 kB). View file
|
|
MiAlgo/__pycache__/network.cpython-36.pyc
ADDED
Binary file (16.4 kB). View file
|
|
MiAlgo/__pycache__/network.cpython-37.pyc
ADDED
Binary file (16.1 kB). View file
|
|
MiAlgo/__pycache__/network.cpython-38.pyc
ADDED
Binary file (16.7 kB). View file
|
|
MiAlgo/__pycache__/network_raw_denoise.cpython-36.pyc
ADDED
Binary file (5.28 kB). View file
|
|
MiAlgo/__pycache__/network_raw_denoise.cpython-38.pyc
ADDED
Binary file (5.31 kB). View file
|
|
MiAlgo/__pycache__/tf2onnx.cpython-36.pyc
ADDED
Binary file (1.12 kB). View file
|
|
MiAlgo/__pycache__/unet.cpython-36.pyc
ADDED
Binary file (3.68 kB). View file
|
|
MiAlgo/__pycache__/unet.cpython-37.pyc
ADDED
Binary file (3.69 kB). View file
|
|
MiAlgo/__pycache__/unet.cpython-38.pyc
ADDED
Binary file (3.65 kB). View file
|
|
MiAlgo/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (8.4 kB). View file
|
|
MiAlgo/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (8.29 kB). View file
|
|
MiAlgo/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (9.09 kB). View file
|
|
MiAlgo/assets/pretrained/hub/checkpoints/squeezenet1_1-f364aa15.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f364aa15cc776cd43e679ca0859f479db11ef4852a4e79bb237519f9d16617c5
|
3 |
+
size 4966400
|
MiAlgo/auxiliary/__pycache__/settings.cpython-36.pyc
ADDED
Binary file (1.22 kB). View file
|
|
MiAlgo/auxiliary/__pycache__/settings.cpython-37.pyc
ADDED
Binary file (1.2 kB). View file
|
|
MiAlgo/auxiliary/__pycache__/settings.cpython-38.pyc
ADDED
Binary file (1.19 kB). View file
|
|
MiAlgo/auxiliary/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (4.6 kB). View file
|
|
MiAlgo/auxiliary/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (4.57 kB). View file
|
|
MiAlgo/auxiliary/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.59 kB). View file
|
|
MiAlgo/auxiliary/settings.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
# --- Determinism (for reproducibility) ---
|
8 |
+
|
9 |
+
def make_deterministic(seed: int):
|
10 |
+
torch.manual_seed(seed)
|
11 |
+
np.random.seed(seed)
|
12 |
+
torch.backends.cudnn.benchmark = False
|
13 |
+
|
14 |
+
|
15 |
+
# --- Device (cpu or cuda:n) ---
|
16 |
+
|
17 |
+
DEVICE_TYPE = "cuda:0"
|
18 |
+
|
19 |
+
|
20 |
+
def get_device() -> torch.device:
|
21 |
+
if DEVICE_TYPE == "cpu":
|
22 |
+
print("\n Running on device 'cpu' \n")
|
23 |
+
return torch.device("cpu")
|
24 |
+
|
25 |
+
if re.match(r"\bcuda:\b\d+", DEVICE_TYPE):
|
26 |
+
if not torch.cuda.is_available():
|
27 |
+
print("\n WARNING: running on cpu since device {} is not available \n".format(DEVICE_TYPE))
|
28 |
+
return torch.device("cpu")
|
29 |
+
|
30 |
+
# print("\n Running on device '{}' \n".format(DEVICE_TYPE))
|
31 |
+
return torch.device(DEVICE_TYPE)
|
32 |
+
|
33 |
+
raise ValueError("ERROR: {} is not a valid device! Supported device are 'cpu' and 'cuda:n'".format(DEVICE_TYPE))
|
34 |
+
|
35 |
+
|
36 |
+
DEVICE = get_device()
|
37 |
+
|
38 |
+
# --- Model ---
|
39 |
+
|
40 |
+
# If set to False, a simpler summation pooling will be used
|
41 |
+
USE_CONFIDENCE_WEIGHTED_POOLING = True
|
42 |
+
if not USE_CONFIDENCE_WEIGHTED_POOLING:
|
43 |
+
print("\n WARN: confidence-weighted pooling option is set to False \n")
|
44 |
+
|
45 |
+
# Input size
|
46 |
+
TRAIN_IMG_W, TRAIN_IMG_H = 512, 512
|
47 |
+
TEST_IMG_W, TEST_IMG_H = 0, 0
|
MiAlgo/auxiliary/utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import Union, List, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
from PIL.Image import Image
|
9 |
+
from scipy.spatial.distance import jensenshannon
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn.functional import interpolate
|
12 |
+
|
13 |
+
from auxiliary.settings import DEVICE
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def print_metrics(current_metrics: dict, best_metrics: dict):
|
18 |
+
print(" Mean ......... : {:.4f} (Best: {:.4f})".format(current_metrics["mean"], best_metrics["mean"]))
|
19 |
+
print(" Median ....... : {:.4f} (Best: {:.4f})".format(current_metrics["median"], best_metrics["median"]))
|
20 |
+
print(" Trimean ...... : {:.4f} (Best: {:.4f})".format(current_metrics["trimean"], best_metrics["trimean"]))
|
21 |
+
print(" Best 25% ..... : {:.4f} (Best: {:.4f})".format(current_metrics["bst25"], best_metrics["bst25"]))
|
22 |
+
print(" Worst 25% .... : {:.4f} (Best: {:.4f})".format(current_metrics["wst25"], best_metrics["wst25"]))
|
23 |
+
print(" Worst 5% ..... : {:.4f} (Best: {:.4f})".format(current_metrics["wst5"], best_metrics["wst5"]))
|
24 |
+
|
25 |
+
|
26 |
+
def correct(img: Image, illuminant: Tensor) -> Image:
|
27 |
+
"""
|
28 |
+
Corrects the color of the illuminant of a linear image based on an estimated (linear) illuminant
|
29 |
+
@param img: a linear image
|
30 |
+
@param illuminant: a linear illuminant
|
31 |
+
@return: a non-linear color-corrected version of the input image
|
32 |
+
"""
|
33 |
+
img = F.to_tensor(img).to(DEVICE)
|
34 |
+
|
35 |
+
# Correct the image
|
36 |
+
correction = illuminant.unsqueeze(2).unsqueeze(3) * torch.sqrt(Tensor([3])).to(DEVICE)
|
37 |
+
corrected_img = torch.div(img, correction + 1e-10)
|
38 |
+
|
39 |
+
# Normalize the image
|
40 |
+
max_img = torch.max(torch.max(torch.max(corrected_img, dim=1)[0], dim=1)[0], dim=1)[0] + 1e-10
|
41 |
+
max_img = max_img.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
42 |
+
normalized_img = torch.div(corrected_img, max_img)
|
43 |
+
|
44 |
+
return F.to_pil_image(linear_to_nonlinear(normalized_img).squeeze(), mode="RGB")
|
45 |
+
|
46 |
+
|
47 |
+
def linear_to_nonlinear(img: Union[np.array, Image, Tensor]) -> Union[np.array, Image, Tensor]:
|
48 |
+
if isinstance(img, np.ndarray):
|
49 |
+
return np.power(img, (1.0 / 2.2))
|
50 |
+
if isinstance(img, Tensor):
|
51 |
+
return torch.pow(img, 1.0 / 2.2)
|
52 |
+
return F.to_pil_image(torch.pow(F.to_tensor(img), 1.0 / 2.2).squeeze(), mode="RGB")
|
53 |
+
|
54 |
+
|
55 |
+
def normalize(img: np.ndarray) -> np.ndarray:
|
56 |
+
max_int = 65535.0
|
57 |
+
return np.clip(img, 0.0, max_int) * (1.0 / max_int)
|
58 |
+
|
59 |
+
|
60 |
+
def rgb_to_bgr(x: np.ndarray) -> np.ndarray:
|
61 |
+
return x[::-1]
|
62 |
+
|
63 |
+
|
64 |
+
def bgr_to_rgb(x: np.ndarray) -> np.ndarray:
|
65 |
+
return x[:, :, ::-1]
|
66 |
+
|
67 |
+
|
68 |
+
def hwc_to_chw(x: np.ndarray) -> np.ndarray:
|
69 |
+
""" Converts an image from height x width x channels to channels x height x width """
|
70 |
+
return x.transpose(2, 0, 1)
|
71 |
+
|
72 |
+
|
73 |
+
def scale(x: Tensor) -> Tensor:
|
74 |
+
""" Scales all values of a tensor between 0 and 1 """
|
75 |
+
x = x - x.min()
|
76 |
+
x = x / x.max()
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
def rescale(x: Tensor, size: Tuple) -> Tensor:
|
81 |
+
""" Rescale tensor to image size for better visualization """
|
82 |
+
return interpolate(x, size, mode='bilinear')
|
83 |
+
|
84 |
+
|
85 |
+
def angular_error(x: Tensor, y: Tensor, safe_v: float = 0.999999) -> Tensor:
|
86 |
+
x, y = torch.nn.functional.normalize(x, dim=1), torch.nn.functional.normalize(y, dim=1)
|
87 |
+
dot = torch.clamp(torch.sum(x * y, dim=1), -safe_v, safe_v)
|
88 |
+
angle = torch.acos(dot) * (180 / math.pi)
|
89 |
+
return torch.mean(angle).item()
|
90 |
+
|
91 |
+
|
92 |
+
def tvd(pred: Tensor, label: Tensor) -> Tensor:
|
93 |
+
"""
|
94 |
+
Total Variation Distance (TVD) is a distance measure for probability distributions
|
95 |
+
https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures
|
96 |
+
"""
|
97 |
+
return (Tensor([0.5]) * torch.abs(pred - label)).sum()
|
98 |
+
|
99 |
+
|
100 |
+
def jsd(p: List, q: List) -> float:
|
101 |
+
"""
|
102 |
+
Jensen-Shannon Divergence (JSD) between two probability distributions as square of scipy's JS distance. Refs:
|
103 |
+
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html
|
104 |
+
- https://stackoverflow.com/questions/15880133/jensen-shannon-divergence
|
105 |
+
"""
|
106 |
+
return jensenshannon(p, q) ** 2
|
107 |
+
|
MiAlgo/checkpoint/nn_enhance.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37c37648922f149ebefcec8b4c32c687bbafacd1efd99e288b235aab6e834ad5
|
3 |
+
size 152303133
|
MiAlgo/checkpoint/raw_denoise.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a987213026b43e214196154494e5b54045721dea11dcb8e6f1f2d8e632445699
|
3 |
+
size 2072446
|
MiAlgo/classes/core/Evaluator.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Evaluator:
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
monitored_metrics = ["mean", "median", "trimean", "bst25", "wst25", "wst5"]
|
8 |
+
self.__metrics = {}
|
9 |
+
self.__best_metrics = {m: 100.0 for m in monitored_metrics}
|
10 |
+
self.__errors = []
|
11 |
+
|
12 |
+
def add_error(self, error: float):
|
13 |
+
self.__errors.append(error)
|
14 |
+
|
15 |
+
def reset_errors(self):
|
16 |
+
self.__errors = []
|
17 |
+
|
18 |
+
def get_errors(self) -> list:
|
19 |
+
return self.__errors
|
20 |
+
|
21 |
+
def get_metrics(self) -> dict:
|
22 |
+
return self.__metrics
|
23 |
+
|
24 |
+
def get_best_metrics(self) -> dict:
|
25 |
+
return self.__best_metrics
|
26 |
+
|
27 |
+
def compute_metrics(self) -> dict:
|
28 |
+
self.__errors = sorted(self.__errors)
|
29 |
+
self.__metrics = {
|
30 |
+
"mean": np.mean(self.__errors),
|
31 |
+
"median": self.__g(0.5),
|
32 |
+
"trimean": 0.25 * (self.__g(0.25) + 2 * self.__g(0.5) + self.__g(0.75)),
|
33 |
+
"bst25": np.mean(self.__errors[:int(0.25 * len(self.__errors))]),
|
34 |
+
"wst25": np.mean(self.__errors[int(0.75 * len(self.__errors)):]),
|
35 |
+
"wst5": self.__g(0.95)
|
36 |
+
}
|
37 |
+
return self.__metrics
|
38 |
+
|
39 |
+
def update_best_metrics(self) -> dict:
|
40 |
+
self.__best_metrics["mean"] = self.__metrics["mean"]
|
41 |
+
self.__best_metrics["median"] = self.__metrics["median"]
|
42 |
+
self.__best_metrics["trimean"] = self.__metrics["trimean"]
|
43 |
+
self.__best_metrics["bst25"] = self.__metrics["bst25"]
|
44 |
+
self.__best_metrics["wst25"] = self.__metrics["wst25"]
|
45 |
+
self.__best_metrics["wst5"] = self.__metrics["wst5"]
|
46 |
+
return self.__best_metrics
|
47 |
+
|
48 |
+
def __g(self, f: float) -> float:
|
49 |
+
return np.percentile(self.__errors, f * 100)
|
MiAlgo/classes/core/Loss.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
|
5 |
+
class Loss:
|
6 |
+
def __init__(self, device: torch.device):
|
7 |
+
self._device = device
|
8 |
+
|
9 |
+
def _compute(self, *args, **kwargs) -> Tensor:
|
10 |
+
pass
|
11 |
+
|
12 |
+
def __call__(self, *args, **kwargs):
|
13 |
+
return self._compute(*args).to(self._device)
|
MiAlgo/classes/core/LossTracker.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class LossTracker(object):
|
2 |
+
|
3 |
+
def __init__(self):
|
4 |
+
self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
|
5 |
+
|
6 |
+
def reset(self):
|
7 |
+
self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
|
8 |
+
|
9 |
+
def update(self, val: float, n: int = 1):
|
10 |
+
self.val = val
|
11 |
+
self.sum += val * n
|
12 |
+
self.count += n
|
13 |
+
self.avg = self.sum / self.count
|
14 |
+
|
15 |
+
def get_loss(self) -> float:
|
16 |
+
return self.avg
|
MiAlgo/classes/core/Model.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
from auxiliary.settings import DEVICE
|
7 |
+
from classes.losses.AngularLoss import AngularLoss
|
8 |
+
|
9 |
+
|
10 |
+
class Model:
|
11 |
+
def __init__(self):
|
12 |
+
self._device = DEVICE
|
13 |
+
self._criterion = AngularLoss(self._device)
|
14 |
+
self._optimizer = None
|
15 |
+
self._network = None
|
16 |
+
|
17 |
+
def print_network(self):
|
18 |
+
print("\n----------------------------------------------------------\n")
|
19 |
+
print(self._network)
|
20 |
+
print("\n----------------------------------------------------------\n")
|
21 |
+
|
22 |
+
def log_network(self, path_to_log: str):
|
23 |
+
open(os.path.join(path_to_log, "network.txt"), 'a+').write(str(self._network))
|
24 |
+
|
25 |
+
def get_loss(self, pred: Tensor, label: Tensor) -> Tensor:
|
26 |
+
return self._criterion(pred, label)
|
27 |
+
|
28 |
+
def train_mode(self):
|
29 |
+
self._network = self._network.train()
|
30 |
+
|
31 |
+
def evaluation_mode(self):
|
32 |
+
self._network = self._network.eval()
|
33 |
+
|
34 |
+
def save(self, path_to_log: str):
|
35 |
+
torch.save(self._network.state_dict(), os.path.join(path_to_log, "model.pth"))
|
36 |
+
|
37 |
+
def load(self, path_to_pretrained: str):
|
38 |
+
path_to_model = os.path.join(path_to_pretrained, "model.pth")
|
39 |
+
self._network.load_state_dict(torch.load(path_to_model, map_location=self._device))
|
40 |
+
|
41 |
+
def set_optimizer(self, learning_rate: float, optimizer_type: str = "adam"):
|
42 |
+
optimizers_map = {"adam": torch.optim.Adam, "rmsprop": torch.optim.RMSprop}
|
43 |
+
self._optimizer = optimizers_map[optimizer_type](self._network.parameters(), lr=learning_rate)
|
MiAlgo/classes/core/__pycache__/Evaluator.cpython-36.pyc
ADDED
Binary file (2.37 kB). View file
|
|
MiAlgo/classes/core/__pycache__/Evaluator.cpython-37.pyc
ADDED
Binary file (2.34 kB). View file
|
|
MiAlgo/classes/core/__pycache__/Evaluator.cpython-38.pyc
ADDED
Binary file (2.42 kB). View file
|
|
MiAlgo/classes/core/__pycache__/Loss.cpython-36.pyc
ADDED
Binary file (780 Bytes). View file
|
|
MiAlgo/classes/core/__pycache__/Loss.cpython-37.pyc
ADDED
Binary file (763 Bytes). View file
|
|
MiAlgo/classes/core/__pycache__/Loss.cpython-38.pyc
ADDED
Binary file (803 Bytes). View file
|
|
MiAlgo/classes/core/__pycache__/Model.cpython-36.pyc
ADDED
Binary file (2.35 kB). View file
|
|
MiAlgo/classes/core/__pycache__/Model.cpython-37.pyc
ADDED
Binary file (2.34 kB). View file
|
|
MiAlgo/classes/core/__pycache__/Model.cpython-38.pyc
ADDED
Binary file (2.39 kB). View file
|
|
MiAlgo/classes/data/ColorCheckerDataset.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import scipy.io
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from auxiliary.utils import normalize, bgr_to_rgb, linear_to_nonlinear, hwc_to_chw
|
10 |
+
from classes.data.DataAugmenter import DataAugmenter
|
11 |
+
|
12 |
+
|
13 |
+
class ColorCheckerDataset(data.Dataset):
|
14 |
+
|
15 |
+
def __init__(self, train: bool = True, folds_num: int = 1):
|
16 |
+
|
17 |
+
self.__train = train
|
18 |
+
self.__da = DataAugmenter()
|
19 |
+
|
20 |
+
path_to_folds = os.path.join("dataset", "folds.mat")
|
21 |
+
path_to_metadata = os.path.join("dataset", "metadata.txt")
|
22 |
+
self.__path_to_data = os.path.join("dataset", "preprocessed", "numpy_data")
|
23 |
+
self.__path_to_label = os.path.join("dataset", "preprocessed", "numpy_labels")
|
24 |
+
|
25 |
+
folds = scipy.io.loadmat(path_to_folds)
|
26 |
+
img_idx = folds["tr_split" if self.__train else "te_split"][0][folds_num][0]
|
27 |
+
|
28 |
+
metadata = open(path_to_metadata, 'r').readlines()
|
29 |
+
self.__fold_data = [metadata[i - 1] for i in img_idx]
|
30 |
+
|
31 |
+
def __getitem__(self, index: int) -> Tuple:
|
32 |
+
file_name = self.__fold_data[index].strip().split(' ')[1]
|
33 |
+
img = np.array(np.load(os.path.join(self.__path_to_data, file_name + '.npy')), dtype='float32')
|
34 |
+
illuminant = np.array(np.load(os.path.join(self.__path_to_label, file_name + '.npy')), dtype='float32')
|
35 |
+
|
36 |
+
if self.__train:
|
37 |
+
img, illuminant = self.__da.augment(img, illuminant)
|
38 |
+
else:
|
39 |
+
img = self.__da.crop(img)
|
40 |
+
|
41 |
+
img = hwc_to_chw(linear_to_nonlinear(bgr_to_rgb(normalize(img))))
|
42 |
+
|
43 |
+
img = torch.from_numpy(img.copy())
|
44 |
+
illuminant = torch.from_numpy(illuminant.copy())
|
45 |
+
|
46 |
+
if not self.__train:
|
47 |
+
img = img.type(torch.FloatTensor)
|
48 |
+
|
49 |
+
return img, illuminant, file_name
|
50 |
+
|
51 |
+
def __len__(self) -> int:
|
52 |
+
return len(self.__fold_data)
|
MiAlgo/classes/data/DataAugmenter.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from auxiliary.settings import TRAIN_IMG_H, TRAIN_IMG_W, TEST_IMG_H, TEST_IMG_W
|
8 |
+
from auxiliary.utils import rgb_to_bgr
|
9 |
+
|
10 |
+
|
11 |
+
class DataAugmenter:
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
# Input Size of the fully-convolutional network (SqueezeNet)
|
15 |
+
self.__train_size = (TRAIN_IMG_H, TRAIN_IMG_W)
|
16 |
+
self.__test_size = (TEST_IMG_H, TEST_IMG_W)
|
17 |
+
|
18 |
+
# Rotation angle
|
19 |
+
self.__angle = 60
|
20 |
+
|
21 |
+
# Patch scale
|
22 |
+
self.__scale = [0.1, 1.0]
|
23 |
+
|
24 |
+
# Color rescaling
|
25 |
+
self.__color = 0.8
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def __rotate_image(image: np.ndarray, angle: float) -> np.ndarray:
|
29 |
+
"""
|
30 |
+
Rotates an OpenCV 2 / NumPy image about it's centre by the given angle (in degrees).
|
31 |
+
The returned image will be large enough to hold the entire new image, with a black background
|
32 |
+
"""
|
33 |
+
|
34 |
+
# Get the image size (note: NumPy stores image matrices backwards)
|
35 |
+
image_size = (image.shape[1], image.shape[0])
|
36 |
+
image_center = tuple(np.array(image_size) / 2)
|
37 |
+
|
38 |
+
# Convert the OpenCV 3x2 rotation matrix to 3x3
|
39 |
+
rot_mat = np.vstack([cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]])
|
40 |
+
|
41 |
+
rot_mat_no_translate = np.matrix(rot_mat[0:2, 0:2])
|
42 |
+
|
43 |
+
image_w2, image_h2 = image_size[0] * 0.5, image_size[1] * 0.5
|
44 |
+
|
45 |
+
# Obtain the rotated coordinates of the image corners
|
46 |
+
rotated_coords = [
|
47 |
+
(np.array([-image_w2, image_h2]) * rot_mat_no_translate).A[0],
|
48 |
+
(np.array([image_w2, image_h2]) * rot_mat_no_translate).A[0],
|
49 |
+
(np.array([-image_w2, -image_h2]) * rot_mat_no_translate).A[0],
|
50 |
+
(np.array([image_w2, -image_h2]) * rot_mat_no_translate).A[0]
|
51 |
+
]
|
52 |
+
|
53 |
+
# Find the size of the new image
|
54 |
+
x_coords = [pt[0] for pt in rotated_coords]
|
55 |
+
x_pos, x_neg = [x for x in x_coords if x > 0], [x for x in x_coords if x < 0]
|
56 |
+
|
57 |
+
y_coords = [pt[1] for pt in rotated_coords]
|
58 |
+
y_pos, y_neg = [y for y in y_coords if y > 0], [y for y in y_coords if y < 0]
|
59 |
+
|
60 |
+
right_bound, left_bound, top_bound, bot_bound = max(x_pos), min(x_neg), max(y_pos), min(y_neg)
|
61 |
+
new_w, new_h = int(abs(right_bound - left_bound)), int(abs(top_bound - bot_bound))
|
62 |
+
|
63 |
+
# We require a translation matrix to keep the image centred
|
64 |
+
trans_mat = np.matrix([[1, 0, int(new_w * 0.5 - image_w2)], [0, 1, int(new_h * 0.5 - image_h2)], [0, 0, 1]])
|
65 |
+
|
66 |
+
# Compute the transform for the combined rotation and translation
|
67 |
+
affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :]
|
68 |
+
|
69 |
+
# Apply the transform
|
70 |
+
return cv2.warpAffine(image, affine_mat, (new_w, new_h), flags=cv2.INTER_LINEAR)
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def __largest_rotated_rect(w: float, h: float, angle: float) -> tuple:
|
74 |
+
"""
|
75 |
+
Given a rectangle of size w x h that has been rotated by 'angle' (in radians), computes the width and height of
|
76 |
+
the largest possible axis-aligned rectangle within the rotated rectangle.
|
77 |
+
|
78 |
+
Original JS code by 'Andri' and Magnus Hoff from Stack Overflow. Converted to Python by Aaron Snoswell
|
79 |
+
"""
|
80 |
+
quadrant = int(math.floor(angle / (math.pi / 2))) & 3
|
81 |
+
sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle
|
82 |
+
alpha = (sign_alpha % math.pi + math.pi) % math.pi
|
83 |
+
|
84 |
+
bb_w = w * math.cos(alpha) + h * math.sin(alpha)
|
85 |
+
bb_h = w * math.sin(alpha) + h * math.cos(alpha)
|
86 |
+
|
87 |
+
length = h if (w < h) else w
|
88 |
+
d = length * math.cos(alpha)
|
89 |
+
|
90 |
+
gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w)
|
91 |
+
delta = math.pi - alpha - gamma
|
92 |
+
a = d * math.sin(alpha) / math.sin(delta)
|
93 |
+
|
94 |
+
y = a * math.cos(gamma)
|
95 |
+
x = y * math.tan(gamma)
|
96 |
+
|
97 |
+
return bb_w - 2 * x, bb_h - 2 * y
|
98 |
+
|
99 |
+
def __crop_around_center(self, image: np.ndarray, width: float, height: float) -> np.ndarray:
|
100 |
+
""" Given a NumPy / OpenCV 2 image, crops it to the given width and height around it's centre point """
|
101 |
+
|
102 |
+
image_size = (image.shape[1], image.shape[0])
|
103 |
+
image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5))
|
104 |
+
|
105 |
+
width = image_size[0] if width > image_size[0] else width
|
106 |
+
height = image_size[1] if height > image_size[1] else height
|
107 |
+
|
108 |
+
x1, x2 = int(image_center[0] - width * 0.5), int(image_center[0] + width * 0.5)
|
109 |
+
y1, y2 = int(image_center[1] - height * 0.5), int(image_center[1] + height * 0.5)
|
110 |
+
|
111 |
+
return cv2.resize(image[y1:y2, x1:x2], self.__train_size)
|
112 |
+
|
113 |
+
def __rotate_and_crop(self, image: np.ndarray, angle: float) -> np.ndarray:
|
114 |
+
width, height = image.shape[:2]
|
115 |
+
target_width, target_height = self.__largest_rotated_rect(width, height, math.radians(angle))
|
116 |
+
return self.__crop_around_center(self.__rotate_image(image, angle), target_width, target_height)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def __random_flip(img: np.ndarray) -> np.ndarray:
|
120 |
+
""" Perform random left/right flip with probability 0.5 """
|
121 |
+
if random.randint(0, 1):
|
122 |
+
img = img[:, ::-1]
|
123 |
+
return img.astype(np.float32)
|
124 |
+
|
125 |
+
def augment(self, img: np.ndarray, illumination: np.ndarray) -> tuple:
|
126 |
+
scale = math.exp(random.random() * math.log(self.__scale[1] / self.__scale[0])) * self.__scale[0]
|
127 |
+
s = min(max(int(round(min(img.shape[:2]) * scale)), 10), min(img.shape[:2]))
|
128 |
+
|
129 |
+
start_x = random.randrange(0, img.shape[0] - s + 1)
|
130 |
+
start_y = random.randrange(0, img.shape[1] - s + 1)
|
131 |
+
img = img[start_x:start_x + s, start_y:start_y + s]
|
132 |
+
|
133 |
+
img = self.__rotate_and_crop(img, angle=(random.random() - 0.5) * self.__angle)
|
134 |
+
img = self.__random_flip(img)
|
135 |
+
|
136 |
+
color_aug = np.zeros(shape=(3, 3))
|
137 |
+
for i in range(3):
|
138 |
+
color_aug[i, i] = 1 + random.random() * self.__color - 0.5 * self.__color
|
139 |
+
img *= np.array([[[color_aug[0][0], color_aug[1][1], color_aug[2][2]]]], dtype=np.float32)
|
140 |
+
new_image = np.clip(img, 0, 65535)
|
141 |
+
|
142 |
+
new_illuminant = np.zeros_like(illumination)
|
143 |
+
illumination = rgb_to_bgr(illumination)
|
144 |
+
for i in range(3):
|
145 |
+
for j in range(3):
|
146 |
+
new_illuminant[i] += illumination[j] * color_aug[i, j]
|
147 |
+
new_illuminant = rgb_to_bgr(np.clip(new_illuminant, 0.01, 100))
|
148 |
+
|
149 |
+
return new_image, new_illuminant
|
150 |
+
|
151 |
+
def crop(self, img: np.ndarray, scale: float = 0.5) -> np.ndarray:
|
152 |
+
return cv2.resize(img, self.__test_size, fx=scale, fy=scale)
|
MiAlgo/classes/data/__pycache__/ColorCheckerDataset.cpython-36.pyc
ADDED
Binary file (2.34 kB). View file
|
|
MiAlgo/classes/data/__pycache__/ColorCheckerDataset.cpython-37.pyc
ADDED
Binary file (2.31 kB). View file
|
|
MiAlgo/classes/data/__pycache__/ColorCheckerDataset.cpython-38.pyc
ADDED
Binary file (2.36 kB). View file
|
|
MiAlgo/classes/data/__pycache__/DataAugmenter.cpython-36.pyc
ADDED
Binary file (6.39 kB). View file
|
|
MiAlgo/classes/data/__pycache__/DataAugmenter.cpython-37.pyc
ADDED
Binary file (6.36 kB). View file
|
|
MiAlgo/classes/data/__pycache__/DataAugmenter.cpython-38.pyc
ADDED
Binary file (6.41 kB). View file
|
|
MiAlgo/classes/fc4/FC4.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
from torch.nn.functional import normalize
|
6 |
+
|
7 |
+
from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING
|
8 |
+
from classes.fc4.squeezenet.SqueezeNetLoader import SqueezeNetLoader
|
9 |
+
|
10 |
+
"""
|
11 |
+
FC4: Fully Convolutional Color Constancy with Confidence-weighted Pooling
|
12 |
+
* Original code: https://github.com/yuanming-hu/fc4
|
13 |
+
* Paper: https://www.microsoft.com/en-us/research/publication/fully-convolutional-color-constancy-confidence-weighted-pooling/
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
class FC4(torch.nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, squeezenet_version: float = 1.1):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
# SqueezeNet backbone (conv1-fire8) for extracting semantic features
|
23 |
+
squeezenet = SqueezeNetLoader(squeezenet_version).load(pretrained=True)
|
24 |
+
self.backbone = nn.Sequential(*list(squeezenet.children())[0][:12])
|
25 |
+
|
26 |
+
# Final convolutional layers (conv6 and conv7) to extract semi-dense feature maps
|
27 |
+
self.final_convs = nn.Sequential(
|
28 |
+
nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True),
|
29 |
+
nn.Conv2d(512, 64, kernel_size=6, stride=1, padding=3),
|
30 |
+
nn.ReLU(inplace=True),
|
31 |
+
nn.Dropout(p=0.5),
|
32 |
+
nn.Conv2d(64, 4 if USE_CONFIDENCE_WEIGHTED_POOLING else 3, kernel_size=1, stride=1),
|
33 |
+
nn.ReLU(inplace=True)
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x: Tensor) -> Union[tuple, Tensor]:
|
37 |
+
"""
|
38 |
+
Estimate an RGB colour for the illuminant of the input image
|
39 |
+
@param x: the image for which the colour of the illuminant has to be estimated
|
40 |
+
@return: the colour estimate as a Tensor. If confidence-weighted pooling is used, the per-path colour estimates
|
41 |
+
and the confidence weights are returned as well (used for visualizations)
|
42 |
+
"""
|
43 |
+
|
44 |
+
x = self.backbone(x)
|
45 |
+
out = self.final_convs(x)
|
46 |
+
|
47 |
+
# Confidence-weighted pooling: "out" is a set of semi-dense feature maps
|
48 |
+
if USE_CONFIDENCE_WEIGHTED_POOLING:
|
49 |
+
# Per-patch color estimates (first 3 dimensions)
|
50 |
+
rgb = normalize(out[:, :3, :, :], dim=1)
|
51 |
+
|
52 |
+
# Confidence (last dimension)
|
53 |
+
confidence = out[:, 3:4, :, :]
|
54 |
+
|
55 |
+
# Confidence-weighted pooling
|
56 |
+
pred = normalize(torch.sum(torch.sum(rgb * confidence, 2), 2), dim=1)
|
57 |
+
|
58 |
+
return pred, rgb, confidence
|
59 |
+
|
60 |
+
# Summation pooling
|
61 |
+
pred = normalize(torch.sum(torch.sum(out, 2), 2), dim=1)
|
62 |
+
|
63 |
+
return pred
|
MiAlgo/classes/fc4/ModelFC4.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union, Tuple
|
3 |
+
|
4 |
+
import torchvision.transforms.functional as F
|
5 |
+
from torch import Tensor
|
6 |
+
from torchvision.transforms import transforms
|
7 |
+
|
8 |
+
from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING
|
9 |
+
from auxiliary.utils import correct, rescale, scale
|
10 |
+
from classes.core.Model import Model
|
11 |
+
from classes.fc4.FC4 import FC4
|
12 |
+
|
13 |
+
|
14 |
+
class ModelFC4(Model):
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
self._network = FC4().to(self._device)
|
19 |
+
|
20 |
+
def predict(self, img: Tensor, return_steps: bool = False) -> Union[Tensor, Tuple]:
|
21 |
+
"""
|
22 |
+
Performs inference on the input image using the FC4 method.
|
23 |
+
@param img: the image for which an illuminant colour has to be estimated
|
24 |
+
@param return_steps: whether or not to also return the per-patch estimates and confidence weights. When this
|
25 |
+
flag is set to True, confidence-weighted pooling must be active)
|
26 |
+
@return: the colour estimate as a Tensor. If "return_steps" is set to true, the per-path colour estimates and
|
27 |
+
the confidence weights are also returned (used for visualizations)
|
28 |
+
"""
|
29 |
+
if USE_CONFIDENCE_WEIGHTED_POOLING:
|
30 |
+
pred, rgb, confidence = self._network(img)
|
31 |
+
if return_steps:
|
32 |
+
return pred, rgb, confidence
|
33 |
+
return pred
|
34 |
+
return self._network(img)
|
35 |
+
|
36 |
+
def optimize(self, img: Tensor, label: Tensor) -> float:
|
37 |
+
self._optimizer.zero_grad()
|
38 |
+
pred = self.predict(img)
|
39 |
+
loss = self.get_loss(pred, label)
|
40 |
+
loss.backward()
|
41 |
+
self._optimizer.step()
|
42 |
+
return loss.item()
|
43 |
+
|
44 |
+
|
MiAlgo/classes/fc4/__pycache__/FC4.cpython-36.pyc
ADDED
Binary file (2.05 kB). View file
|
|