Spaces:
Running
Running
Update test.py
Browse files
test.py
CHANGED
@@ -1,103 +1,2 @@
|
|
1 |
|
2 |
-
import os
|
3 |
|
4 |
-
import numpy as np
|
5 |
-
import cv2
|
6 |
-
import pandas as pd
|
7 |
-
from glob import glob
|
8 |
-
from tqdm import tqdm
|
9 |
-
import tensorflow as tf
|
10 |
-
from tensorflow.keras.utils import CustomObjectScope
|
11 |
-
from sklearn.metrics import f1_score, jaccard_score, precision_score, recall_score
|
12 |
-
from sklearn.model_selection import train_test_split
|
13 |
-
from metrics import dice_loss, dice_coef
|
14 |
-
from train import load_dataset
|
15 |
-
from unet import build_unet
|
16 |
-
|
17 |
-
""" Global parameters """
|
18 |
-
H = 256
|
19 |
-
W = 256
|
20 |
-
|
21 |
-
""" Creating a directory """
|
22 |
-
def create_dir(path):
|
23 |
-
if not os.path.exists(path):
|
24 |
-
os.makedirs(path)
|
25 |
-
|
26 |
-
def save_results(image, mask, y_pred, save_image_path):
|
27 |
-
mask = np.expand_dims(mask, axis=-1)
|
28 |
-
mask = np.concatenate([mask, mask, mask], axis=-1)
|
29 |
-
|
30 |
-
y_pred = np.expand_dims(y_pred, axis=-1)
|
31 |
-
y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1)
|
32 |
-
y_pred = y_pred * 255
|
33 |
-
|
34 |
-
line = np.ones((H, 10, 3)) * 255
|
35 |
-
|
36 |
-
cat_images = np.concatenate([image, line, mask, line, y_pred], axis=1)
|
37 |
-
cv2.imwrite(save_image_path, cat_images)
|
38 |
-
|
39 |
-
|
40 |
-
if __name__ == "__main__":
|
41 |
-
""" Seeding """
|
42 |
-
np.random.seed(42)
|
43 |
-
tf.random.set_seed(42)
|
44 |
-
|
45 |
-
""" Directory for storing files """
|
46 |
-
create_dir("results")
|
47 |
-
|
48 |
-
""" Load the model """
|
49 |
-
with CustomObjectScope({"dice_coef": dice_coef, "dice_loss": dice_loss}):
|
50 |
-
model = tf.keras.models.load_model(os.path.join("files", "model.h5"))
|
51 |
-
|
52 |
-
""" Dataset """
|
53 |
-
dataset_path = "/media/nikhil/Seagate Backup Plus Drive/ML_DATASET/brain_tumor_dataset/data"
|
54 |
-
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_dataset(dataset_path)
|
55 |
-
|
56 |
-
""" Prediction and Evaluation """
|
57 |
-
SCORE = []
|
58 |
-
for x, y in tqdm(zip(test_x, test_y), total=len(test_y)):
|
59 |
-
""" Extracting the name """
|
60 |
-
name = x.split("/")[-1]
|
61 |
-
|
62 |
-
""" Reading the image """
|
63 |
-
image = cv2.imread(x, cv2.IMREAD_COLOR) ## [H, w, 3]
|
64 |
-
image = cv2.resize(image, (W, H)) ## [H, w, 3]
|
65 |
-
x = image/255.0 ## [H, w, 3]
|
66 |
-
x = np.expand_dims(x, axis=0) ## [1, H, w, 3]
|
67 |
-
|
68 |
-
""" Reading the mask """
|
69 |
-
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
|
70 |
-
mask = cv2.resize(mask, (W, H))
|
71 |
-
|
72 |
-
""" Prediction """
|
73 |
-
y_pred = model.predict(x, verbose=0)[0]
|
74 |
-
y_pred = np.squeeze(y_pred, axis=-1)
|
75 |
-
y_pred = y_pred >= 0.5
|
76 |
-
y_pred = y_pred.astype(np.int32)
|
77 |
-
|
78 |
-
""" Saving the prediction """
|
79 |
-
save_image_path = os.path.join("results", name)
|
80 |
-
save_results(image, mask, y_pred, save_image_path)
|
81 |
-
|
82 |
-
""" Flatten the array """
|
83 |
-
mask = mask/255.0
|
84 |
-
mask = (mask > 0.5).astype(np.int32).flatten()
|
85 |
-
y_pred = y_pred.flatten()
|
86 |
-
|
87 |
-
""" Calculating the metrics values """
|
88 |
-
f1_value = f1_score(mask, y_pred, labels=[0, 1], average="binary")
|
89 |
-
jac_value = jaccard_score(mask, y_pred, labels=[0, 1], average="binary")
|
90 |
-
recall_value = recall_score(mask, y_pred, labels=[0, 1], average="binary", zero_division=0)
|
91 |
-
precision_value = precision_score(mask, y_pred, labels=[0, 1], average="binary", zero_division=0)
|
92 |
-
SCORE.append([name, f1_value, jac_value, recall_value, precision_value])
|
93 |
-
|
94 |
-
""" Metrics values """
|
95 |
-
score = [s[1:]for s in SCORE]
|
96 |
-
score = np.mean(score, axis=0)
|
97 |
-
print(f"F1: {score[0]:0.5f}")
|
98 |
-
print(f"Jaccard: {score[1]:0.5f}")
|
99 |
-
print(f"Recall: {score[2]:0.5f}")
|
100 |
-
print(f"Precision: {score[3]:0.5f}")
|
101 |
-
|
102 |
-
df = pd.DataFrame(SCORE, columns=["Image", "F1", "Jaccard", "Recall", "Precision"])
|
103 |
-
df.to_csv("files/score.csv")
|
|
|
1 |
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|