CalculatorCBD / app.py
ready2drop's picture
test
52ea5bb verified
import argparse
import os
import matplotlib.pyplot as plt
import sys
import gradio as gr
import torch
import pandas as pd
import numpy as np
import io
import base64
from lime.lime_tabular import LimeTabularExplainer
from pycaret.classification import *
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
from util import load_data_and_prepare
import view
def parse_args(args):
parser = argparse.ArgumentParser(description="CBD Classification")
parser.add_argument('--data_dir', type=str, default="./data/")
parser.add_argument('--excel_file', type=str, default="DUMC_final.csv")
parser.add_argument('--mode', type=str, default="train")
parser.add_argument('--scale', type=bool, default=True)
parser.add_argument('--smote', type=bool, default=True)
parser.add_argument('--model_name_or_path', type=str, default="./data/model", choices=[])
return parser.parse_args(args)
# Inference function
def classify(tabular_data):
try:
# Ensure tabular_data is a 2D list and extract the first row
if isinstance(tabular_data, list) and isinstance(tabular_data[0], list):
tabular_data = tabular_data[0] # Extract the first row
else:
raise ValueError("Input data is not in the expected 2D list format.")
# Convert input data to a pandas DataFrame
input_data = pd.DataFrame([tabular_data], columns= tabular_header)
print(f"Original Input DataFrame:\n{input_data}")
# Use PyCaret's predict_model to make predictions
prediction = predict_model(model, data=input_data)
# Extract predicted class and probability
predicted_class = prediction.loc[0, "prediction_label"]
class_probability = prediction.loc[0, "prediction_score"]
# Generate appropriate output based on the prediction and probability
if class_probability < 0.34:
result = (
f"This analysis estimates a low probability ({class_probability:.2f}) of a common bile duct stone. "
"Please consult a medical professional for final diagnosis."
)
elif 0.34 <= class_probability < 0.67:
result = (
f"Based on the provided data, this tool estimates an intermediate probability ({class_probability:.2f}) "
"of a common bile duct stone. Further medical review is recommended."
)
else: # class_probability >= 0.67
result = (
f"Based on the provided data, this tool estimates a high probability ({class_probability:.2f}) "
"of a common bile duct stone. Further medical review is necessary."
)
return result
except Exception as e:
return f"An error occurred during classification: {str(e)}"
# Inference function
def predict_proba_fn(instance):
"""
PyCaret의 predict_model을 활용한 확률 예측 함수.
"""
# 2D 형태로 변환
if instance.ndim == 1:
instance = instance.reshape(1, -1)
# DataFrame으로 변환
instance_df = pd.DataFrame(instance, columns=train.columns)
# predict_model을 통해 예측 수행
predictions = predict_model(model, data=instance_df)
# prediction_label이 1이면 prediction_score, 0이면 1-prediction_score
predictions['class_1_prob'] = np.where(predictions['prediction_label'] == 1,
predictions['prediction_score'],
0)
predictions['class_0_prob'] = np.where(predictions['prediction_label'] == 0,
predictions['prediction_score'],
0)
# class_0_prob와 class_1_prob 반환
return predictions[['class_0_prob', 'class_1_prob']].values
# def explain_with_lime(tabular_data):
# instance = np.array(tabular_data[0],dtype='float')
# # Create an explainer instance for classification
# explainer = LimeTabularExplainer(
# training_data=train.values, # Use your training data
# feature_names=tabular_header,
# class_names=['intermediate', 'High'], # Replace with actual class names
# mode='classification'
# )
# # LIME expects a 2D numpy array or DataFrame for input, and we need to provide the correct number of features
# explanation = explainer.explain_instance(
# instance, # Single instance (first row of the tabular data)
# predict_proba_fn, # The prediction function
# num_features=len(tabular_header) # Number of features to display in the explanation
# )
# # Plot LIME explanation
# fig = explanation.as_pyplot_figure()
# fig.set_size_inches(25, 8)
# buf = io.BytesIO()
# fig.savefig(buf, format='png')
# buf.seek(0)
# encoded_image = base64.b64encode(buf.read()).decode('utf-8')
# buf.close()
# plt.close(fig)
# return f"<img src='data:image/png;base64,{encoded_image}'/>"
if __name__ == '__main__':
args = parse_args(sys.argv[1:])
train = load_data_and_prepare(args.data_dir, args.excel_file, args.mode, args.scale, args.smote)
model = load_model(args.model_name_or_path)
examples = view.examples
description = view.description
title_markdown = view.title_markdown
tabular_header = view.tabular_header
tabular_dtype = ['number'] * len(tabular_header)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(title_markdown)
gr.Markdown(description)
with gr.Row():
with gr.Column():
tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=11)
info = gr.Textbox(lines=1, label="Patient info", visible = False)
with gr.Row():
# btn_c = gr.ClearButton([tabular_input])
btn_c = gr.Button("Clear")
btn = gr.Button("Run")
result_output = gr.Textbox(lines=2, label="Classification Result")
lime_output = gr.HTML(label="LIME Explanation")
gr.Examples(examples=examples, inputs=[tabular_input, info])
btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
# btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button
# Clear functionality: resets inputs and outputs
def clear_fields():
return None, None, [[None] * len(tabular_header)]
btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])
demo.queue()
demo.launch(share=True)