File size: 6,940 Bytes
e1b51e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
print("Importing standard...")
import subprocess
import shutil
from pathlib import Path

print("Importing external...")
import torch
import numpy as np
from PIL import Image

REDUCTION = "pca"
if REDUCTION == "umap":
    from umap import UMAP
elif REDUCTION == "tsne":
    from sklearn.manifold import TSNE
elif REDUCTION == "pca":
    from sklearn.decomposition import PCA


def symlog(x):
    return torch.sign(x) * torch.log(torch.abs(x) + 1)


def preprocess_masks_features(masks, features):
    # Get shapes right
    B, M, H, W = masks.shape
    Bf, F, Hf, Wf = features.shape
    masks = masks.reshape(B, M, 1, H * W)
    # # the following assertions should work, remove due to speed
    # assert H == Hf and W == Wf and B == Bf
    # assert masks.dtype == torch.bool
    # assert (mask_areas > 0).all(), "you shouldn't have empty masks"

    # Reduce M if there are empty masks
    mask_areas = masks.sum(dim=3)  # B, M, 1
    features = features.reshape(B, 1, F, H * W)
    # output shapes
    # features: B, 1, F, H*W
    # masks: B, M, 1, H*W

    return masks, features, M, B, H, W, F


def get_row_col(H, W, device):
    # get position of pixels in [0, 1]
    row = torch.linspace(0, 1, H, device=device)
    col = torch.linspace(0, 1, W, device=device)
    return row, col


def get_current_git_commit():
    try:
        # Run the git command to get the current commit hash
        commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
        # Decode from bytes to a string
        return commit_hash.decode("utf-8")
    except subprocess.CalledProcessError:
        # Handle the case where the command fails (e.g., not a Git repository)
        print("An error occurred while trying to retrieve the git commit hash.")
        return None


def clean_dir(dirname):
    """Removes all directories in dirname that don't have a done.txt file"""
    dstdir = Path(dirname)
    dstdir.mkdir(exist_ok=True, parents=True)
    for f in dstdir.iterdir():
        # if the directory doesn't have a done.txt file remove it
        if f.is_dir() and not (f / "done.txt").exists():
            shutil.rmtree(f)


def save_tensor_as_image(tensor, dstfile, global_step):
    dstfile = Path(dstfile)
    dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix(
        ".jpg"
    )
    save(tensor, str(dstfile))


def minmaxnorm(x):
    return (x - x.min()) / (x.max() - x.min())


def save(tensor, name, channel_offset=0):
    tensor = to_img(tensor, channel_offset=channel_offset)
    Image.fromarray(tensor).save(name)


def to_img(tensor, channel_offset=0):
    tensor = minmaxnorm(tensor)
    tensor = (tensor * 255).to(torch.uint8)
    C, H, W = tensor.shape
    if tensor.shape[0] == 1:
        tensor = tensor[0]
    elif tensor.shape[0] == 2:
        tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0)
        tensor = tensor.permute(1, 2, 0)
    elif tensor.shape[0] >= 3:
        tensor = tensor[channel_offset : channel_offset + 3]
        tensor = tensor.permute(1, 2, 0)
    tensor = tensor.cpu().numpy()
    return tensor


def log_input_output(
    name,
    x,
    y_hat,
    global_step,
    img_dstdir,
    out_dstdir,
    reduce_dim=True,
    reduction=REDUCTION,
    resample_size=20000,
):
    y_hat = y_hat.reshape(
        y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4]
    )
    if reduce_dim and y_hat.shape[1] >= 3:
        reducer = (
            UMAP(n_components=3)
            if (reduction == "umap")
            else (
                TSNE(n_components=3)
                if reduction == "tsne"
                else PCA(n_components=3)
                if reduction == "pca"
                else None
            )
        )
        np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy()  # F, 1, B, H, W
        np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1)  # F, BHW
        np_y_hat = np_y_hat.T  # BHW, F
        sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
        print("dim reduction fit..." + " " * 30, end="\r")
        reducer = reducer.fit(sampled_pixels)
        print("dim reduction transform..." + " " * 30, end="\r")
        reducer.transform(np_y_hat[:10])  # to numba compile the function
        np_y_hat = reducer.transform(np_y_hat)  # BHW, 3
        # revert back to original shape
        y_hat2 = (
            torch.from_numpy(
                np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3])
            )
            .to(y_hat.device)
            .permute(1, 0, 2, 3)
        )
        print("done" + " " * 30, end="\r")
    else:
        y_hat2 = y_hat

    for i in range(min(len(x), 8)):
        save_tensor_as_image(
            x[i],
            img_dstdir / f"input_{name}_{str(i).zfill(2)}",
            global_step=global_step,
        )
        for c in range(y_hat.shape[1]):
            save_tensor_as_image(
                y_hat[i, c : c + 1],
                out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}",
                global_step=global_step,
            )
        # log color image

        assert len(y_hat2.shape) == 4, "should be B, F, H, W"
        if reduce_dim:
            save_tensor_as_image(
                y_hat2[i][:3],
                out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}",
                global_step=global_step,
            )
        save_tensor_as_image(
            y_hat[i][:3],
            out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}",
            global_step=global_step,
        )


def check_for_nan(loss, model, batch):
    try:
        assert torch.isnan(loss) == False
    except Exception as e:
        # print things useful to debug
        # does the batch contain nan?
        print("img batch contains nan?", torch.isnan(batch[0]).any())
        print("mask batch contains nan?", torch.isnan(batch[1]).any())
        # does the model weights contain nan?
        for name, param in model.named_parameters():
            if torch.isnan(param).any():
                print(name, "contains nan")
        # does the output contain nan?
        print("output contains nan?", torch.isnan(model(batch[0])).any())
        # now raise the error
        raise e


def calculate_iou(pred, label):
    intersection = ((label == 1) & (pred == 1)).sum()
    union = ((label == 1) | (pred == 1)).sum()
    if not union:
        return 0
    else:
        iou = intersection.item() / union.item()
        return iou


def load_from_ckpt(net, ckpt_path, strict=True):
    """Load network weights"""
    if ckpt_path and Path(ckpt_path).exists():
        ckpt = torch.load(ckpt_path, map_location="cpu")
        if "MODEL_STATE" in ckpt:
            ckpt = ckpt["MODEL_STATE"]
        elif "state_dict" in ckpt:
            ckpt = ckpt["state_dict"]
        net.load_state_dict(ckpt, strict=strict)
        print("Loaded checkpoint from", ckpt_path)
    return net