import random
import streamlit as st
import pandas as pd
import numpy as np
import wfdb
import ast
import time
import os.path
import altair as alt
from streamlit_javascript import st_javascript as st_js
from csscolor import parse
import subprocess
import matplotlib.pyplot as plt
import urllib.parse
import psutil
# Define constants
path = 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/'
sampling_rate = 100
# Configure libraries
st.set_page_config(
page_title="ECG Database",
page_icon="🫀",
layout="wide",
initial_sidebar_state="collapsed"
)
pd.set_option('display.max_columns', None)
# Initialize session state
if "expander_state" not in st.session_state:
st.session_state["expander_state"] = True
if "theme" not in st.session_state:
st.session_state["theme"] = "dark"
if "history" not in st.session_state:
st.session_state["history"] = []
if "forceload" not in st.session_state:
st.session_state["forceload"] = False
# Show title and site check
st.markdown('Med › ECG',
unsafe_allow_html=True)
st.write("""
# ECG Database
Filter and view the ECG, VCG and diagnosis data from the PTB-XL ECG Database.
**Points to note**
- The ECG analysis may not be 100% accurate.
- The VCG is a derived estimate from the ECG. It may be different from an actual VCG on the same patient.
- The ECG is virtually generated from sensor data. It is not a perfect replica of the original report printed by the ECG machine.
""")
site_link = 'https://lysine-ecg-db.hf.space/'
st_cloud = os.path.isdir('/home/appuser')
if st_cloud:
st.warning(f"""
**ecg-db has a new home with increased stability. Please access ecg-db from the new link below:**
Link to the new site: [{site_link}]({site_link}?{urllib.parse.urlencode(st.experimental_get_query_params(), doseq=True)})
""", icon='✨')
if not st.session_state["forceload"]:
with st.expander("If the new site is not working for you"):
if st.button("Load the app here"):
st.session_state["forceload"] = True
st.experimental_rerun()
st.stop()
# Download data from kaggle
if not os.path.isfile(path + 'ptbxl_database.csv'):
placeholder = st.empty()
already_downloading = False
for proc in psutil.process_iter():
try:
# Check if process name contains the given name string.
if "kaggle" in proc.name().lower():
already_downloading = True
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
if not already_downloading:
try:
placeholder.info(
"**Downloading data.**\nThis may take a minute, but it only needs to be done once.", icon="⏳")
subprocess.run(['pip', 'uninstall', '-y', 'kaggle'])
subprocess.run(['pip', 'install', '--user', 'kaggle'])
try:
# Streamlit cloud
subprocess.run(['/home/appuser/.local/bin/kaggle', 'datasets', 'download',
'khyeh0719/ptb-xl-dataset', '--unzip'])
except:
# Hugging Face
subprocess.run(['/home/user/.local/bin/kaggle', 'datasets', 'download',
'khyeh0719/ptb-xl-dataset', '--unzip'])
placeholder.empty()
except Exception as error:
placeholder.warning(
"An error occurred while downloading the data. Please take a screenshot of the whole page and send to the developer.")
st.write(error)
st.stop()
else:
placeholder.info(
"**Downloading data.**\nPlease refresh the page after a few minutes.", icon="⏳")
st.stop()
def query_to_filters():
filters = {}
query_params = st.experimental_get_query_params()
if "id" in query_params:
filters["record_index"] = int(query_params["id"][0]) - 1
if "validated" in query_params:
filters["validated_by_human"] = query_params["validated"][0].lower() == "true"
if "second_opinion" in query_params:
filters["second_opinion"] = query_params["second_opinion"][0].lower() == "true"
if "axis" in query_params:
filters["heart_axis"] = query_params["axis"][0].lower() == "true"
if "clean" in query_params:
filters["no_artifacts"] = query_params["clean"][0].lower() == "true"
if "condition" in query_params:
filters["scp_code"] = query_params["condition"]
if "d_class" in query_params:
filters["diagnostic_class"] = query_params["d_class"]
return filters
def filters_to_query():
query_params = {}
if "record_index" in filters:
query_params["id"] = filters["record_index"] + 1
if "validated_by_human" in filters:
query_params["validated"] = filters["validated_by_human"]
if "second_opinion" in filters:
query_params["second_opinion"] = filters["second_opinion"]
if "heart_axis" in filters:
query_params["axis"] = filters["heart_axis"]
if "no_artifacts" in filters:
query_params["clean"] = filters["no_artifacts"]
if "scp_code" in filters:
query_params["condition"] = filters["scp_code"]
if "diagnostic_class" in filters:
query_params["d_class"] = filters["diagnostic_class"]
st.experimental_set_query_params(**query_params)
filters = query_to_filters()
@st.cache_data(ttl=60 * 60)
def load_records():
"""
Load and convert the ECG records to a DataFrame.
One record for each ECG taken.
"""
def optional_int(x): return pd.NA if x == '' else int(float(x))
def optional_float(x): return pd.NA if x == '' else float(x)
def optional_string(x): return pd.NA if x == '' else x
record_df = pd.read_csv(
path+'ptbxl_database.csv',
index_col='ecg_id',
converters={
'patient_id': optional_int,
'age': optional_int,
'sex': lambda x: 'M' if x == '0' else 'F',
'height': optional_float,
'weight': optional_float,
'nurse': optional_int,
'site': optional_int,
'scp_codes': lambda x: ast.literal_eval(x),
'heart_axis': optional_string,
'infarction_stadium1': optional_string,
'infarction_stadium2': optional_string,
'validated_by': optional_int,
'baseline_drift': optional_string,
'static_noise': optional_string,
'burst_noise': optional_string,
'electrodes_problems': optional_string,
'extra_beats': optional_string,
'pacemaker': optional_string,
}
)
return record_df.reset_index()
try:
record_df = load_records()
except Exception as error:
st.warning(
"An error occurred while loading data. Please refresh the page to try again.")
st.write(error)
st.stop()
@st.cache_data(ttl=60 * 60)
def load_annotations():
"""
Load and convert the ECG annotations to a DataFrame.
One row for each condition in SCP code.
"""
def int_bool(x): return False if x == '' else True
def optional_int(x): return pd.NA if x == '' else int(float(x))
def optional_string(x): return pd.NA if x == '' else x
def mandatory_string(x): return 'OTHER' if x == '' else x
annotation_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
annotation_df = pd.read_csv(
path+'scp_statements.csv',
index_col=0,
converters={
'diagnostic': int_bool,
'form': int_bool,
'rhythm': int_bool,
'diagnostic_class': mandatory_string,
'diagnostic_subclass': mandatory_string,
'AHA code': optional_int,
'aECG REFID': optional_string,
'CDISC Code': optional_string,
'DICOM Code': optional_string,
}
)
annotation_df.index.name = 'scp_code'
annotation_df.sort_values('description', inplace=True)
return annotation_df
annotation_df = load_annotations()
# ===============================
# Browsing history
# ===============================
with st.sidebar:
st.write("**Browsing history:**")
if len(st.session_state['history']) == 0:
st.write('No ECGs viewed yet.')
else:
for history in st.session_state['history']:
st.write(
f"""{history + 1} - {', '.join(record_df.iloc[history].scp_codes.keys())}""", unsafe_allow_html=True)
# ===============================
# ECG Filters
# ===============================
with st.form("filter_form"):
col1, col2, col3, col4 = st.columns(4)
with col1:
filters['validated_by_human'] = st.checkbox("Human-validated", key='new_ecg1', value=filters['validated_by_human'] if 'validated_by_human' in filters else True,
help='Filter ECGs with results validated by a human')
with col2:
filters['second_opinion'] = st.checkbox("Double-validated", key='new_ecg2', value=filters['second_opinion'] if 'second_opinion' in filters else False,
help='Filter ECGs with results validated twice')
with col3:
filters['heart_axis'] = st.checkbox("Heart axis", key='new_ecg3', value=filters['heart_axis'] if 'heart_axis' in filters else False,
help='Filter ECGs with heart axis data')
with col4:
filters['no_artifacts'] = st.checkbox("No artifacts", key='new_ecg4', value=filters['no_artifacts'] if 'no_artifacts' in filters else True,
help='Filter ECGs with no artifacts (e.g. baseline drift and noises)')
with st.expander("Filter by class"):
if 'diagnostic_class' in filters:
del filters['diagnostic_class']
cols = st.columns(2)
class_df = annotation_df.groupby(['diagnostic_class'])[
'Statement Category'].apply(set).reset_index()
for i in range(len(class_df)):
key = class_df.iloc[i]['diagnostic_class']
description = 'Other conditions' if key == 'OTHER' else ', '.join(
class_df.iloc[i]['Statement Category'])
selected_class = cols[i % 2].checkbox(
description, key=f'filter_class_{i}', value=key in filters['diagnostic_class'] if 'diagnostic_class' in filters else False)
if selected_class:
if 'diagnostic_class' not in filters:
filters['diagnostic_class'] = [key]
elif key not in filters['diagnostic_class']:
filters['diagnostic_class'].append(key)
with st.expander("Filter by condition"):
if 'scp_code' in filters:
del filters['scp_code']
cols = st.columns(4)
for i in range(len(annotation_df)):
key = annotation_df.iloc[i].name
description = annotation_df.iloc[i]['description']
selected_code = cols[i % 4].checkbox(
description, key=f'filter_condition_{i}', value=key in filters['scp_code'] if 'scp_code' in filters else False)
if selected_code:
if 'scp_code' not in filters:
filters['scp_code'] = [key]
elif key not in filters['scp_code']:
filters['scp_code'].append(key)
submitted = st.form_submit_button(
label='Random ECG', help='Find a new ECG with the selected filters')
if submitted:
if 'record_index' in filters:
del filters['record_index']
filters_to_query()
st.session_state["expander_state"] = True
def applyFilter():
"""
Filter records based on filters in session state.
"""
global record_df
filtered_record_df = record_df
if "validated_by_human" in filters and filters['validated_by_human']:
filtered_record_df = filtered_record_df[filtered_record_df.validated_by_human]
if "second_opinion" in filters and filters['second_opinion']:
filtered_record_df = filtered_record_df[filtered_record_df.second_opinion]
if "heart_axis" in filters and filters['heart_axis']:
filtered_record_df = filtered_record_df[pd.isna(
filtered_record_df.heart_axis) == False]
if "no_artifacts" in filters and filters['no_artifacts']:
filtered_record_df = filtered_record_df[pd.isna(filtered_record_df.baseline_drift) & pd.isna(
filtered_record_df.static_noise) & pd.isna(filtered_record_df.burst_noise) & pd.isna(filtered_record_df.electrodes_problems)]
if "scp_code" in filters and not "diagnostic_class" in filters:
filtered_record_df = filtered_record_df[filtered_record_df.scp_codes.apply(
lambda x: any(code in filters["scp_code"] for code in x))]
elif not "scp_code" in filters and "diagnostic_class" in filters:
filtered_codes = annotation_df[annotation_df['diagnostic_class'].isin(
filters["diagnostic_class"])].reset_index()['scp_code'].values
filtered_record_df = filtered_record_df[filtered_record_df.scp_codes.apply(
lambda x: any(code in filtered_codes for code in x))]
elif "scp_code" in filters and "diagnostic_class" in filters:
filtered_codes = annotation_df[annotation_df['diagnostic_class'].isin(
filters["diagnostic_class"])].reset_index()['scp_code'].values
filtered_record_df = filtered_record_df[filtered_record_df.scp_codes.apply(
lambda x: any(code in filters["scp_code"] or code in filtered_codes for code in x))]
return filtered_record_df
filtered_record_df = applyFilter()
if len(filtered_record_df) == 0:
st.error('No ECGs found with the selected filters.')
st.stop()
# Select a random ECG record
if "record_index" not in filters or filters["record_index"] == None:
record = filtered_record_df.iloc[random.randint(
0, len(filtered_record_df) - 1)]
filters["record_index"] = record.name
filters_to_query()
else:
record = record_df.iloc[filters["record_index"]]
if filters["record_index"] in st.session_state['history']:
st.session_state['history'].remove(filters["record_index"])
st.session_state['history'].insert(0, filters["record_index"])
st.write(f'*{len(filtered_record_df)} ECGs with the selected filters*')
st.write("----------------------------")
# ===============================
# ECG Verification Status
# ===============================
box = st.warning
if record.validated_by_human:
box = st.info
if record.second_opinion:
box = st.success
box(f"""
**Autogenerated report:** {'Yes' if record.initial_autogenerated_report else 'No'}
**Human validated:** {'Yes' if record.validated_by_human else 'No'}
**Second opinion:** {'Yes' if record.second_opinion else 'No'}
""")
# ===============================
# Patient Info
# ===============================
col1, col2, col3, col4 = st.columns(4)
with col1:
st.write(f"**Patient ID:** {record.patient_id}")
st.write(f"**ECG ID:** {record.ecg_id}")
with col2:
st.write(f"**Age:** {record.age}")
st.write(f"**Sex:** {record.sex}")
with col3:
st.write(f"**Height:** {record.height}")
st.write(f"**Weight:** {record.weight}")
with col4:
st.write(f"**Date:** {record.recording_date}")
st.write(f"**ECG Device:** {record.device}")
# ===============================
# ECG Chart
# ===============================
@st.cache_data(max_entries=2)
def load_raw_data(df, sampling_rate, path):
"""
Load ECG signals from the raw data files.
"""
if sampling_rate == 100:
data = wfdb.rdsamp(path + df.filename_lr)
else:
data = wfdb.rdsamp(path + df.filename_hr)
data = pd.DataFrame(data[0], columns=data[1]['sig_name']).reset_index()
return data
@st.cache_resource(max_entries=2)
def plot_ecg(lead_signals, sampling_rate, chart_mode, theme):
"""
Draw the ECG chart.
"""
alt.renderers.set_embed_options(
padding={"left": 0, "right": 0, "bottom": 0, "top": 0}
)
if chart_mode == 'Continuous':
chart_x_min = 0
chart_x_max = 10 * sampling_rate
chart_y_min = -1.5
chart_y_max = 34.5
# Prepare DataFrames for the grid lines
grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2'])
for i in range(int(chart_y_min * 2), int(chart_y_max * 2), 1):
grid_df.loc[len(grid_df.index)] = [
chart_x_min, i / 2, chart_x_max, i / 2]
for i in range(chart_x_min, chart_x_max, 20):
grid_df.loc[len(grid_df.index)] = [i, chart_y_min, i, chart_y_max]
minor_grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2'])
for i in range(int(chart_y_min * 10), int(chart_y_max * 10), 1):
minor_grid_df.loc[len(minor_grid_df.index)] = [
chart_x_min, i / 10, chart_x_max, i / 10]
for i in range(chart_x_min, chart_x_max, 4):
minor_grid_df.loc[len(minor_grid_df.index)] = [
i, chart_y_min, i, chart_y_max]
# Prepare DataFrames for the text labels and modify the lead signals
text_df = pd.DataFrame(columns=['x', 'y', 'text'])
lead_names = lead_signals.columns.values[1:]
leads_count = len(lead_names)
for i in range(leads_count):
lead_signals[lead_names[i]].iloc[int(
10 * sampling_rate * 48/50):int(10 * sampling_rate * 49/50)] = 1
lead_signals[lead_names[i]].iloc[int(
10 * sampling_rate * 49/50):int(10 * sampling_rate * 49.2/50)] = 0
lead_signals[lead_names[i]].iloc[int(
10 * sampling_rate * 49.2/50):] = pd.NA
lead_signals[lead_names[i]] = lead_signals[lead_names[i]
] + (leads_count - i - 1) * 3
text_df.loc[len(text_df.index)] = [
4, (leads_count - i - 1) * 3 + 1, lead_names[i]]
# Plot the grid lines
chart = alt.layer(
alt.Chart(minor_grid_df).mark_rule(clip=True, stroke=('#252525' if theme == 'dark' else '#dddddd')).encode(
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
x2=alt.X2('x2'),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
y2=alt.Y2('y2'),
tooltip=alt.value(None),
),
alt.Chart(grid_df).mark_rule(clip=True, stroke=('#555' if theme == 'dark' else '#bbb')).encode(
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
x2=alt.X2('x2'),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
y2=alt.Y2('y2'),
tooltip=alt.value(None),
)
).properties(
width=1600,
height=1600 / 50 * 72 + 20, # 20px padding
).configure_concat(
spacing=0
).configure_facet(
spacing=0
).configure_axis(
grid=False,
labels=False,
)
# Plot the ECG signals
for col in lead_signals.columns.values[1:]:
chart += alt.Chart(lead_signals).mark_line(clip=True, stroke=('#7abaed' if theme == 'dark' else '#05014a')).encode(
x=alt.X('index', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
y=alt.Y(col, type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
tooltip=alt.value(None),
)
# Plot the text labels
chart += alt.Chart(text_df).mark_text(baseline='middle', align='left', size=20, fill=('#fff' if theme == 'dark' else '#020079')).encode(
text='text',
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
tooltip=alt.value(None),
)
return chart
else:
# Duplicate lead II into a new column
lead_signals['II '] = lead_signals['II']
lead_config = [
{
'lead': 'II ',
'y': 0,
'start_x': 0,
'end_x': int(10 * sampling_rate),
},
{
'lead': 'I',
'y': 3,
'start_x': 0,
'end_x': int(10 * sampling_rate * 12 / 50),
},
{
'lead': 'II',
'y': 2,
'start_x': 0,
'end_x': int(10 * sampling_rate * 12 / 50),
},
{
'lead': 'III',
'y': 1,
'start_x': 0,
'end_x': int(10 * sampling_rate * 12 / 50),
},
{
'lead': 'AVR',
'y': 3,
'start_x': int(10 * sampling_rate * 12 / 50),
'end_x': int(10 * sampling_rate * 24 / 50),
},
{
'lead': 'AVL',
'y': 2,
'start_x': int(10 * sampling_rate * 12 / 50),
'end_x': int(10 * sampling_rate * 24 / 50),
},
{
'lead': 'AVF',
'y': 1,
'start_x': int(10 * sampling_rate * 12 / 50),
'end_x': int(10 * sampling_rate * 24 / 50),
},
{
'lead': 'V1',
'y': 3,
'start_x': int(10 * sampling_rate * 24 / 50),
'end_x': int(10 * sampling_rate * 36 / 50),
},
{
'lead': 'V2',
'y': 2,
'start_x': int(10 * sampling_rate * 24 / 50),
'end_x': int(10 * sampling_rate * 36 / 50),
},
{
'lead': 'V3',
'y': 1,
'start_x': int(10 * sampling_rate * 24 / 50),
'end_x': int(10 * sampling_rate * 36 / 50),
},
{
'lead': 'V4',
'y': 3,
'start_x': int(10 * sampling_rate * 36 / 50),
'end_x': int(10 * sampling_rate),
},
{
'lead': 'V5',
'y': 2,
'start_x': int(10 * sampling_rate * 36 / 50),
'end_x': int(10 * sampling_rate),
},
{
'lead': 'V6',
'y': 1,
'start_x': int(10 * sampling_rate * 36 / 50),
'end_x': int(10 * sampling_rate),
},
]
chart_x_min = 0
chart_x_max = 10 * sampling_rate
chart_y_min = -1.5
chart_y_max = 10.5
# Prepare DataFrames for the grid lines
grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2'])
for i in range(int(chart_y_min * 2), int(chart_y_max * 2), 1):
grid_df.loc[len(grid_df.index)] = [
chart_x_min, i / 2, chart_x_max, i / 2]
for i in range(chart_x_min, chart_x_max, 20):
grid_df.loc[len(grid_df.index)] = [i, chart_y_min, i, chart_y_max]
minor_grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2'])
for i in range(int(chart_y_min * 10), int(chart_y_max * 10), 1):
minor_grid_df.loc[len(minor_grid_df.index)] = [
chart_x_min, i / 10, chart_x_max, i / 10]
for i in range(chart_x_min, chart_x_max, 4):
minor_grid_df.loc[len(minor_grid_df.index)] = [
i, chart_y_min, i, chart_y_max]
# Prepare DataFrames for the text labels and lead separators
# Also modify the lead signals
text_df = pd.DataFrame(columns=['x', 'y', 'text'])
separator_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2'])
for config in lead_config:
if config['start_x'] > 0:
lead_signals[config['lead']].iloc[:config['start_x']] = pd.NA
separator_df.loc[len(separator_df.index)] = [
config['start_x'], config['y'] * 3 - 0.5, config['start_x'], config['y'] * 3 + 0.5]
if config['end_x'] < 10 * sampling_rate:
lead_signals[config['lead']].iloc[config['end_x']:] = pd.NA
else:
lead_signals[config['lead']].iloc[int(
10 * sampling_rate * 48/50):int(10 * sampling_rate * 49/50)] = 1
lead_signals[config['lead']].iloc[int(
10 * sampling_rate * 49/50):int(10 * sampling_rate * 49.2/50)] = 0
lead_signals[config['lead']].iloc[int(
10 * sampling_rate * 49.2/50):] = pd.NA
lead_signals[config['lead']
] = lead_signals[config['lead']] + config['y'] * 3
text_df.loc[len(text_df.index)] = [
config['start_x'] + 4, config['y'] * 3 + 1, config['lead']]
# Plot the grid lines
chart = alt.layer(
alt.Chart(minor_grid_df).mark_rule(clip=True, stroke=('#252525' if theme == 'dark' else '#dddddd')).encode(
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
x2=alt.X2('x2'),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
y2=alt.Y2('y2'),
tooltip=alt.value(None),
),
alt.Chart(grid_df).mark_rule(clip=True, stroke=('#555' if theme == 'dark' else '#bbb')).encode(
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
x2=alt.X2('x2'),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
y2=alt.Y2('y2'),
tooltip=alt.value(None),
)
).properties(
width=1600,
height=1600 / 50 * 24 + 20, # 20px padding
).configure_concat(
spacing=0
).configure_facet(
spacing=0
).configure_axis(
grid=False,
labels=False,
)
# Plot the ECG signals
for config in lead_config:
chart += alt.Chart(lead_signals).mark_line(clip=True, stroke=('#7abaed' if theme == 'dark' else '#05014a')).encode(
x=alt.X('index', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
y=alt.Y(config['lead'], type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
tooltip=alt.value(None),
)
# Plot the lead separators
chart += alt.Chart(separator_df).mark_rule(clip=True, stroke=('#7abaed' if theme == 'dark' else '#05014a')).encode(
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
x2=alt.X2('x2'),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
y2=alt.Y2('y2'),
tooltip=alt.value(None),
)
# Plot the text labels
chart += alt.Chart(text_df).mark_text(baseline='middle', align='left', size=20, fill=('#fff' if theme == 'dark' else '#020079')).encode(
text='text',
x=alt.X('x', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_x_min, chart_x_max), padding=0)),
y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale(
domain=(chart_y_min, chart_y_max), padding=0)),
tooltip=alt.value(None),
)
return chart
if st.session_state["expander_state"] == False:
chart_mode = st.selectbox(
'ECG Chart Mode',
options=('Report', 'Continuous'),
)
with st.spinner('Loading ECG...'):
lead_signals = load_raw_data(record, sampling_rate, path)
fig = plot_ecg(lead_signals, sampling_rate,
chart_mode, st.session_state["theme"])
st.altair_chart(fig, use_container_width=False)
else:
st.info('**Loading ECG...**', icon='🔃')
# ===============================
# ECG Analysis
# ===============================
# Only render the expander when this is the final re-render
if st.session_state["expander_state"] == False:
with st.expander("ECG Analysis", expanded=st.session_state["expander_state"]):
for code, prob in record.scp_codes.items():
annotation = annotation_df.loc[code]
st.write(f"""
> `{f"{annotation.diagnostic_class} > {annotation.diagnostic_subclass} > {annotation.name}" if not pd.isna(annotation.diagnostic_class) and not pd.isna(annotation.diagnostic_subclass) else
f"{annotation.diagnostic_class} > {annotation.name}" if not pd.isna(annotation.diagnostic_class) else annotation.name}` - {"unknown likelihood" if prob == 0 else f"**{prob}%**"}
>
> {annotation['Statement Category']}
>
> **{annotation['SCP-ECG Statement Description']}**
""")
st.write("---------------------")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.write(f"**Heart Axis:** {record.heart_axis}")
st.write(f"**Pacemaker:** {record.pacemaker}")
st.write(f"**Extra Beats:** {record.extra_beats}")
with col2:
st.write(f"**Infarction Stadium 1:** {record.infarction_stadium1}")
st.write(f"**Infarction Stadium 2:** {record.infarction_stadium2}")
with col3:
st.write(f"**Baseline Drift:** {record.baseline_drift}")
st.write(f"**Electrode Problems:** {record.electrodes_problems}")
with col4:
st.write(f"**Static Noise:** {record.static_noise}")
st.write(f"**Burst Noise:** {record.burst_noise}")
else:
st.write('**Loading...**')
# ===============================
# Vectorcardiogram
# ===============================
kors_transform = [
{
'axis': 'X',
'leads': [
{'lead': 'V1', 'weight': -0.130},
{'lead': 'V2', 'weight': 0.050},
{'lead': 'V3', 'weight': -0.010},
{'lead': 'V4', 'weight': 0.140},
{'lead': 'V5', 'weight': 0.060},
{'lead': 'V6', 'weight': 0.540},
{'lead': 'I', 'weight': 0.380},
{'lead': 'II', 'weight': -0.070},
]
},
{
'axis': 'Y',
'leads': [
{'lead': 'V1', 'weight': 0.060},
{'lead': 'V2', 'weight': -0.020},
{'lead': 'V3', 'weight': -0.050},
{'lead': 'V4', 'weight': 0.060},
{'lead': 'V5', 'weight': -0.170},
{'lead': 'V6', 'weight': 0.130},
{'lead': 'I', 'weight': -0.070},
{'lead': 'II', 'weight': 0.930},
]
},
{
'axis': 'Z',
'leads': [
{'lead': 'V1', 'weight': -0.430},
{'lead': 'V2', 'weight': -0.060},
{'lead': 'V3', 'weight': -0.140},
{'lead': 'V4', 'weight': -0.200},
{'lead': 'V5', 'weight': -0.110},
{'lead': 'V6', 'weight': 0.310},
{'lead': 'I', 'weight': 0.110},
{'lead': 'II', 'weight': -0.230},
]
}
]
def cart2pol(x, y):
rho = np.sqrt(x**2 + y**2)
phi = np.arctan2(y, x)
return np.array([rho, phi])
def pol2cart(rho, phi):
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return np.array([x, y])
@st.cache_data(max_entries=2)
def calculate_kors_transform(lead_signals):
"""
Calculate VCG data from the ECG data using the Kors regression transformation.
https://doi.org/10.3390/s19143072
"""
vector_signals = lead_signals.copy()
for axis in kors_transform:
vector_signals[axis['axis']] = vector_signals.apply(
lambda r: sum([r[lead['lead']] * lead['weight'] for lead in axis['leads']]), axis=1)
vector_signals['frontal'] = vector_signals.apply(
lambda r: cart2pol(r['X'], r['Y']), axis=1)
vector_signals['frontal_rho'] = vector_signals['frontal'].apply(
lambda x: x[0])
vector_signals['frontal_phi'] = vector_signals['frontal'].apply(
lambda x: x[1])
vector_signals['transverse'] = vector_signals.apply(
lambda r: cart2pol(r['X'], -r['Z']), axis=1)
vector_signals['transverse_rho'] = vector_signals['transverse'].apply(
lambda x: x[0])
vector_signals['transverse_phi'] = vector_signals['transverse'].apply(
lambda x: x[1])
vector_signals['sagittal'] = vector_signals.apply(
lambda r: cart2pol(r['Z'], r['Y']), axis=1)
vector_signals['sagittal_rho'] = vector_signals['sagittal'].apply(
lambda x: x[0])
vector_signals['sagittal_phi'] = vector_signals['sagittal'].apply(
lambda x: x[1])
return vector_signals
@st.cache_resource(max_entries=2)
def plot_vcg(lead_signals, theme):
"""
Draw the vectorcardiogram.
"""
if theme == 'dark':
plt.style.use('dark_background')
else:
plt.style.use('default')
plt.rcParams.update({'font.size': 8, 'axes.titlepad': 40})
fig, ax = plt.subplots(
1, 3, subplot_kw={'projection': 'polar'}, figsize=(15, 15))
fig.patch.set_alpha(0.0)
fig.tight_layout(pad=2.0)
ax[0].set_theta_direction(-1)
ax[0].title.set_text("Frontal Vectorcardiogram")
fig.text(0.172, 0.662, "0º at +X, positive towards +Y",
horizontalalignment="center")
ax[0].set_facecolor("none")
ax[0].plot(lead_signals['frontal_phi'],
lead_signals['frontal_rho'], linewidth=0.5, color=('#7abaed' if theme == 'dark' else '#05014a'))
ax[1].set_theta_direction(-1)
ax[1].title.set_text("Transverse Vectorcardiogram")
fig.text(0.5, 0.662, "0º at +X, positive towards -Z",
horizontalalignment="center")
ax[1].set_facecolor("none")
ax[1].plot(lead_signals['transverse_phi'],
lead_signals['transverse_rho'], linewidth=0.5, color=('#7abaed' if theme == 'dark' else '#05014a'))
ax[2].set_theta_direction(-1)
ax[2].title.set_text("Sagittal Vectorcardiogram")
fig.text(0.829, 0.662, "0º at +Z, positive towards +Y",
horizontalalignment="center")
ax[2].set_facecolor("none")
ax[2].plot(lead_signals['sagittal_phi'],
lead_signals['sagittal_rho'], linewidth=0.5, color=('#7abaed' if theme == 'dark' else '#05014a'))
return fig
@st.cache_resource(max_entries=2)
def plot_vcg_3d(lead_signals, h_angle, v_angle, theme):
"""
Draw the vectorcardiogram in 3D.
"""
if theme == 'dark':
plt.style.use('dark_background')
else:
plt.style.use('default')
plt.rcParams.update({'font.size': 8})
fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, figsize=(10, 8))
fig.patch.set_alpha(0.0)
ax.set_facecolor("none")
ax.plot(lead_signals['X'],
lead_signals['Y'], lead_signals['Z'], linewidth=0.5, color="blue")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.invert_yaxis()
ax.title.set_text("Spatial Vectorcardiogram")
ax.view_init(v_angle, h_angle, None, vertical_axis='y')
return fig
# @st.cache_resource(max_entries=2)
# def plot_vcg_interactive(lead_signals, theme):
# """
# Draw the vectorcardiogram in 3D.
# """
# # pv.global_theme.show_scalar_bar = False
# p = pv.Plotter()
# points = np.array(
# [lead_signals['X'].values, lead_signals['Y'].values, lead_signals['Z'].values])
# points = points.transpose()
# spline = pv.Spline(points)
# p.add_mesh(mesh=spline, color='blue')
# p.show_grid()
# if theme == 'dark':
# p.set_background(color='#0e1117')
# else:
# p.set_background(color='#ffffff')
# return p
if st.session_state["expander_state"] == False:
with st.expander("Vectorcardiogram (Approximation)", expanded=st.session_state["expander_state"]):
hd_lead_signals = load_raw_data(record, 500, path)
vector_signals = calculate_kors_transform(hd_lead_signals)
fig = plot_vcg(vector_signals, st.session_state["theme"])
st.pyplot(fig, use_container_width=False)
col1, col2 = st.columns(spec=[0.2, 0.8])
with st.spinner('Loading 3D plot...'):
with col1:
h_angle = st.slider("Horizontal view angle", min_value=-180,
max_value=180, value=-60, step=5)
v_angle = st.slider("Vertical view angle", min_value=-180,
max_value=180, value=30, step=5)
with col2:
fig3d = plot_vcg_3d(vector_signals, h_angle,
v_angle, st.session_state["theme"])
st.pyplot(fig3d, use_container_width=False)
# fig3d = plot_vcg_interactive(
# vector_signals, st.session_state["theme"])
# stpyvista(fig3d)
# fig_html = mpld3.fig_to_html(fig3d)
# components.html(fig_html, height=600)
else:
st.info('**Loading VCG...**', icon='🔃')
# Detect browser theme
if st.session_state["expander_state"] == True:
theme = st_js(
"""window.getComputedStyle( document.body ,null).getPropertyValue('background-color')""")
if theme != 0:
color = parse.color(theme)
if color.as_hsl_percent_triple()[2] > 50:
st.session_state["theme"] = "light"
else:
st.session_state["theme"] = "dark"
# To forcibly collapse the expanders, the whole page is rendered twice.
# In the first rerender, the expander is replaced by a placeholder markdown text.
# In the second rerender, the expander is rendered and it defaults to collapsed
# because it did not exist in the previous render.
if st.session_state["expander_state"] == True and theme != 0:
st.session_state["expander_state"] = False
# Wait for the client to sync up
time.sleep(0.05)
# Start the second re-render
st.experimental_rerun()