Adding gender prediction
Browse files- app.py +18 -9
- models/__pycache__/inception.cpython-39.pyc +0 -0
- models/inception.py +40 -1
- models/weights/{model_weights_leadI.h5 → model_weights_leadI_age.h5} +0 -0
- models/weights/model_weights_leadI_gender.h5 +3 -0
- sample_data/ath_008.dat +0 -0
- sample_data/ath_008.hea +15 -0
- sample_data/ath_013.dat +0 -0
- sample_data/ath_013.hea +15 -0
app.py
CHANGED
@@ -24,14 +24,21 @@ def preprocess_ecg(ecg,fs):
|
|
24 |
pass
|
25 |
return ecg
|
26 |
|
27 |
-
def
|
28 |
cwd = os.getcwd()
|
29 |
-
weights = f"{cwd}/models/weights/
|
30 |
-
model =
|
31 |
model.load_weights(weights)
|
32 |
return model
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def run(header_file, data_file):
|
36 |
SAMPLE_FREQUENCY = 100
|
37 |
TIME = 10
|
@@ -43,9 +50,11 @@ def run(header_file, data_file):
|
|
43 |
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
|
44 |
data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
|
45 |
ecg = preprocess_ecg(data,fs)
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
|
50 |
# Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface
|
51 |
|
@@ -59,14 +68,14 @@ with gr.Blocks() as demo:
|
|
59 |
header_file = gr.File(label = "header_file", file_types=[".hea"],)
|
60 |
data_file = gr.File(label = "data_file", file_types=[".dat"])
|
61 |
with gr.Column(scale=1):
|
62 |
-
output_age = gr.Textbox(label = "
|
63 |
-
|
64 |
#with gr.Row():
|
65 |
# ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
|
66 |
with gr.Row():
|
67 |
predict_btn = gr.Button("Predict")
|
68 |
predict_btn.click(fn= run, inputs = [#pred_type,
|
69 |
-
header_file, data_file], outputs=[output_age])
|
70 |
with gr.Row():
|
71 |
gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
|
72 |
# [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
|
|
|
24 |
pass
|
25 |
return ecg
|
26 |
|
27 |
+
def load_age_model(sample_frequency,recording_time, num_leads):
|
28 |
cwd = os.getcwd()
|
29 |
+
weights = f"{cwd}/models/weights/model_weights_leadI_age.h5"
|
30 |
+
model = build_age_model((sample_frequency * recording_time, num_leads), 1)
|
31 |
model.load_weights(weights)
|
32 |
return model
|
33 |
|
34 |
|
35 |
+
def load_gender_model(sample_frequency,recording_time, num_leads):
|
36 |
+
cwd = os.getcwd()
|
37 |
+
weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5"
|
38 |
+
model = build_gender_model((sample_frequency * recording_time, num_leads), 1)
|
39 |
+
model.load_weights(weights)
|
40 |
+
return model
|
41 |
+
|
42 |
def run(header_file, data_file):
|
43 |
SAMPLE_FREQUENCY = 100
|
44 |
TIME = 10
|
|
|
50 |
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
|
51 |
data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
|
52 |
ecg = preprocess_ecg(data,fs)
|
53 |
+
age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
|
54 |
+
gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
|
55 |
+
age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0]
|
56 |
+
gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0]
|
57 |
+
return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction}
|
58 |
|
59 |
# Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface
|
60 |
|
|
|
68 |
header_file = gr.File(label = "header_file", file_types=[".hea"],)
|
69 |
data_file = gr.File(label = "data_file", file_types=[".dat"])
|
70 |
with gr.Column(scale=1):
|
71 |
+
output_age = gr.Textbox(label = "Estimated age")
|
72 |
+
output_gender = gr.Label( label = "Predicted gender")
|
73 |
#with gr.Row():
|
74 |
# ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
|
75 |
with gr.Row():
|
76 |
predict_btn = gr.Button("Predict")
|
77 |
predict_btn.click(fn= run, inputs = [#pred_type,
|
78 |
+
header_file, data_file], outputs=[output_age,output_gender])
|
79 |
with gr.Row():
|
80 |
gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
|
81 |
# [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
|
models/__pycache__/inception.cpython-39.pyc
CHANGED
Binary files a/models/__pycache__/inception.cpython-39.pyc and b/models/__pycache__/inception.cpython-39.pyc differ
|
|
models/inception.py
CHANGED
@@ -70,7 +70,7 @@ def _shortcut_layer(input_tensor, out_tensor):
|
|
70 |
return x
|
71 |
|
72 |
|
73 |
-
def
|
74 |
input_shape: Tuple[int, int],
|
75 |
nb_classes: int,
|
76 |
depth: int = 6,
|
@@ -105,4 +105,43 @@ def build_model(
|
|
105 |
metrics=[tf.keras.metrics.MeanSquaredError()],
|
106 |
)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
return model
|
|
|
70 |
return x
|
71 |
|
72 |
|
73 |
+
def build_age_model(
|
74 |
input_shape: Tuple[int, int],
|
75 |
nb_classes: int,
|
76 |
depth: int = 6,
|
|
|
105 |
metrics=[tf.keras.metrics.MeanSquaredError()],
|
106 |
)
|
107 |
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
def build_gender_model(
|
113 |
+
input_shape: Tuple[int, int],
|
114 |
+
nb_classes: int,
|
115 |
+
depth: int = 6,
|
116 |
+
use_residual: bool = True,
|
117 |
+
)-> tf.keras.models.Model:
|
118 |
+
"""
|
119 |
+
Model proposed by HI Fawas et al 2019 "Finding AlexNet for Time Series Classification - InceptionTime"
|
120 |
+
"""
|
121 |
+
input_layer = tf.keras.layers.Input(input_shape)
|
122 |
+
|
123 |
+
x = input_layer
|
124 |
+
input_res = input_layer
|
125 |
+
|
126 |
+
for d in range(depth):
|
127 |
+
|
128 |
+
x = _inception_module(x)
|
129 |
+
|
130 |
+
if use_residual and d % 3 == 2:
|
131 |
+
x = _shortcut_layer(input_res, x)
|
132 |
+
input_res = x
|
133 |
+
|
134 |
+
gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)
|
135 |
+
|
136 |
+
output_layer = tf.keras.layers.Dense(units=nb_classes, activation="sigmoid")(
|
137 |
+
gap_layer
|
138 |
+
)
|
139 |
+
|
140 |
+
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
|
141 |
+
model.compile(
|
142 |
+
loss=tf.keras.losses.BinaryCrossentropy(),
|
143 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
144 |
+
metrics=[tf.keras.metrics.AUC(curve='ROC',name="AUROC")],
|
145 |
+
)
|
146 |
+
|
147 |
return model
|
models/weights/{model_weights_leadI.h5 → model_weights_leadI_age.h5}
RENAMED
File without changes
|
models/weights/model_weights_leadI_gender.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac1616b91eb8f740aaae4c44c08948b3b3d1469e628eb3c145b7e9670264f023
|
3 |
+
size 1833768
|
sample_data/ath_008.dat
ADDED
Binary file (120 kB). View file
|
|
sample_data/ath_008.hea
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ath_008 12 500 5000
|
2 |
+
ath_008.dat 16 50000/mV 16 0 -11527 54167 0 I
|
3 |
+
ath_008.dat 16 50000/mV 16 0 -18070 20408 0 II
|
4 |
+
ath_008.dat 16 50000/mV 16 0 -15660 38252 0 III
|
5 |
+
ath_008.dat 16 50000/mV 16 0 16820 59491 0 AVR
|
6 |
+
ath_008.dat 16 50000/mV 16 0 4368 62574 0 AVL
|
7 |
+
ath_008.dat 16 50000/mV 16 0 -17613 24434 0 AVF
|
8 |
+
ath_008.dat 16 50000/mV 16 0 11148 20294 0 V1
|
9 |
+
ath_008.dat 16 50000/mV 16 0 10557 25409 0 V2
|
10 |
+
ath_008.dat 16 50000/mV 16 0 8134 4135 0 V3
|
11 |
+
ath_008.dat 16 50000/mV 16 0 1343 20358 0 V4
|
12 |
+
ath_008.dat 16 50000/mV 16 0 -12126 42898 0 V5
|
13 |
+
ath_008.dat 16 50000/mV 16 0 -22817 49296 0 V6
|
14 |
+
#SL12: Normal sinus rhythm, RSR' or QR pattern in V1 suggests right ventricular conduction delay, Borderline ECG
|
15 |
+
#C: Normal sinus rhythm, Incomplete right bundle branch block, Normal ECG
|
sample_data/ath_013.dat
ADDED
Binary file (120 kB). View file
|
|
sample_data/ath_013.hea
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ath_013 12 500 5000
|
2 |
+
ath_013.dat 16 50000/mV 16 0 5460 51465 0 I
|
3 |
+
ath_013.dat 16 50000/mV 16 0 -18724 34405 0 II
|
4 |
+
ath_013.dat 16 50000/mV 16 0 -28034 29205 0 III
|
5 |
+
ath_013.dat 16 50000/mV 16 0 9138 13588 0 AVR
|
6 |
+
ath_013.dat 16 50000/mV 16 0 19917 20645 0 AVL
|
7 |
+
ath_013.dat 16 50000/mV 16 0 -24408 31082 0 AVF
|
8 |
+
ath_013.dat 16 50000/mV 16 0 21564 56921 0 V1
|
9 |
+
ath_013.dat 16 50000/mV 16 0 20969 29226 0 V2
|
10 |
+
ath_013.dat 16 50000/mV 16 0 18256 12273 0 V3
|
11 |
+
ath_013.dat 16 50000/mV 16 0 -456 37253 0 V4
|
12 |
+
ath_013.dat 16 50000/mV 16 0 -7056 1140 0 V5
|
13 |
+
ath_013.dat 16 50000/mV 16 0 -14084 59880 0 V6
|
14 |
+
#SL12: Marked sinus bradycardia, Right axis deviation, Abnormal ECG
|
15 |
+
#C: Sinus bradycardia, Normal ECG
|