Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pickle | |
gender_categories = ["Male", "Female"] | |
gender = gr.inputs.Dropdown(gender_categories, label="Gender") | |
race_categories = ["White", "Black", "Other"] | |
race = gr.inputs.Dropdown(race_categories, label="Race") | |
age_categories = ["80 - 89", "70 - 79", "50 - 59", "60 - 69", "0 - 9", "40 - 49", "30 - 39", "20 - 29", "10 - 19"] | |
age = gr.inputs.Dropdown(age_categories, label="Age") | |
height = gr.inputs.Textbox(label = "Height (cm)") | |
weight = gr.inputs.Textbox(label = "Weight (kg)") | |
diabetes_categories = ["Yes", "No"] | |
diabetes = gr.inputs.Dropdown(diabetes_categories, label="Diabetes") | |
simvastatin_categories = ["Yes", "No"] | |
simvastatin = gr.inputs.Dropdown(simvastatin_categories, label="Simvastatin (Zocor)") | |
amiodarone_categories = ["Yes", "No"] | |
amiodarone = gr.inputs.Dropdown(amiodarone_categories, label="Amiodarone (Cordarone)") | |
target_inr = gr.inputs.Textbox(label = "Target INR (mg/week)") | |
inr = gr.inputs.Textbox(label = "INR on Reported Therapeutic Dose of Warfarin") | |
cyp2c9_genotypes_categories = ["*1/*1", "*1/*3", "*1/*2", "*2/*2", "*1/*5", "*2/*3", "*3/*3"] | |
cyp2c9_genotypes = gr.inputs.Dropdown(cyp2c9_genotypes_categories, label = "Cyp2C9 genotypes") | |
vk0rc1_genotype_categories = ["A/G", "G/G", "A/A"] | |
vk0rc1_genotype = gr.inputs.Dropdown(vk0rc1_genotype_categories, label = "VKORC1 genotype: -1639 G>A (3673); chr16:31015190; rs9923231; C/T") | |
model_selection = gr.inputs.Dropdown(["Linear Regression (Regression)", "Ridge Regression (Regression)", "Random Forest (Regression)", "Neural Network (Regression)", "K-Nearest Neighbors (Classification)", "Logistic Regression (Classification)", "Random Forest (Classification)"], label="Machine Learning Model") | |
warfarin_dose = gr.outputs.Textbox(label = "Warfarin Dose") | |
linear_model = pickle.load(open('linear_model_regression.pkl', 'rb')) | |
ridge_model = pickle.load(open('ridge_model_regression.pkl', 'rb')) | |
forest_mode = pickle.load(open('forest_model_regression.pkl', 'rb')) | |
ann_model = pickle.load(open('ann_model_regression.pkl', 'rb')) | |
knn_model_clf = pickle.load(open('knn_model_classification.pkl', 'rb')) | |
logistic_model_clf = pickle.load(open('logistic_model_classification.pkl', 'rb')) | |
random_forest_model_clf = pickle.load(open('random_forest_model_classification.pkl', 'rb')) | |
def predict_warfarin_dose(gender, race, age, height, weight, diabetes, simvastatin, amiodarone, target_inr, inr, cyp2c9_genotypes, vk0rc1_genotyp, model_selection): | |
new_instance = [gender_categories.index(gender), race_categories.index(race), age_categories.index(age), height, weight, diabetes_categories.index(diabetes), simvastatin_categories.index(simvastatin), amiodarone_categories.index(amiodarone), target_inr, inr, cyp2c9_genotypes_categories.index(cyp2c9_genotypes), vk0rc1_genotype_categories.index(vk0rc1_genotype)] | |
if model_selection == "Linear Regression (Regression)": | |
prediction = linear_model.predict(new_instance) | |
elif model_selection == "Ridge Regression (Regression)": | |
prediction = ridge_model.predict(new_instance) | |
elif model_selection == "Random Forest (Regression)": | |
prediction = forest_mode.predict(new_instance) | |
elif model_selection == "Neural Network (Regression)": | |
prediction = ann_model.predict(new_instance) | |
elif model_selection == "K-Nearest Neighbors (Classification)": | |
prediction = knn_model_clf.predict(new_instance) | |
elif model_selection == "Logistic Regression (Classification)": | |
prediction = logistic_model_clf.predict(new_instance) | |
else: | |
prediction = random_forest_model_clf.predict(new_instance) | |
return prediction | |
interface = gr.Interface(fn=predict_warfarin_dose, inputs=[gender, race, age, height, weight, diabetes, simvastatin, amiodarone, target_inr, inr, cyp2c9_genotypes, vk0rc1_genotype, model_selection], outputs=warfarin_dose) | |
interface.launch() |