import random import streamlit as st import pandas as pd import numpy as np import wfdb import ast import time import os.path import altair as alt from streamlit_javascript import st_javascript as st_js from csscolor import parse import subprocess import matplotlib.pyplot as plt import urllib.parse import psutil # Define constants path = 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/' sampling_rate = 100 # Configure libraries st.set_page_config( page_title="ECG Database", page_icon="🫀", layout="wide", initial_sidebar_state="collapsed" ) pd.set_option('display.max_columns', None) # Initialize session state if "expander_state" not in st.session_state: st.session_state["expander_state"] = True if "theme" not in st.session_state: st.session_state["theme"] = "dark" if "history" not in st.session_state: st.session_state["history"] = [] if "forceload" not in st.session_state: st.session_state["forceload"] = False # Show title and site check st.markdown('Med ECG', unsafe_allow_html=True) st.write(""" # ECG Database Filter and view the ECG, VCG and diagnosis data from the PTB-XL ECG Database. **Points to note** - The ECG analysis may not be 100% accurate. - The VCG is a derived estimate from the ECG. It may be different from an actual VCG on the same patient. - The ECG is virtually generated from sensor data. It is not a perfect replica of the original report printed by the ECG machine. """) site_link = 'https://lysine-ecg-db.hf.space/' st_cloud = os.path.isdir('/home/appuser') if st_cloud: st.warning(f""" **ecg-db has a new home with increased stability. Please access ecg-db from the new link below:** Link to the new site: [{site_link}]({site_link}?{urllib.parse.urlencode(st.experimental_get_query_params(), doseq=True)}) """, icon='✨') if not st.session_state["forceload"]: with st.expander("If the new site is not working for you"): if st.button("Load the app here"): st.session_state["forceload"] = True st.experimental_rerun() st.stop() # Download data from kaggle if not os.path.isfile(path + 'ptbxl_database.csv'): placeholder = st.empty() already_downloading = False for proc in psutil.process_iter(): try: # Check if process name contains the given name string. if "kaggle" in proc.name().lower(): already_downloading = True except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): pass if not already_downloading: try: placeholder.info( "**Downloading data.**\nThis may take a minute, but it only needs to be done once.", icon="⏳") subprocess.run(['pip', 'uninstall', '-y', 'kaggle']) subprocess.run(['pip', 'install', '--user', 'kaggle']) try: # Streamlit cloud subprocess.run(['/home/appuser/.local/bin/kaggle', 'datasets', 'download', 'khyeh0719/ptb-xl-dataset', '--unzip']) except: # Hugging Face subprocess.run(['/home/user/.local/bin/kaggle', 'datasets', 'download', 'khyeh0719/ptb-xl-dataset', '--unzip']) placeholder.empty() except Exception as error: placeholder.warning( "An error occurred while downloading the data. Please take a screenshot of the whole page and send to the developer.") st.write(error) st.stop() else: placeholder.info( "**Downloading data.**\nPlease refresh the page after a few minutes.", icon="⏳") st.stop() def query_to_filters(): filters = {} query_params = st.experimental_get_query_params() if "id" in query_params: filters["record_index"] = int(query_params["id"][0]) - 1 if "validated" in query_params: filters["validated_by_human"] = query_params["validated"][0].lower() == "true" if "second_opinion" in query_params: filters["second_opinion"] = query_params["second_opinion"][0].lower() == "true" if "axis" in query_params: filters["heart_axis"] = query_params["axis"][0].lower() == "true" if "clean" in query_params: filters["no_artifacts"] = query_params["clean"][0].lower() == "true" if "condition" in query_params: filters["scp_code"] = query_params["condition"] if "d_class" in query_params: filters["diagnostic_class"] = query_params["d_class"] return filters def filters_to_query(): query_params = {} if "record_index" in filters: query_params["id"] = filters["record_index"] + 1 if "validated_by_human" in filters: query_params["validated"] = filters["validated_by_human"] if "second_opinion" in filters: query_params["second_opinion"] = filters["second_opinion"] if "heart_axis" in filters: query_params["axis"] = filters["heart_axis"] if "no_artifacts" in filters: query_params["clean"] = filters["no_artifacts"] if "scp_code" in filters: query_params["condition"] = filters["scp_code"] if "diagnostic_class" in filters: query_params["d_class"] = filters["diagnostic_class"] st.experimental_set_query_params(**query_params) filters = query_to_filters() @st.cache_data(ttl=60 * 60) def load_records(): """ Load and convert the ECG records to a DataFrame. One record for each ECG taken. """ def optional_int(x): return pd.NA if x == '' else int(float(x)) def optional_float(x): return pd.NA if x == '' else float(x) def optional_string(x): return pd.NA if x == '' else x record_df = pd.read_csv( path+'ptbxl_database.csv', index_col='ecg_id', converters={ 'patient_id': optional_int, 'age': optional_int, 'sex': lambda x: 'M' if x == '0' else 'F', 'height': optional_float, 'weight': optional_float, 'nurse': optional_int, 'site': optional_int, 'scp_codes': lambda x: ast.literal_eval(x), 'heart_axis': optional_string, 'infarction_stadium1': optional_string, 'infarction_stadium2': optional_string, 'validated_by': optional_int, 'baseline_drift': optional_string, 'static_noise': optional_string, 'burst_noise': optional_string, 'electrodes_problems': optional_string, 'extra_beats': optional_string, 'pacemaker': optional_string, } ) return record_df.reset_index() try: record_df = load_records() except Exception as error: st.warning( "An error occurred while loading data. Please refresh the page to try again.") st.write(error) st.stop() @st.cache_data(ttl=60 * 60) def load_annotations(): """ Load and convert the ECG annotations to a DataFrame. One row for each condition in SCP code. """ def int_bool(x): return False if x == '' else True def optional_int(x): return pd.NA if x == '' else int(float(x)) def optional_string(x): return pd.NA if x == '' else x def mandatory_string(x): return 'OTHER' if x == '' else x annotation_df = pd.read_csv(path+'scp_statements.csv', index_col=0) annotation_df = pd.read_csv( path+'scp_statements.csv', index_col=0, converters={ 'diagnostic': int_bool, 'form': int_bool, 'rhythm': int_bool, 'diagnostic_class': mandatory_string, 'diagnostic_subclass': mandatory_string, 'AHA code': optional_int, 'aECG REFID': optional_string, 'CDISC Code': optional_string, 'DICOM Code': optional_string, } ) annotation_df.index.name = 'scp_code' annotation_df.sort_values('description', inplace=True) return annotation_df annotation_df = load_annotations() # =============================== # Browsing history # =============================== with st.sidebar: st.write("**Browsing history:**") if len(st.session_state['history']) == 0: st.write('No ECGs viewed yet.') else: for history in st.session_state['history']: st.write( f"""{history + 1} - {', '.join(record_df.iloc[history].scp_codes.keys())}""", unsafe_allow_html=True) # =============================== # ECG Filters # =============================== with st.form("filter_form"): col1, col2, col3, col4 = st.columns(4) with col1: filters['validated_by_human'] = st.checkbox("Human-validated", key='new_ecg1', value=filters['validated_by_human'] if 'validated_by_human' in filters else True, help='Filter ECGs with results validated by a human') with col2: filters['second_opinion'] = st.checkbox("Double-validated", key='new_ecg2', value=filters['second_opinion'] if 'second_opinion' in filters else False, help='Filter ECGs with results validated twice') with col3: filters['heart_axis'] = st.checkbox("Heart axis", key='new_ecg3', value=filters['heart_axis'] if 'heart_axis' in filters else False, help='Filter ECGs with heart axis data') with col4: filters['no_artifacts'] = st.checkbox("No artifacts", key='new_ecg4', value=filters['no_artifacts'] if 'no_artifacts' in filters else True, help='Filter ECGs with no artifacts (e.g. baseline drift and noises)') with st.expander("Filter by class"): if 'diagnostic_class' in filters: del filters['diagnostic_class'] cols = st.columns(2) class_df = annotation_df.groupby(['diagnostic_class'])[ 'Statement Category'].apply(set).reset_index() for i in range(len(class_df)): key = class_df.iloc[i]['diagnostic_class'] description = 'Other conditions' if key == 'OTHER' else ', '.join( class_df.iloc[i]['Statement Category']) selected_class = cols[i % 2].checkbox( description, key=f'filter_class_{i}', value=key in filters['diagnostic_class'] if 'diagnostic_class' in filters else False) if selected_class: if 'diagnostic_class' not in filters: filters['diagnostic_class'] = [key] elif key not in filters['diagnostic_class']: filters['diagnostic_class'].append(key) with st.expander("Filter by condition"): if 'scp_code' in filters: del filters['scp_code'] cols = st.columns(4) for i in range(len(annotation_df)): key = annotation_df.iloc[i].name description = annotation_df.iloc[i]['description'] selected_code = cols[i % 4].checkbox( description, key=f'filter_condition_{i}', value=key in filters['scp_code'] if 'scp_code' in filters else False) if selected_code: if 'scp_code' not in filters: filters['scp_code'] = [key] elif key not in filters['scp_code']: filters['scp_code'].append(key) submitted = st.form_submit_button( label='Random ECG', help='Find a new ECG with the selected filters') if submitted: if 'record_index' in filters: del filters['record_index'] filters_to_query() st.session_state["expander_state"] = True def applyFilter(): """ Filter records based on filters in session state. """ global record_df filtered_record_df = record_df if "validated_by_human" in filters and filters['validated_by_human']: filtered_record_df = filtered_record_df[filtered_record_df.validated_by_human] if "second_opinion" in filters and filters['second_opinion']: filtered_record_df = filtered_record_df[filtered_record_df.second_opinion] if "heart_axis" in filters and filters['heart_axis']: filtered_record_df = filtered_record_df[pd.isna( filtered_record_df.heart_axis) == False] if "no_artifacts" in filters and filters['no_artifacts']: filtered_record_df = filtered_record_df[pd.isna(filtered_record_df.baseline_drift) & pd.isna( filtered_record_df.static_noise) & pd.isna(filtered_record_df.burst_noise) & pd.isna(filtered_record_df.electrodes_problems)] if "scp_code" in filters and not "diagnostic_class" in filters: filtered_record_df = filtered_record_df[filtered_record_df.scp_codes.apply( lambda x: any(code in filters["scp_code"] for code in x))] elif not "scp_code" in filters and "diagnostic_class" in filters: filtered_codes = annotation_df[annotation_df['diagnostic_class'].isin( filters["diagnostic_class"])].reset_index()['scp_code'].values filtered_record_df = filtered_record_df[filtered_record_df.scp_codes.apply( lambda x: any(code in filtered_codes for code in x))] elif "scp_code" in filters and "diagnostic_class" in filters: filtered_codes = annotation_df[annotation_df['diagnostic_class'].isin( filters["diagnostic_class"])].reset_index()['scp_code'].values filtered_record_df = filtered_record_df[filtered_record_df.scp_codes.apply( lambda x: any(code in filters["scp_code"] or code in filtered_codes for code in x))] return filtered_record_df filtered_record_df = applyFilter() if len(filtered_record_df) == 0: st.error('No ECGs found with the selected filters.') st.stop() # Select a random ECG record if "record_index" not in filters or filters["record_index"] == None: record = filtered_record_df.iloc[random.randint( 0, len(filtered_record_df) - 1)] filters["record_index"] = record.name filters_to_query() else: record = record_df.iloc[filters["record_index"]] if filters["record_index"] in st.session_state['history']: st.session_state['history'].remove(filters["record_index"]) st.session_state['history'].insert(0, filters["record_index"]) st.write(f'*{len(filtered_record_df)} ECGs with the selected filters*') st.write("----------------------------") # =============================== # ECG Verification Status # =============================== box = st.warning if record.validated_by_human: box = st.info if record.second_opinion: box = st.success box(f""" **Autogenerated report:** {'Yes' if record.initial_autogenerated_report else 'No'} **Human validated:** {'Yes' if record.validated_by_human else 'No'} **Second opinion:** {'Yes' if record.second_opinion else 'No'} """) # =============================== # Patient Info # =============================== col1, col2, col3, col4 = st.columns(4) with col1: st.write(f"**Patient ID:** {record.patient_id}") st.write(f"**ECG ID:** {record.ecg_id}") with col2: st.write(f"**Age:** {record.age}") st.write(f"**Sex:** {record.sex}") with col3: st.write(f"**Height:** {record.height}") st.write(f"**Weight:** {record.weight}") with col4: st.write(f"**Date:** {record.recording_date}") st.write(f"**ECG Device:** {record.device}") # =============================== # ECG Chart # =============================== @st.cache_data(max_entries=2) def load_raw_data(df, sampling_rate, path): """ Load ECG signals from the raw data files. """ if sampling_rate == 100: data = wfdb.rdsamp(path + df.filename_lr) else: data = wfdb.rdsamp(path + df.filename_hr) data = pd.DataFrame(data[0], columns=data[1]['sig_name']).reset_index() return data @st.cache_resource(max_entries=2) def plot_ecg(lead_signals, sampling_rate, chart_mode, theme): """ Draw the ECG chart. """ alt.renderers.set_embed_options( padding={"left": 0, "right": 0, "bottom": 0, "top": 0} ) if chart_mode == 'Continuous': chart_x_min = 0 chart_x_max = 10 * sampling_rate chart_y_min = -1.5 chart_y_max = 34.5 # Prepare DataFrames for the grid lines grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2']) for i in range(int(chart_y_min * 2), int(chart_y_max * 2), 1): grid_df.loc[len(grid_df.index)] = [ chart_x_min, i / 2, chart_x_max, i / 2] for i in range(chart_x_min, chart_x_max, 20): grid_df.loc[len(grid_df.index)] = [i, chart_y_min, i, chart_y_max] minor_grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2']) for i in range(int(chart_y_min * 10), int(chart_y_max * 10), 1): minor_grid_df.loc[len(minor_grid_df.index)] = [ chart_x_min, i / 10, chart_x_max, i / 10] for i in range(chart_x_min, chart_x_max, 4): minor_grid_df.loc[len(minor_grid_df.index)] = [ i, chart_y_min, i, chart_y_max] # Prepare DataFrames for the text labels and modify the lead signals text_df = pd.DataFrame(columns=['x', 'y', 'text']) lead_names = lead_signals.columns.values[1:] leads_count = len(lead_names) for i in range(leads_count): lead_signals[lead_names[i]].iloc[int( 10 * sampling_rate * 48/50):int(10 * sampling_rate * 49/50)] = 1 lead_signals[lead_names[i]].iloc[int( 10 * sampling_rate * 49/50):int(10 * sampling_rate * 49.2/50)] = 0 lead_signals[lead_names[i]].iloc[int( 10 * sampling_rate * 49.2/50):] = pd.NA lead_signals[lead_names[i]] = lead_signals[lead_names[i] ] + (leads_count - i - 1) * 3 text_df.loc[len(text_df.index)] = [ 4, (leads_count - i - 1) * 3 + 1, lead_names[i]] # Plot the grid lines chart = alt.layer( alt.Chart(minor_grid_df).mark_rule(clip=True, stroke=('#252525' if theme == 'dark' else '#dddddd')).encode( x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), x2=alt.X2('x2'), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), y2=alt.Y2('y2'), tooltip=alt.value(None), ), alt.Chart(grid_df).mark_rule(clip=True, stroke=('#555' if theme == 'dark' else '#bbb')).encode( x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), x2=alt.X2('x2'), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), y2=alt.Y2('y2'), tooltip=alt.value(None), ) ).properties( width=1600, height=1600 / 50 * 72 + 20, # 20px padding ).configure_concat( spacing=0 ).configure_facet( spacing=0 ).configure_axis( grid=False, labels=False, ) # Plot the ECG signals for col in lead_signals.columns.values[1:]: chart += alt.Chart(lead_signals).mark_line(clip=True, stroke=('#7abaed' if theme == 'dark' else '#05014a')).encode( x=alt.X('index', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), y=alt.Y(col, type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), tooltip=alt.value(None), ) # Plot the text labels chart += alt.Chart(text_df).mark_text(baseline='middle', align='left', size=20, fill=('#fff' if theme == 'dark' else '#020079')).encode( text='text', x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), tooltip=alt.value(None), ) return chart else: # Duplicate lead II into a new column lead_signals['II '] = lead_signals['II'] lead_config = [ { 'lead': 'II ', 'y': 0, 'start_x': 0, 'end_x': int(10 * sampling_rate), }, { 'lead': 'I', 'y': 3, 'start_x': 0, 'end_x': int(10 * sampling_rate * 12 / 50), }, { 'lead': 'II', 'y': 2, 'start_x': 0, 'end_x': int(10 * sampling_rate * 12 / 50), }, { 'lead': 'III', 'y': 1, 'start_x': 0, 'end_x': int(10 * sampling_rate * 12 / 50), }, { 'lead': 'AVR', 'y': 3, 'start_x': int(10 * sampling_rate * 12 / 50), 'end_x': int(10 * sampling_rate * 24 / 50), }, { 'lead': 'AVL', 'y': 2, 'start_x': int(10 * sampling_rate * 12 / 50), 'end_x': int(10 * sampling_rate * 24 / 50), }, { 'lead': 'AVF', 'y': 1, 'start_x': int(10 * sampling_rate * 12 / 50), 'end_x': int(10 * sampling_rate * 24 / 50), }, { 'lead': 'V1', 'y': 3, 'start_x': int(10 * sampling_rate * 24 / 50), 'end_x': int(10 * sampling_rate * 36 / 50), }, { 'lead': 'V2', 'y': 2, 'start_x': int(10 * sampling_rate * 24 / 50), 'end_x': int(10 * sampling_rate * 36 / 50), }, { 'lead': 'V3', 'y': 1, 'start_x': int(10 * sampling_rate * 24 / 50), 'end_x': int(10 * sampling_rate * 36 / 50), }, { 'lead': 'V4', 'y': 3, 'start_x': int(10 * sampling_rate * 36 / 50), 'end_x': int(10 * sampling_rate), }, { 'lead': 'V5', 'y': 2, 'start_x': int(10 * sampling_rate * 36 / 50), 'end_x': int(10 * sampling_rate), }, { 'lead': 'V6', 'y': 1, 'start_x': int(10 * sampling_rate * 36 / 50), 'end_x': int(10 * sampling_rate), }, ] chart_x_min = 0 chart_x_max = 10 * sampling_rate chart_y_min = -1.5 chart_y_max = 10.5 # Prepare DataFrames for the grid lines grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2']) for i in range(int(chart_y_min * 2), int(chart_y_max * 2), 1): grid_df.loc[len(grid_df.index)] = [ chart_x_min, i / 2, chart_x_max, i / 2] for i in range(chart_x_min, chart_x_max, 20): grid_df.loc[len(grid_df.index)] = [i, chart_y_min, i, chart_y_max] minor_grid_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2']) for i in range(int(chart_y_min * 10), int(chart_y_max * 10), 1): minor_grid_df.loc[len(minor_grid_df.index)] = [ chart_x_min, i / 10, chart_x_max, i / 10] for i in range(chart_x_min, chart_x_max, 4): minor_grid_df.loc[len(minor_grid_df.index)] = [ i, chart_y_min, i, chart_y_max] # Prepare DataFrames for the text labels and lead separators # Also modify the lead signals text_df = pd.DataFrame(columns=['x', 'y', 'text']) separator_df = pd.DataFrame(columns=['x', 'y', 'x2', 'y2']) for config in lead_config: if config['start_x'] > 0: lead_signals[config['lead']].iloc[:config['start_x']] = pd.NA separator_df.loc[len(separator_df.index)] = [ config['start_x'], config['y'] * 3 - 0.5, config['start_x'], config['y'] * 3 + 0.5] if config['end_x'] < 10 * sampling_rate: lead_signals[config['lead']].iloc[config['end_x']:] = pd.NA else: lead_signals[config['lead']].iloc[int( 10 * sampling_rate * 48/50):int(10 * sampling_rate * 49/50)] = 1 lead_signals[config['lead']].iloc[int( 10 * sampling_rate * 49/50):int(10 * sampling_rate * 49.2/50)] = 0 lead_signals[config['lead']].iloc[int( 10 * sampling_rate * 49.2/50):] = pd.NA lead_signals[config['lead'] ] = lead_signals[config['lead']] + config['y'] * 3 text_df.loc[len(text_df.index)] = [ config['start_x'] + 4, config['y'] * 3 + 1, config['lead']] # Plot the grid lines chart = alt.layer( alt.Chart(minor_grid_df).mark_rule(clip=True, stroke=('#252525' if theme == 'dark' else '#dddddd')).encode( x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), x2=alt.X2('x2'), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), y2=alt.Y2('y2'), tooltip=alt.value(None), ), alt.Chart(grid_df).mark_rule(clip=True, stroke=('#555' if theme == 'dark' else '#bbb')).encode( x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), x2=alt.X2('x2'), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), y2=alt.Y2('y2'), tooltip=alt.value(None), ) ).properties( width=1600, height=1600 / 50 * 24 + 20, # 20px padding ).configure_concat( spacing=0 ).configure_facet( spacing=0 ).configure_axis( grid=False, labels=False, ) # Plot the ECG signals for config in lead_config: chart += alt.Chart(lead_signals).mark_line(clip=True, stroke=('#7abaed' if theme == 'dark' else '#05014a')).encode( x=alt.X('index', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), y=alt.Y(config['lead'], type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), tooltip=alt.value(None), ) # Plot the lead separators chart += alt.Chart(separator_df).mark_rule(clip=True, stroke=('#7abaed' if theme == 'dark' else '#05014a')).encode( x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), x2=alt.X2('x2'), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), y2=alt.Y2('y2'), tooltip=alt.value(None), ) # Plot the text labels chart += alt.Chart(text_df).mark_text(baseline='middle', align='left', size=20, fill=('#fff' if theme == 'dark' else '#020079')).encode( text='text', x=alt.X('x', type='quantitative', title=None, scale=alt.Scale( domain=(chart_x_min, chart_x_max), padding=0)), y=alt.Y('y', type='quantitative', title=None, scale=alt.Scale( domain=(chart_y_min, chart_y_max), padding=0)), tooltip=alt.value(None), ) return chart if st.session_state["expander_state"] == False: chart_mode = st.selectbox( 'ECG Chart Mode', options=('Report', 'Continuous'), ) with st.spinner('Loading ECG...'): lead_signals = load_raw_data(record, sampling_rate, path) fig = plot_ecg(lead_signals, sampling_rate, chart_mode, st.session_state["theme"]) st.altair_chart(fig, use_container_width=False) else: st.info('**Loading ECG...**', icon='🔃') # =============================== # ECG Analysis # =============================== # Only render the expander when this is the final re-render if st.session_state["expander_state"] == False: with st.expander("ECG Analysis", expanded=st.session_state["expander_state"]): for code, prob in record.scp_codes.items(): annotation = annotation_df.loc[code] st.write(f""" > `{f"{annotation.diagnostic_class} > {annotation.diagnostic_subclass} > {annotation.name}" if not pd.isna(annotation.diagnostic_class) and not pd.isna(annotation.diagnostic_subclass) else f"{annotation.diagnostic_class} > {annotation.name}" if not pd.isna(annotation.diagnostic_class) else annotation.name}` - {"unknown likelihood" if prob == 0 else f"**{prob}%**"} > > {annotation['Statement Category']} > > **{annotation['SCP-ECG Statement Description']}** """) st.write("---------------------") col1, col2, col3, col4 = st.columns(4) with col1: st.write(f"**Heart Axis:** {record.heart_axis}") st.write(f"**Pacemaker:** {record.pacemaker}") st.write(f"**Extra Beats:** {record.extra_beats}") with col2: st.write(f"**Infarction Stadium 1:** {record.infarction_stadium1}") st.write(f"**Infarction Stadium 2:** {record.infarction_stadium2}") with col3: st.write(f"**Baseline Drift:** {record.baseline_drift}") st.write(f"**Electrode Problems:** {record.electrodes_problems}") with col4: st.write(f"**Static Noise:** {record.static_noise}") st.write(f"**Burst Noise:** {record.burst_noise}") else: st.write('**Loading...**') # =============================== # Vectorcardiogram # =============================== kors_transform = [ { 'axis': 'X', 'leads': [ {'lead': 'V1', 'weight': -0.130}, {'lead': 'V2', 'weight': 0.050}, {'lead': 'V3', 'weight': -0.010}, {'lead': 'V4', 'weight': 0.140}, {'lead': 'V5', 'weight': 0.060}, {'lead': 'V6', 'weight': 0.540}, {'lead': 'I', 'weight': 0.380}, {'lead': 'II', 'weight': -0.070}, ] }, { 'axis': 'Y', 'leads': [ {'lead': 'V1', 'weight': 0.060}, {'lead': 'V2', 'weight': -0.020}, {'lead': 'V3', 'weight': -0.050}, {'lead': 'V4', 'weight': 0.060}, {'lead': 'V5', 'weight': -0.170}, {'lead': 'V6', 'weight': 0.130}, {'lead': 'I', 'weight': -0.070}, {'lead': 'II', 'weight': 0.930}, ] }, { 'axis': 'Z', 'leads': [ {'lead': 'V1', 'weight': -0.430}, {'lead': 'V2', 'weight': -0.060}, {'lead': 'V3', 'weight': -0.140}, {'lead': 'V4', 'weight': -0.200}, {'lead': 'V5', 'weight': -0.110}, {'lead': 'V6', 'weight': 0.310}, {'lead': 'I', 'weight': 0.110}, {'lead': 'II', 'weight': -0.230}, ] } ] def cart2pol(x, y): rho = np.sqrt(x**2 + y**2) phi = np.arctan2(y, x) return np.array([rho, phi]) def pol2cart(rho, phi): x = rho * np.cos(phi) y = rho * np.sin(phi) return np.array([x, y]) @st.cache_data(max_entries=2) def calculate_kors_transform(lead_signals): """ Calculate VCG data from the ECG data using the Kors regression transformation. https://doi.org/10.3390/s19143072 """ vector_signals = lead_signals.copy() for axis in kors_transform: vector_signals[axis['axis']] = vector_signals.apply( lambda r: sum([r[lead['lead']] * lead['weight'] for lead in axis['leads']]), axis=1) vector_signals['frontal'] = vector_signals.apply( lambda r: cart2pol(r['X'], r['Y']), axis=1) vector_signals['frontal_rho'] = vector_signals['frontal'].apply( lambda x: x[0]) vector_signals['frontal_phi'] = vector_signals['frontal'].apply( lambda x: x[1]) vector_signals['transverse'] = vector_signals.apply( lambda r: cart2pol(r['X'], -r['Z']), axis=1) vector_signals['transverse_rho'] = vector_signals['transverse'].apply( lambda x: x[0]) vector_signals['transverse_phi'] = vector_signals['transverse'].apply( lambda x: x[1]) vector_signals['sagittal'] = vector_signals.apply( lambda r: cart2pol(r['Z'], r['Y']), axis=1) vector_signals['sagittal_rho'] = vector_signals['sagittal'].apply( lambda x: x[0]) vector_signals['sagittal_phi'] = vector_signals['sagittal'].apply( lambda x: x[1]) return vector_signals @st.cache_resource(max_entries=2) def plot_vcg(lead_signals, theme): """ Draw the vectorcardiogram. """ if theme == 'dark': plt.style.use('dark_background') else: plt.style.use('default') plt.rcParams.update({'font.size': 8, 'axes.titlepad': 40}) fig, ax = plt.subplots( 1, 3, subplot_kw={'projection': 'polar'}, figsize=(15, 15)) fig.patch.set_alpha(0.0) fig.tight_layout(pad=2.0) ax[0].set_theta_direction(-1) ax[0].title.set_text("Frontal Vectorcardiogram") fig.text(0.172, 0.662, "0º at +X, positive towards +Y", horizontalalignment="center") ax[0].set_facecolor("none") ax[0].plot(lead_signals['frontal_phi'], lead_signals['frontal_rho'], linewidth=0.5, color=('#7abaed' if theme == 'dark' else '#05014a')) ax[1].set_theta_direction(-1) ax[1].title.set_text("Transverse Vectorcardiogram") fig.text(0.5, 0.662, "0º at +X, positive towards -Z", horizontalalignment="center") ax[1].set_facecolor("none") ax[1].plot(lead_signals['transverse_phi'], lead_signals['transverse_rho'], linewidth=0.5, color=('#7abaed' if theme == 'dark' else '#05014a')) ax[2].set_theta_direction(-1) ax[2].title.set_text("Sagittal Vectorcardiogram") fig.text(0.829, 0.662, "0º at +Z, positive towards +Y", horizontalalignment="center") ax[2].set_facecolor("none") ax[2].plot(lead_signals['sagittal_phi'], lead_signals['sagittal_rho'], linewidth=0.5, color=('#7abaed' if theme == 'dark' else '#05014a')) return fig @st.cache_resource(max_entries=2) def plot_vcg_3d(lead_signals, h_angle, v_angle, theme): """ Draw the vectorcardiogram in 3D. """ if theme == 'dark': plt.style.use('dark_background') else: plt.style.use('default') plt.rcParams.update({'font.size': 8}) fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, figsize=(10, 8)) fig.patch.set_alpha(0.0) ax.set_facecolor("none") ax.plot(lead_signals['X'], lead_signals['Y'], lead_signals['Z'], linewidth=0.5, color="blue") ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') ax.invert_yaxis() ax.title.set_text("Spatial Vectorcardiogram") ax.view_init(v_angle, h_angle, None, vertical_axis='y') return fig # @st.cache_resource(max_entries=2) # def plot_vcg_interactive(lead_signals, theme): # """ # Draw the vectorcardiogram in 3D. # """ # # pv.global_theme.show_scalar_bar = False # p = pv.Plotter() # points = np.array( # [lead_signals['X'].values, lead_signals['Y'].values, lead_signals['Z'].values]) # points = points.transpose() # spline = pv.Spline(points) # p.add_mesh(mesh=spline, color='blue') # p.show_grid() # if theme == 'dark': # p.set_background(color='#0e1117') # else: # p.set_background(color='#ffffff') # return p if st.session_state["expander_state"] == False: with st.expander("Vectorcardiogram (Approximation)", expanded=st.session_state["expander_state"]): hd_lead_signals = load_raw_data(record, 500, path) vector_signals = calculate_kors_transform(hd_lead_signals) fig = plot_vcg(vector_signals, st.session_state["theme"]) st.pyplot(fig, use_container_width=False) col1, col2 = st.columns(spec=[0.2, 0.8]) with st.spinner('Loading 3D plot...'): with col1: h_angle = st.slider("Horizontal view angle", min_value=-180, max_value=180, value=-60, step=5) v_angle = st.slider("Vertical view angle", min_value=-180, max_value=180, value=30, step=5) with col2: fig3d = plot_vcg_3d(vector_signals, h_angle, v_angle, st.session_state["theme"]) st.pyplot(fig3d, use_container_width=False) # fig3d = plot_vcg_interactive( # vector_signals, st.session_state["theme"]) # stpyvista(fig3d) # fig_html = mpld3.fig_to_html(fig3d) # components.html(fig_html, height=600) else: st.info('**Loading VCG...**', icon='🔃') # Detect browser theme if st.session_state["expander_state"] == True: theme = st_js( """window.getComputedStyle( document.body ,null).getPropertyValue('background-color')""") if theme != 0: color = parse.color(theme) if color.as_hsl_percent_triple()[2] > 50: st.session_state["theme"] = "light" else: st.session_state["theme"] = "dark" # To forcibly collapse the expanders, the whole page is rendered twice. # In the first rerender, the expander is replaced by a placeholder markdown text. # In the second rerender, the expander is rendered and it defaults to collapsed # because it did not exist in the previous render. if st.session_state["expander_state"] == True and theme != 0: st.session_state["expander_state"] = False # Wait for the client to sync up time.sleep(0.05) # Start the second re-render st.experimental_rerun()