import gradio as gr import matplotlib.pyplot as plt import pandas as pd from utils import ( load_mimic_data, train_predictive_model, load_mednist_image, create_mock_segmentation, apply_threshold, extract_entities, visualize_model_performance, get_clinical_text_examples ) # show mimic data function def display_mimic_data(): diagnoses, procedures, prescriptions = load_mimic_data() return diagnoses.head(5), procedures.head(5), prescriptions.head(5) # predictive model function def train_model(): model, mse, r2, X_test, y_test = train_predictive_model() model_performance_fig = visualize_model_performance(model, X_test, y_test) return f"Model Performance (trained on a medium-sized sample):\nMean Squared Error: {mse:.2f}\nRĀ² Score: {r2:.2f}", model_performance_fig # show image segmentation function def process_image(image_index, threshold): image_path = f'data/Mednist/AbdomenCT/{image_index:06d}.jpeg' image = load_mednist_image(image_path) segmented_output = create_mock_segmentation(image) thresholded_output = apply_threshold(segmented_output, threshold) fig, axes = plt.subplots(1, 3, figsize=(18, 6)) fig.subplots_adjust(hspace=0.3, wspace=0.3) axes[0].imshow(image, cmap='gray') axes[0].set_title("Original MedNIST Image", fontsize=16, pad=20) axes[0].axis('off') axes[1].imshow(segmented_output, cmap='jet') axes[1].set_title("Segmentation Heatmap", fontsize=16, pad=20) axes[1].axis('off') axes[2].imshow(thresholded_output, cmap='gray') axes[2].set_title(f"Thresholded Segmentation\n(Threshold: {threshold:.2f})", fontsize=16, pad=20) axes[2].axis('off') plt.tight_layout() return fig # clinical text analysis functions def analyze_clinical_text(text): entities = extract_entities(text) df = pd.DataFrame(entities, columns=['Entity', 'Clinical Category', 'Original Category']) return df.to_html(index=False) def update_text_input(example): return example # gradio with gr.Blocks() as demo: gr.Markdown("# Multi-Modal Clinical Data Analysis") with gr.Tab("1. MIMIC-IV Data Analysis"): gr.Markdown("displaying a sample of mimic-iv data (1000 rows) for demonstration purposes.") diagnoses_output = gr.DataFrame(label="Diagnoses") procedures_output = gr.DataFrame(label="Procedures") prescriptions_output = gr.DataFrame(label="Prescriptions") mimic_button = gr.Button("Load MIMIC Data") mimic_button.click(display_mimic_data, inputs=None, outputs=[diagnoses_output, procedures_output, prescriptions_output]) with gr.Tab("2. Predictive Model for Length of Stay"): gr.Markdown("training a model on a medium-sized sample (2000 rows) for a balance of performance and speed.") model_output = gr.Textbox(label="Model Performance") model_plot = gr.Plot() model_button = gr.Button("Train Model") model_button.click(train_model, inputs=None, outputs=[model_output, model_plot]) with gr.Tab("3. Image Segmentation"): gr.Markdown("explore mock segmentation on mednist abdomenct images. this demo uses image processing techniques to create a visually appealing segmentation effect.") image_index = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Image Index") threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Segmentation Threshold") image_output = gr.Plot() image_button = gr.Button("Process Image") image_button.click(process_image, inputs=[image_index, threshold_slider], outputs=image_output) with gr.Tab("4. Clinical Text Analysis"): gr.Markdown("extract named entities from clinical text using an improved clinical ner model.") with gr.Row(): text_input = gr.Textbox( label="Enter clinical text", value="Patient shows symptoms of COVID-19, including mild respiratory distress and fever. The X-ray indicates possible lung opacities." ) example_dropdown = gr.Dropdown( choices=get_clinical_text_examples(), label="Select an example" ) text_output = gr.HTML() text_button = gr.Button("Extract Entities") text_button.click(analyze_clinical_text, inputs=text_input, outputs=text_output) example_dropdown.change(update_text_input, inputs=example_dropdown, outputs=text_input) demo.launch(share=True)