Clinical-Demo / utils.py
dgrant6's picture
Upload 60 files
67fdc2e verified
import pandas as pd
import matplotlib.pyplot as plt
from monai.transforms import LoadImage, EnsureChannelFirst, Resize
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import torch
from monai.networks.nets import UNet
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from scipy import ndimage
import numpy as np
import warnings
warnings.filterwarnings("ignore")
# ner model initialization
MODEL_NAME = "d4data/biomedical-ner-all"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME)
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
# ************** mimic data loading functions ************
def load_mimic_data():
# loads sample mimic-iv data for diagnoses, procedures, and prescriptions
hosp_paths = {
'diagnoses': 'data/mimic-iv-clinical-database-demo/hosp/diagnoses_icd.csv',
'procedures': 'data/mimic-iv-clinical-database-demo/hosp/procedures_icd.csv',
'prescriptions': 'data/mimic-iv-clinical-database-demo/hosp/prescriptions.csv'
}
diagnoses = pd.read_csv(hosp_paths['diagnoses'], nrows=1000)
procedures = pd.read_csv(hosp_paths['procedures'], nrows=1000)
prescriptions = pd.read_csv(hosp_paths['prescriptions'], nrows=1000)
return diagnoses, procedures, prescriptions
def load_mimic_demo_data():
# loads and merges multiple mimic-iv datasets
hosp_paths = {
'admissions': 'data/mimic-iv-clinical-database-demo/hosp/admissions.csv',
'patients': 'data/mimic-iv-clinical-database-demo/hosp/patients.csv',
'labevents': 'data/mimic-iv-clinical-database-demo/hosp/labevents.csv'
}
icu_paths = {
'icustays': 'data/mimic-iv-clinical-database-demo/icu/icustays.csv'
}
admissions = pd.read_csv(hosp_paths['admissions'], nrows=5000)
patients = pd.read_csv(hosp_paths['patients'], nrows=5000)
labevents = pd.read_csv(hosp_paths['labevents'], nrows=5000)
icustays = pd.read_csv(icu_paths['icustays'], nrows=5000)
merged_data = pd.merge(admissions, patients, on='subject_id', how='inner')
merged_data = pd.merge(merged_data, labevents, on='subject_id', how='inner')
merged_data = pd.merge(merged_data, icustays, on='subject_id', how='inner')
return merged_data
# ************* predictive model functions **************
def preprocess_data(data):
# selects numeric columns and handles missing values
numeric_data = data.select_dtypes(include=['number'])
numeric_data = numeric_data.fillna(numeric_data.median())
return numeric_data
def train_predictive_model():
# trains a random forest regressor to predict length of stay
data = load_mimic_demo_data()
data = preprocess_data(data)
data = data.sample(n=min(2000, len(data)), random_state=42)
X, y = data.drop('los', axis=1), data['los']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)
model = RandomForestRegressor(n_estimators=50, random_state=42)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
mse = mean_squared_error(y_test, predictions)
r2 = r2_score(y_test, predictions)
return model, mse, r2, X_test, y_test
def visualize_model_performance(model, X_test, y_test):
# creates a scatter plot of model predictions vs actual values
predictions = model.predict(X_test)
mse = mean_squared_error(y_test, predictions)
r2 = r2_score(y_test, predictions)
plt.figure(figsize=(10, 6))
plt.scatter(y_test, predictions, alpha=0.3)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.title(f"Model Predictions vs Actuals\nMSE: {mse:.2f}, R²: {r2:.2f}")
plt.xlabel("Actual Length of Stay")
plt.ylabel("Predicted Length of Stay")
plt.grid(True)
return plt.gcf()
# ************ image segmentation functions ************
def load_mednist_image(path):
# loads and preprocesses a mednist image
transform = LoadImage(image_only=True)
image = transform(path)
image = EnsureChannelFirst()(image)
return Resize(spatial_size=(256, 256))(image).squeeze().numpy()
def create_mock_segmentation(image):
# creates a mock segmentation using edge detection
smooth = ndimage.gaussian_filter(image, sigma=2)
sobel_h = ndimage.sobel(smooth, axis=0)
sobel_v = ndimage.sobel(smooth, axis=1)
magnitude = np.sqrt(sobel_h**2 + sobel_v**2)
magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min())
return magnitude
def apply_threshold(segmentation, threshold):
# applies a threshold to the segmentation image
return (segmentation > threshold).astype(float)
# ********** clinical text analysis (nlp) functions ************
def merge_entities(entities):
# merges adjacent entities of the same type
merged = []
for entity in entities:
if not merged or entity['entity_group'] != merged[-1]['entity_group'] or entity['start'] > merged[-1]['end']:
merged.append(entity)
else:
merged[-1]['end'] = entity['end']
merged[-1]['word'] = merged[-1]['word'] + ' ' + entity['word'].replace('##', '')
return merged
def map_to_clinical_category(entity_group):
# maps the original entity group to a broader clinical category
category_mapping = {
'DISEASE': 'DIAGNOSIS',
'Sign_symptom': 'SYMPTOM',
'DRUG': 'MEDICATION',
'Diagnostic_procedure': 'PROCEDURE',
'Therapeutic_procedure': 'PROCEDURE',
'Biological_structure': 'ANATOMY',
'Severity': 'MODIFIER',
'Detailed_description': 'DESCRIPTION',
'Clinical_event': 'EVENT',
'Lab_value': 'LAB_RESULT',
'Date': 'TEMPORAL',
'Age': 'DEMOGRAPHIC',
'Sex': 'DEMOGRAPHIC'
}
return category_mapping.get(entity_group, 'OTHER')
def extract_entities(text):
# extracts and processes named entities from the input text
raw_entities = ner_pipeline(text)
merged_entities = merge_entities(raw_entities)
processed_entities = []
for entity in merged_entities:
original_category = entity['entity_group']
clinical_category = map_to_clinical_category(original_category)
processed_entities.append((entity['word'], clinical_category, original_category))
return processed_entities
def get_clinical_text_examples():
# provides a list of example clinical texts for demonstration
return [
"Patient shows symptoms of COVID-19, including mild respiratory distress and fever. The X-ray indicates possible lung opacities.",
"73-year-old male with a history of hypertension and type 2 diabetes presents with chest pain and shortness of breath. ECG shows ST-segment elevation.",
"29-year-old female, 32 weeks pregnant, reports severe headache and blurred vision. Blood pressure reading: 160/100 mmHg.",
"45-year-old patient diagnosed with stage 3 colorectal cancer. Started on FOLFOX chemotherapy regimen. Experiencing nausea and fatigue post-treatment.",
"18-year-old male admitted after a motor vehicle accident. CT scan reveals internal bleeding and a fractured femur. Prepped for emergency surgery."
]