Spaces:
Sleeping
Sleeping
import gradio as gr | |
from joblib import load | |
import pandas as pd | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
# Load the imputed features and target variable from the CSV file | |
df_imputed_features = pd.read_csv('imputed_data.csv') | |
# Separate features and target variable | |
X = df_imputed_features.drop(columns=['Therapeutic Dose of Warfarin']) | |
Y = df_imputed_features['Therapeutic Dose of Warfarin'] | |
# Load the scaler | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(X) | |
# Define the mappings | |
age_dict = {"10 - 19": 0, "20 - 29": 1, "30 - 39": 2, "40 - 49": 3, "50 - 59": 4, "60 - 69": 5, "70 - 79": 6, "80 - 89": 7, "90+": 8} | |
race_dict = {"Caucasian": 0, "Chinese": 1, "Indian": 2, "Japanese": 3, "Korean": 4, "Han Chinese": 5, "Hispanic": 6, "African-American": 7, "Asian": 8, "Black": 9, "Malay": 10, "White": 11, "Other": 12, "Other Mixed Race": 13} | |
diabetes_dict = {'0.0': 0, '1.0': 1} | |
simvastatin_dict = {'0.0': 0, '1.0': 1} | |
amiodarone_dict = {'0.0': 0, '1.0': 1} | |
cyp2C9_genotypes_dict = {'*1/*1': 0, '*1/*2': 1, '*1/*3': 2, '*2/*2': 3, '*2/*3': 4, '*3/*3': 5} | |
VKORC1_genotype_dict = {'AA': 0, 'AG': 1, 'GG': 2} | |
age = gr.Dropdown(choices=list(age_dict.keys()), label="Age") | |
gender = gr.Radio(["male", "female"], label="Gender") | |
race = gr.Dropdown(choices=list(race_dict.keys()), label="Race") | |
weight = gr.Number(label="Weight") | |
height = gr.Number(label="Height") | |
diabetes = gr.Radio(["0.0", "1.0"], label="Diabetes") | |
simvastatin = gr.Radio(["0.0", "1.0"], label="Simvastatin") | |
amiodarone = gr.Radio(["0.0", "1.0"], label="Amiodarone") | |
INR_reported = gr.Number(label="INR on Reported Therapeutic Dose of Warfarin") | |
cyp2C9_genotypes = gr.Dropdown(choices=list(cyp2C9_genotypes_dict.keys()), label="Cyp2C9 genotypes") | |
VKORC1_genotype = gr.Dropdown(choices=list(VKORC1_genotype_dict.keys()), label="VKORC1 genotype") | |
model = gr.Dropdown(choices=["Linear Regression", "Ridge Regression", "Decision Tree", "KNN","ANN", "Random Forest"], label="Model") | |
# Modify the gender encoding and ensure all categorical variables are properly encoded | |
def multi_inputs(age, gender, race, weight, height, diabetes, simvastatin, amiodarone, INR_reported, cyp2C9_genotypes, VKORC1_genotype, model, dose): | |
# Load the appropriate model based on the selected model | |
if model == 'Linear Regression': | |
model = load('Linear_regressor_model.pkl') | |
elif model == 'Ridge Regression': | |
model = load('best_ridge_regression_model.pkl') | |
elif model == 'Decision Tree': | |
model = load('best_decision_tree_model.pkl') | |
elif model == 'KNN': | |
model = load('best_knn_regressor_model.pkl') | |
elif model == 'ANN': | |
model = load('Best_ann_regressor.pkl') | |
else: | |
model = load("best_random_forest_model.pkl") | |
# Map categorical variables to their encoded values | |
age_encoded = age_dict.get(age) | |
race_encoded = race_dict.get(race) | |
gender_encoded = 0 if gender == "male" else 1 # Encoding gender | |
diabetes_encoded = diabetes_dict.get(diabetes) | |
simvastatin_encoded = simvastatin_dict.get(simvastatin) | |
amiodarone_encoded = amiodarone_dict.get(amiodarone) | |
cyp2C9_genotypes_encoded = cyp2C9_genotypes_dict.get(cyp2C9_genotypes) | |
VKORC1_genotype_encoded = VKORC1_genotype_dict.get(VKORC1_genotype) | |
# Transform input data | |
inputs = [age_encoded, gender_encoded, race_encoded, weight, height, diabetes_encoded, simvastatin_encoded, amiodarone_encoded, INR_reported, cyp2C9_genotypes_encoded, VKORC1_genotype_encoded] | |
# Predict therapeutic dose of warfarin | |
input_data = np.array([inputs]) | |
output = model.predict(input_data) | |
return output | |
# Create the Gradio interface | |
inputs = [age, gender, race, weight, height, diabetes, simvastatin, amiodarone, INR_reported, cyp2C9_genotypes, VKORC1_genotype, model] | |
outputs = gr.Textbox(label="Predicted Therapeutic Dose of Warfarin") | |
gr.Interface(fn=multi_inputs, inputs=inputs, outputs=outputs).launch(share=True, debug=True) | |