adding data loader, model loader and run func
Browse files- app.py +32 -0
- models/inception.py +108 -0
- models/weights/model_weights_leadI.h5 +3 -0
- sample_data/ath_001.dat +0 -0
- sample_data/ath_001.hea +16 -0
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wfdb
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from models.inception import *
|
6 |
+
|
7 |
+
|
8 |
+
def load_data():
|
9 |
+
cwd = os.getcwd()
|
10 |
+
sample_data = f"{cwd}/sample_data/ath_001"
|
11 |
+
ecg = wfdb.rdsamp(sample_data)
|
12 |
+
return np.asarray(ecg)
|
13 |
+
|
14 |
+
def load_model(sample_frequency,recording_time, num_leads):
|
15 |
+
cwd = os.getcwd()
|
16 |
+
weights = f"{cwd}/models/weights/model_weights_leadI.h5"
|
17 |
+
model = build_model((sample_frequency * recording_time, num_leads), 1)
|
18 |
+
model.load_weights(weights)
|
19 |
+
return model
|
20 |
+
|
21 |
+
|
22 |
+
def run(ecg):
|
23 |
+
SAMPLE_FREQUENCY = 100
|
24 |
+
TIME = 10
|
25 |
+
NUM_LEADS = 1
|
26 |
+
data = load_data()
|
27 |
+
model = load_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
|
28 |
+
predicion = model.predict(data)
|
29 |
+
return predicion
|
30 |
+
|
31 |
+
iface = gr.Interface(fn=run, inputs="text", outputs="text")
|
32 |
+
iface.launch()
|
models/inception.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
def _inception_module(
|
5 |
+
input_tensor,
|
6 |
+
stride=1,
|
7 |
+
activation="linear",
|
8 |
+
use_bottleneck=True,
|
9 |
+
kernel_size=40,
|
10 |
+
bottleneck_size=32,
|
11 |
+
nb_filters=32,
|
12 |
+
):
|
13 |
+
|
14 |
+
if use_bottleneck and int(input_tensor.shape[-1]) > 1:
|
15 |
+
input_inception = tf.keras.layers.Conv1D(
|
16 |
+
filters=bottleneck_size,
|
17 |
+
kernel_size=1,
|
18 |
+
padding="same",
|
19 |
+
activation=activation,
|
20 |
+
use_bias=False,
|
21 |
+
)(input_tensor)
|
22 |
+
else:
|
23 |
+
input_inception = input_tensor
|
24 |
+
|
25 |
+
# kernel_size_s = [3, 5, 8, 11, 17]
|
26 |
+
kernel_size_s = [kernel_size // (2**i) for i in range(3)]
|
27 |
+
|
28 |
+
conv_list = []
|
29 |
+
|
30 |
+
for i in range(len(kernel_size_s)):
|
31 |
+
conv_list.append(
|
32 |
+
tf.keras.layers.Conv1D(
|
33 |
+
filters=nb_filters,
|
34 |
+
kernel_size=kernel_size_s[i],
|
35 |
+
strides=stride,
|
36 |
+
padding="same",
|
37 |
+
activation=activation,
|
38 |
+
use_bias=False,
|
39 |
+
)(input_inception)
|
40 |
+
)
|
41 |
+
|
42 |
+
max_pool_1 = tf.keras.layers.MaxPool1D(pool_size=3, strides=stride, padding="same")(
|
43 |
+
input_tensor
|
44 |
+
)
|
45 |
+
|
46 |
+
conv_6 = tf.keras.layers.Conv1D(
|
47 |
+
filters=nb_filters,
|
48 |
+
kernel_size=1,
|
49 |
+
padding="same",
|
50 |
+
activation=activation,
|
51 |
+
use_bias=False,
|
52 |
+
)(max_pool_1)
|
53 |
+
|
54 |
+
conv_list.append(conv_6)
|
55 |
+
|
56 |
+
x = tf.keras.layers.Concatenate(axis=2)(conv_list)
|
57 |
+
x = tf.keras.layers.BatchNormalization()(x)
|
58 |
+
x = tf.keras.layers.Activation(activation="relu")(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
def _shortcut_layer(input_tensor, out_tensor):
|
63 |
+
shortcut_y = tf.keras.layers.Conv1D(
|
64 |
+
filters=int(out_tensor.shape[-1]), kernel_size=1, padding="same", use_bias=False
|
65 |
+
)(input_tensor)
|
66 |
+
shortcut_y = tf.keras.layers.BatchNormalization()(shortcut_y)
|
67 |
+
|
68 |
+
x = tf.keras.layers.Add()([shortcut_y, out_tensor])
|
69 |
+
x = tf.keras.layers.Activation("relu")(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
def build_model(
|
74 |
+
input_shape: Tuple[int, int],
|
75 |
+
nb_classes: int,
|
76 |
+
depth: int = 6,
|
77 |
+
use_residual: bool = True,
|
78 |
+
)-> tf.keras.models.Model:
|
79 |
+
"""
|
80 |
+
Model proposed by HI Fawas et al 2019 "Finding AlexNet for Time Series Classification - InceptionTime"
|
81 |
+
"""
|
82 |
+
input_layer = tf.keras.layers.Input(input_shape)
|
83 |
+
|
84 |
+
x = input_layer
|
85 |
+
input_res = input_layer
|
86 |
+
|
87 |
+
for d in range(depth):
|
88 |
+
|
89 |
+
x = _inception_module(x)
|
90 |
+
|
91 |
+
if use_residual and d % 3 == 2:
|
92 |
+
x = _shortcut_layer(input_res, x)
|
93 |
+
input_res = x
|
94 |
+
|
95 |
+
gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)
|
96 |
+
|
97 |
+
output_layer = tf.keras.layers.Dense(units=nb_classes, activation="linear")(
|
98 |
+
gap_layer
|
99 |
+
)
|
100 |
+
|
101 |
+
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
|
102 |
+
model.compile(
|
103 |
+
loss=tf.keras.losses.MeanAbsoluteError(),
|
104 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
105 |
+
metrics=[tf.keras.metrics.MeanSquaredError()],
|
106 |
+
)
|
107 |
+
|
108 |
+
return model
|
models/weights/model_weights_leadI.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:787e54a1a2338d7649e1424615f8913a90d5bd184ae76f2fbbc30b15e497c964
|
3 |
+
size 1835344
|
sample_data/ath_001.dat
ADDED
Binary file (120 kB). View file
|
|
sample_data/ath_001.hea
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ath_001 12 500 5000
|
2 |
+
ath_001.dat 16 50000/mV 16 0 10251 49595 0 I
|
3 |
+
ath_001.dat 16 50000/mV 16 0 -1096 35223 0 II
|
4 |
+
ath_001.dat 16 50000/mV 16 0 -10267 60826 0 III
|
5 |
+
ath_001.dat 16 50000/mV 16 0 -3724 3505 0 AVR
|
6 |
+
ath_001.dat 16 50000/mV 16 0 9391 26379 0 AVL
|
7 |
+
ath_001.dat 16 50000/mV 16 0 -5395 57481 0 AVF
|
8 |
+
ath_001.dat 16 50000/mV 16 0 13580 61759 0 V1
|
9 |
+
ath_001.dat 16 50000/mV 16 0 11410 33501 0 V2
|
10 |
+
ath_001.dat 16 50000/mV 16 0 14721 52508 0 V3
|
11 |
+
ath_001.dat 16 50000/mV 16 0 16103 51083 0 V4
|
12 |
+
ath_001.dat 16 50000/mV 16 0 6662 44197 0 V5
|
13 |
+
ath_001.dat 16 50000/mV 16 0 -3806 11333 0 V6
|
14 |
+
#SL12: Sinus bradycardia with marked sinus arrhythmia, Right axis deviation, Borderline ECG
|
15 |
+
#C: Sinus arrhythmia, Normal ECG
|
16 |
+
|