Spaces:
Sleeping
Sleeping
import pandas as pd | |
import matplotlib.pyplot as plt | |
from monai.transforms import LoadImage, EnsureChannelFirst, Resize | |
from sklearn.model_selection import train_test_split | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.metrics import mean_squared_error, r2_score | |
from sklearn.preprocessing import StandardScaler | |
import torch | |
from monai.networks.nets import UNet | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
from scipy import ndimage | |
import numpy as np | |
import warnings | |
warnings.filterwarnings("ignore") | |
# ner model initialization | |
MODEL_NAME = "d4data/biomedical-ner-all" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME) | |
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
# ************** mimic data loading functions ************ | |
def load_mimic_data(): | |
# loads sample mimic-iv data for diagnoses, procedures, and prescriptions | |
hosp_paths = { | |
'diagnoses': 'data/mimic-iv-clinical-database-demo/hosp/diagnoses_icd.csv', | |
'procedures': 'data/mimic-iv-clinical-database-demo/hosp/procedures_icd.csv', | |
'prescriptions': 'data/mimic-iv-clinical-database-demo/hosp/prescriptions.csv' | |
} | |
diagnoses = pd.read_csv(hosp_paths['diagnoses'], nrows=1000) | |
procedures = pd.read_csv(hosp_paths['procedures'], nrows=1000) | |
prescriptions = pd.read_csv(hosp_paths['prescriptions'], nrows=1000) | |
return diagnoses, procedures, prescriptions | |
def load_mimic_demo_data(): | |
# loads and merges multiple mimic-iv datasets | |
hosp_paths = { | |
'admissions': 'data/mimic-iv-clinical-database-demo/hosp/admissions.csv', | |
'patients': 'data/mimic-iv-clinical-database-demo/hosp/patients.csv', | |
'labevents': 'data/mimic-iv-clinical-database-demo/hosp/labevents.csv' | |
} | |
icu_paths = { | |
'icustays': 'data/mimic-iv-clinical-database-demo/icu/icustays.csv' | |
} | |
admissions = pd.read_csv(hosp_paths['admissions'], nrows=5000) | |
patients = pd.read_csv(hosp_paths['patients'], nrows=5000) | |
labevents = pd.read_csv(hosp_paths['labevents'], nrows=5000) | |
icustays = pd.read_csv(icu_paths['icustays'], nrows=5000) | |
merged_data = pd.merge(admissions, patients, on='subject_id', how='inner') | |
merged_data = pd.merge(merged_data, labevents, on='subject_id', how='inner') | |
merged_data = pd.merge(merged_data, icustays, on='subject_id', how='inner') | |
return merged_data | |
# ************* predictive model functions ************** | |
def preprocess_data(data): | |
# selects numeric columns and handles missing values | |
numeric_data = data.select_dtypes(include=['number']) | |
numeric_data = numeric_data.fillna(numeric_data.median()) | |
return numeric_data | |
def train_predictive_model(): | |
# trains a random forest regressor to predict length of stay | |
data = load_mimic_demo_data() | |
data = preprocess_data(data) | |
data = data.sample(n=min(2000, len(data)), random_state=42) | |
X, y = data.drop('los', axis=1), data['los'] | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(X) | |
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42) | |
model = RandomForestRegressor(n_estimators=50, random_state=42) | |
model.fit(X_train, y_train) | |
predictions = model.predict(X_test) | |
mse = mean_squared_error(y_test, predictions) | |
r2 = r2_score(y_test, predictions) | |
return model, mse, r2, X_test, y_test | |
def visualize_model_performance(model, X_test, y_test): | |
# creates a scatter plot of model predictions vs actual values | |
predictions = model.predict(X_test) | |
mse = mean_squared_error(y_test, predictions) | |
r2 = r2_score(y_test, predictions) | |
plt.figure(figsize=(10, 6)) | |
plt.scatter(y_test, predictions, alpha=0.3) | |
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2) | |
plt.title(f"Model Predictions vs Actuals\nMSE: {mse:.2f}, R²: {r2:.2f}") | |
plt.xlabel("Actual Length of Stay") | |
plt.ylabel("Predicted Length of Stay") | |
plt.grid(True) | |
return plt.gcf() | |
# ************ image segmentation functions ************ | |
def load_mednist_image(path): | |
# loads and preprocesses a mednist image | |
transform = LoadImage(image_only=True) | |
image = transform(path) | |
image = EnsureChannelFirst()(image) | |
return Resize(spatial_size=(256, 256))(image).squeeze().numpy() | |
def create_mock_segmentation(image): | |
# creates a mock segmentation using edge detection | |
smooth = ndimage.gaussian_filter(image, sigma=2) | |
sobel_h = ndimage.sobel(smooth, axis=0) | |
sobel_v = ndimage.sobel(smooth, axis=1) | |
magnitude = np.sqrt(sobel_h**2 + sobel_v**2) | |
magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min()) | |
return magnitude | |
def apply_threshold(segmentation, threshold): | |
# applies a threshold to the segmentation image | |
return (segmentation > threshold).astype(float) | |
# ********** clinical text analysis (nlp) functions ************ | |
def merge_entities(entities): | |
# merges adjacent entities of the same type | |
merged = [] | |
for entity in entities: | |
if not merged or entity['entity_group'] != merged[-1]['entity_group'] or entity['start'] > merged[-1]['end']: | |
merged.append(entity) | |
else: | |
merged[-1]['end'] = entity['end'] | |
merged[-1]['word'] = merged[-1]['word'] + ' ' + entity['word'].replace('##', '') | |
return merged | |
def map_to_clinical_category(entity_group): | |
# maps the original entity group to a broader clinical category | |
category_mapping = { | |
'DISEASE': 'DIAGNOSIS', | |
'Sign_symptom': 'SYMPTOM', | |
'DRUG': 'MEDICATION', | |
'Diagnostic_procedure': 'PROCEDURE', | |
'Therapeutic_procedure': 'PROCEDURE', | |
'Biological_structure': 'ANATOMY', | |
'Severity': 'MODIFIER', | |
'Detailed_description': 'DESCRIPTION', | |
'Clinical_event': 'EVENT', | |
'Lab_value': 'LAB_RESULT', | |
'Date': 'TEMPORAL', | |
'Age': 'DEMOGRAPHIC', | |
'Sex': 'DEMOGRAPHIC' | |
} | |
return category_mapping.get(entity_group, 'OTHER') | |
def extract_entities(text): | |
# extracts and processes named entities from the input text | |
raw_entities = ner_pipeline(text) | |
merged_entities = merge_entities(raw_entities) | |
processed_entities = [] | |
for entity in merged_entities: | |
original_category = entity['entity_group'] | |
clinical_category = map_to_clinical_category(original_category) | |
processed_entities.append((entity['word'], clinical_category, original_category)) | |
return processed_entities | |
def get_clinical_text_examples(): | |
# provides a list of example clinical texts for demonstration | |
return [ | |
"Patient shows symptoms of COVID-19, including mild respiratory distress and fever. The X-ray indicates possible lung opacities.", | |
"73-year-old male with a history of hypertension and type 2 diabetes presents with chest pain and shortness of breath. ECG shows ST-segment elevation.", | |
"29-year-old female, 32 weeks pregnant, reports severe headache and blurred vision. Blood pressure reading: 160/100 mmHg.", | |
"45-year-old patient diagnosed with stage 3 colorectal cancer. Started on FOLFOX chemotherapy regimen. Experiencing nausea and fatigue post-treatment.", | |
"18-year-old male admitted after a motor vehicle accident. CT scan reveals internal bleeding and a fractured femur. Prepped for emergency surgery." | |
] |