Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import matplotlib.pyplot as plt | |
| import sys | |
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import io | |
| import base64 | |
| from lime.lime_tabular import LimeTabularExplainer | |
| from pycaret.classification import * | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage") | |
| from util import load_data_and_prepare | |
| import view | |
| def parse_args(args): | |
| parser = argparse.ArgumentParser(description="CBD Classification") | |
| parser.add_argument('--data_dir', type=str, default="./data/") | |
| parser.add_argument('--excel_file', type=str, default="DUMC_final.csv") | |
| parser.add_argument('--mode', type=str, default="train") | |
| parser.add_argument('--scale', type=bool, default=True) | |
| parser.add_argument('--smote', type=bool, default=True) | |
| parser.add_argument('--model_name_or_path', type=str, default="./data/model", choices=[]) | |
| return parser.parse_args(args) | |
| # Inference function | |
| def classify(tabular_data): | |
| try: | |
| # Ensure tabular_data is a 2D list and extract the first row | |
| if isinstance(tabular_data, list) and isinstance(tabular_data[0], list): | |
| tabular_data = tabular_data[0] # Extract the first row | |
| else: | |
| raise ValueError("Input data is not in the expected 2D list format.") | |
| # Convert input data to a pandas DataFrame | |
| input_data = pd.DataFrame([tabular_data], columns= tabular_header) | |
| print(f"Original Input DataFrame:\n{input_data}") | |
| # Use PyCaret's predict_model to make predictions | |
| prediction = predict_model(model, data=input_data) | |
| # Extract predicted class and probability | |
| predicted_class = prediction.loc[0, "prediction_label"] | |
| class_probability = prediction.loc[0, "prediction_score"] | |
| # Generate appropriate output based on the prediction and probability | |
| if class_probability < 0.34: | |
| result = ( | |
| f"This analysis estimates a low probability ({class_probability:.2f}) of a common bile duct stone. " | |
| "Please consult a medical professional for final diagnosis." | |
| ) | |
| elif 0.34 <= class_probability < 0.67: | |
| result = ( | |
| f"Based on the provided data, this tool estimates an intermediate probability ({class_probability:.2f}) " | |
| "of a common bile duct stone. Further medical review is recommended." | |
| ) | |
| else: # class_probability >= 0.67 | |
| result = ( | |
| f"Based on the provided data, this tool estimates a high probability ({class_probability:.2f}) " | |
| "of a common bile duct stone. Further medical review is necessary." | |
| ) | |
| return result | |
| except Exception as e: | |
| return f"An error occurred during classification: {str(e)}" | |
| # Inference function | |
| def predict_proba_fn(instance): | |
| """ | |
| PyCaret의 predict_model을 활용한 확률 예측 함수. | |
| """ | |
| # 2D 형태로 변환 | |
| if instance.ndim == 1: | |
| instance = instance.reshape(1, -1) | |
| # DataFrame으로 변환 | |
| instance_df = pd.DataFrame(instance, columns=train.columns) | |
| # predict_model을 통해 예측 수행 | |
| predictions = predict_model(model, data=instance_df) | |
| # prediction_label이 1이면 prediction_score, 0이면 1-prediction_score | |
| predictions['class_1_prob'] = np.where(predictions['prediction_label'] == 1, | |
| predictions['prediction_score'], | |
| 0) | |
| predictions['class_0_prob'] = np.where(predictions['prediction_label'] == 0, | |
| predictions['prediction_score'], | |
| 0) | |
| # class_0_prob와 class_1_prob 반환 | |
| return predictions[['class_0_prob', 'class_1_prob']].values | |
| # def explain_with_lime(tabular_data): | |
| # instance = np.array(tabular_data[0],dtype='float') | |
| # # Create an explainer instance for classification | |
| # explainer = LimeTabularExplainer( | |
| # training_data=train.values, # Use your training data | |
| # feature_names=tabular_header, | |
| # class_names=['intermediate', 'High'], # Replace with actual class names | |
| # mode='classification' | |
| # ) | |
| # # LIME expects a 2D numpy array or DataFrame for input, and we need to provide the correct number of features | |
| # explanation = explainer.explain_instance( | |
| # instance, # Single instance (first row of the tabular data) | |
| # predict_proba_fn, # The prediction function | |
| # num_features=len(tabular_header) # Number of features to display in the explanation | |
| # ) | |
| # # Plot LIME explanation | |
| # fig = explanation.as_pyplot_figure() | |
| # fig.set_size_inches(25, 8) | |
| # buf = io.BytesIO() | |
| # fig.savefig(buf, format='png') | |
| # buf.seek(0) | |
| # encoded_image = base64.b64encode(buf.read()).decode('utf-8') | |
| # buf.close() | |
| # plt.close(fig) | |
| # return f"<img src='data:image/png;base64,{encoded_image}'/>" | |
| if __name__ == '__main__': | |
| args = parse_args(sys.argv[1:]) | |
| train = load_data_and_prepare(args.data_dir, args.excel_file, args.mode, args.scale, args.smote) | |
| model = load_model(args.model_name_or_path) | |
| examples = view.examples | |
| description = view.description | |
| title_markdown = view.title_markdown | |
| tabular_header = view.tabular_header | |
| tabular_dtype = ['number'] * len(tabular_header) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(title_markdown) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=11) | |
| info = gr.Textbox(lines=1, label="Patient info", visible = False) | |
| with gr.Row(): | |
| # btn_c = gr.ClearButton([tabular_input]) | |
| btn_c = gr.Button("Clear") | |
| btn = gr.Button("Run") | |
| result_output = gr.Textbox(lines=2, label="Classification Result") | |
| lime_output = gr.HTML(label="LIME Explanation") | |
| gr.Examples(examples=examples, inputs=[tabular_input, info]) | |
| btn.click(fn=classify, inputs=tabular_input, outputs=result_output) | |
| # btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button | |
| # Clear functionality: resets inputs and outputs | |
| def clear_fields(): | |
| return None, None, [[None] * len(tabular_header)] | |
| btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input]) | |
| demo.queue() | |
| demo.launch(share=True) | |