File size: 5,221 Bytes
032a6d3
90f636a
032a6d3
 
d3649fd
522b959
 
90f636a
d3649fd
90f636a
522b959
032a6d3
 
 
 
 
 
 
 
522b959
032a6d3
 
 
 
522b959
 
 
 
 
 
032a6d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522b959
032a6d3
522b959
032a6d3
522b959
 
 
 
 
 
032a6d3
522b959
032a6d3
 
 
 
 
 
 
 
522b959
032a6d3
522b959
 
 
 
032a6d3
522b959
 
 
 
 
032a6d3
 
 
 
 
 
 
 
 
 
 
522b959
032a6d3
 
 
 
 
 
 
 
 
 
522b959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from PIL import Image
import gradio as gr
import numpy as np
from datasets import load_dataset
import os
import tempfile


dataset = load_dataset("erceguder/histocan-test", token=os.environ["HF_TOKEN"])


COLOR_PALETTE = {
    'others': (0, 0, 0),
    't-g1': (0, 192, 0),
    't-g2': (255, 224, 32),
    't-g3': (255, 0, 0),
    'normal-mucosa': (0, 32, 255)
}


def files_uploaded(paths):
    if len(paths) != 16:
        raise gr.Error("16 segmentation masks are needed.")

    uploaded_file_names = [paths[i].name.split('/')[-1] for i in range(16)]
    for i in range(16):
        if f"test{i:04d}.png" not in uploaded_file_names:
            raise gr.Error(f"Uploaded file names are not recognized.")


def evaluate(paths):
    if paths == None:
        raise gr.Error("Upload segmentation masks first!")

    # Init dicts for accumulating image metrics and calculating per-class scores
    metrics = {}
    for class_ in COLOR_PALETTE.keys():
        idict = {
            "tp": 0.0,
            "fp": 0.0,
            "tn": 0.0,
            "fn": 0.0,
        }
        metrics[class_] = idict

    scores = {}
    for class_ in COLOR_PALETTE.keys():
        idict = {
            "recall": 0.0,
            "precision": 0.0,
            "f1": 0.0
        }
        scores[class_] = idict

    tmpdir = tempfile.TemporaryDirectory()
    for path in paths:
        os.rename(path.name, os.path.join(tmpdir.name, path.name.split('/')[-1]))

    for item in dataset["test"]:
        pred_path = os.path.join(tmpdir.name, item["name"])
        pred = np.array(Image.open(pred_path))
        gt = np.array(item["annotation"])

        assert gt.ndim == 2
        assert pred.ndim == 3 and pred.shape[-1] == 3
        assert gt.shape == pred.shape[:-1]

        # Get predictions for all classes
        out = [(pred == color).all(axis=-1) for color in COLOR_PALETTE.values()]
        maps = np.stack(out)

        # Calculate confusion matrix and metrics
        for i, class_ in enumerate(COLOR_PALETTE.keys()):
            class_pred = maps[i]
            class_gt = (gt == i)

            tp = np.sum(class_pred[class_gt==True])
            fp = np.sum(class_pred[class_gt==False])
            tn = np.sum(np.logical_not(class_pred)[class_gt==False])
            fn = np.sum(np.logical_not(class_pred)[class_gt==True])

            # Accumulate metrics for each class
            metrics[class_]['tp'] += tp
            metrics[class_]['fp'] += fp
            metrics[class_]['tn'] += tn
            metrics[class_]['fn'] += fn

    # Init mean recall, precision and F1 score
    mRecall = 0.0
    mPrecision = 0.0
    mF1 = 0.0

    # Calculate recall, precision and f1 scores for each class
    for i, class_ in enumerate(COLOR_PALETTE.keys()):
        scores[class_]['recall'] = metrics[class_]['tp'] / (metrics[class_]['tp'] + metrics[class_]['fn']) if metrics[class_]['tp'] > 0 else 0.0
        scores[class_]['precision'] = metrics[class_]['tp'] / (metrics[class_]['tp'] + metrics[class_]['fp']) if metrics[class_]['tp'] > 0 else 0.0
        scores[class_]['f1'] = 2 * scores[class_]['precision'] * scores[class_]['recall'] / (scores[class_]['precision'] + scores[class_]['recall']) if (scores[class_]['precision'] != 0 and scores[class_]['recall'] != 0) else 0.0

        mRecall += scores[class_]['recall']
        mPrecision += scores[class_]['precision']
        mF1 += scores[class_]['f1']

    # Calculate mean recall, precision and F1 score over all classes
    class_count = len(COLOR_PALETTE)
    mRecall /= class_count
    mPrecision /= class_count
    mF1 /= class_count

    tmpdir.cleanup()

    result = """
    <div align="center">

    # Results

    |           | Others | T-G1 | T-G2 | T-G3 | Normal mucosa |
    |-----------|--------|------|------|------|---------------|
    | Precision | {:.2f} |{:.2f}|{:.2f}|{:.2f}|    {:.2f}     |
    | Recall    | {:.2f} |{:.2f}|{:.2f}|{:.2f}|    {:.2f}     |
    | Dice      | {:.2f} |{:.2f}|{:.2f}|{:.2f}|    {:.2f}     |

    ### mPrecision: {:.4f}
    ### mRecall: {:.4f}
    ### mDice: {:.4f}

    </div>
    """

    result = result.format(
        scores["others"]["precision"],
        scores["t-g1"]["precision"],
        scores["t-g2"]["precision"],
        scores["t-g3"]["precision"],
        scores["normal-mucosa"]["precision"],
        scores["others"]["recall"],
        scores["t-g1"]["recall"],
        scores["t-g2"]["recall"],
        scores["t-g3"]["recall"],
        scores["normal-mucosa"]["recall"],
        scores["others"]["f1"],
        scores["t-g1"]["f1"],
        scores["t-g2"]["f1"],
        scores["t-g3"]["f1"],
        scores["normal-mucosa"]["f1"],
        mPrecision,
        mRecall,
        mF1
    )
    return gr.Markdown(value=result)


if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown("# Histocan Test Set Evaluation Page")
        files = gr.File(label="Upload your segmentation masks for the test set", file_count="multiple", file_types=["image"])
        run = gr.Button(value="Evaluate!")
        output = gr.Markdown(value="")

        files.upload(files_uploaded, files, [])
        run.click(evaluate, files, [output])

        demo.queue(max_size=1)
        demo.launch()