|
import os |
|
import wfdb |
|
import shutil |
|
import numpy as np |
|
import gradio as gr |
|
from models.inception import * |
|
from scipy.signal import resample |
|
|
|
|
|
def load_data(sample_data): |
|
ecg, meta_data = wfdb.rdsamp(sample_data) |
|
lead_I = ecg[:,0] |
|
sample_frequency = meta_data["fs"] |
|
return lead_I, sample_frequency |
|
|
|
def preprocess_ecg(ecg,fs): |
|
if fs != 100: |
|
ecg = resample(ecg, int(len(ecg)*(100/fs))) |
|
else: |
|
pass |
|
if len(ecg) > 1000: |
|
ecg = ecg[:1000] |
|
else: |
|
pass |
|
return ecg |
|
|
|
def load_age_model(sample_frequency,recording_time, num_leads): |
|
cwd = os.getcwd() |
|
weights = f"{cwd}/models/weights/model_weights_leadI_age.h5" |
|
model = build_age_model((sample_frequency * recording_time, num_leads), 1) |
|
model.load_weights(weights) |
|
return model |
|
|
|
|
|
def load_gender_model(sample_frequency,recording_time, num_leads): |
|
cwd = os.getcwd() |
|
weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5" |
|
model = build_gender_model((sample_frequency * recording_time, num_leads), 1) |
|
model.load_weights(weights) |
|
return model |
|
|
|
def run(header_file, data_file): |
|
SAMPLE_FREQUENCY = 100 |
|
TIME = 10 |
|
NUM_LEADS = 1 |
|
demo_dir = f"{CWD}/sample_data" |
|
_, hdr_basename = os.path.split(header_file.name) |
|
_, data_basename = os.path.split(data_file.name) |
|
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}") |
|
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}") |
|
data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}") |
|
ecg = preprocess_ecg(data,fs) |
|
age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS) |
|
gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS) |
|
age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0] |
|
gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0] |
|
return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction} |
|
|
|
|
|
|
|
CWD = os.getcwd() |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
header_file = gr.File(label = "header_file", file_types=[".hea"],) |
|
data_file = gr.File(label = "data_file", file_types=[".dat"]) |
|
with gr.Column(scale=1): |
|
output_age = gr.Textbox(label = "Estimated age") |
|
output_gender = gr.Label( label = "Predicted gender") |
|
|
|
|
|
with gr.Row(): |
|
predict_btn = gr.Button("Predict") |
|
predict_btn.click(fn= run, inputs = [ |
|
header_file, data_file], outputs=[output_age,output_gender]) |
|
with gr.Row(): |
|
gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\ |
|
|
|
|
|
|
|
], |
|
inputs = [header_file, data_file]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|
|
|
|
|