Clinical-Demo /
dgrant6's picture
Upload 60 files
67fdc2e verified
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from utils import (
load_mimic_data, train_predictive_model, load_mednist_image,
create_mock_segmentation, apply_threshold,
extract_entities, visualize_model_performance, get_clinical_text_examples
# show mimic data function
def display_mimic_data():
diagnoses, procedures, prescriptions = load_mimic_data()
return diagnoses.head(5), procedures.head(5), prescriptions.head(5)
# predictive model function
def train_model():
model, mse, r2, X_test, y_test = train_predictive_model()
model_performance_fig = visualize_model_performance(model, X_test, y_test)
return f"Model Performance (trained on a medium-sized sample):\nMean Squared Error: {mse:.2f}\nR² Score: {r2:.2f}", model_performance_fig
# show image segmentation function
def process_image(image_index, threshold):
image_path = f'data/Mednist/AbdomenCT/{image_index:06d}.jpeg'
image = load_mednist_image(image_path)
segmented_output = create_mock_segmentation(image)
thresholded_output = apply_threshold(segmented_output, threshold)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.subplots_adjust(hspace=0.3, wspace=0.3)
axes[0].imshow(image, cmap='gray')
axes[0].set_title("Original MedNIST Image", fontsize=16, pad=20)
axes[1].imshow(segmented_output, cmap='jet')
axes[1].set_title("Segmentation Heatmap", fontsize=16, pad=20)
axes[2].imshow(thresholded_output, cmap='gray')
axes[2].set_title(f"Thresholded Segmentation\n(Threshold: {threshold:.2f})", fontsize=16, pad=20)
return fig
# clinical text analysis functions
def analyze_clinical_text(text):
entities = extract_entities(text)
df = pd.DataFrame(entities, columns=['Entity', 'Clinical Category', 'Original Category'])
return df.to_html(index=False)
def update_text_input(example):
return example
# gradio
with gr.Blocks() as demo:
gr.Markdown("# Multi-Modal Clinical Data Analysis")
with gr.Tab("1. MIMIC-IV Data Analysis"):
gr.Markdown("displaying a sample of mimic-iv data (1000 rows) for demonstration purposes.")
diagnoses_output = gr.DataFrame(label="Diagnoses")
procedures_output = gr.DataFrame(label="Procedures")
prescriptions_output = gr.DataFrame(label="Prescriptions")
mimic_button = gr.Button("Load MIMIC Data"), inputs=None, outputs=[diagnoses_output, procedures_output, prescriptions_output])
with gr.Tab("2. Predictive Model for Length of Stay"):
gr.Markdown("training a model on a medium-sized sample (2000 rows) for a balance of performance and speed.")
model_output = gr.Textbox(label="Model Performance")
model_plot = gr.Plot()
model_button = gr.Button("Train Model"), inputs=None, outputs=[model_output, model_plot])
with gr.Tab("3. Image Segmentation"):
gr.Markdown("explore mock segmentation on mednist abdomenct images. this demo uses image processing techniques to create a visually appealing segmentation effect.")
image_index = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Image Index")
threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Segmentation Threshold")
image_output = gr.Plot()
image_button = gr.Button("Process Image"), inputs=[image_index, threshold_slider], outputs=image_output)
with gr.Tab("4. Clinical Text Analysis"):
gr.Markdown("extract named entities from clinical text using an improved clinical ner model.")
with gr.Row():
text_input = gr.Textbox(
label="Enter clinical text",
value="Patient shows symptoms of COVID-19, including mild respiratory distress and fever. The X-ray indicates possible lung opacities."
example_dropdown = gr.Dropdown(
label="Select an example"
text_output = gr.HTML()
text_button = gr.Button("Extract Entities"), inputs=text_input, outputs=text_output)
example_dropdown.change(update_text_input, inputs=example_dropdown, outputs=text_input)