pulse_ox / app.py
zonova's picture
Update app.py
c6aa51a
raw
history blame
2.19 kB
import gradio as gr
import xgboost
import pandas as pd
import numpy as np
import json
import pickle
def predicter(SpO2, Age, Weight, Height, Temperature, Gender, Race):
'''
xgb_reg = xgboost.XGBClassifier(tree_method = 'approx',
enable_categorical = True,
learning_rate=.1,
max_depth=2,
n_estimators=70,
early_stopping_rounds = 0,
scale_pos_weight=1)
'''
with open('HH_ensemble_classifier_online.json', 'r') as file:
model_data = json.load(file)
for item in model_data:
index = item['index']
model = pickle.loads(item['model'].encode('latin1'))
loaded_models.append(model)
classifier_list = loaded_models
'''
xgb_reg.load_model('classifier_fewer_features_HH.json')
'''
if Gender == "Male":
gen = "M"
elif Gender == "Female":
gen = "F"
cont_features = ['SpO2','anchor_age','weight','height','temperature']
cat_features = ['gender','race_group']
user_input = pd.DataFrame([[SpO2/100,Age/91,Weight/309,Height/213,Temperature/42.06,gen,Race]],columns = cont_features+cat_features)
user_input[cat_features] = user_input[cat_features].copy().astype('category')
predictions = np.zeros((len(classifier_list),2))
for i in range(len(classifier_list)):
predictions[i] = classifier_list[i].predict_proba(user_input[cont_features + cat_features])
averaged_prediction = predictions.mean(axis=0)
'''
pred = xgb_reg.predict_proba(user_input)
'''
return {"No Hidden Hypoxemia": float(averaged_prediction[0]), "Hidden Hypoxemia": float(averaged_prediction[1])}
demo = gr.Interface(
fn=predicter,
inputs=[gr.Slider(88.1, 100),"number",gr.inputs.Number(label = "Weight in kg"),gr.inputs.Number(label = "Height in cm"),gr.inputs.Number(label = "Temperature in Celcius"),gr.Radio(["Male", "Female"]),gr.Radio(["White", "Black", "Asian", "Hispanic", "Other"])],
outputs=[gr.Label(label = "Probabilities")],
title = "Model Predictions"
)
demo.launch()