Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import seaborn as sn | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import confusion_matrix | |
| from matplotlib.colors import ListedColormap | |
| import numpy as np | |
| import gradio as gr | |
| set_input = gr.Dataframe(type="numpy", row_count=10, col_count=3, headers=['Sample Index', 'Predicted Prob', 'Label (Y)'], datatype=["number", "number", "number"]) | |
| set_input2 = gr.Slider(0, 1, step = 0.1, value=0.4, label="Set Probability Threshold (Default = 0.5)") | |
| #set_output = gr.Textbox(label ='test') | |
| set_output1 = gr.Dataframe(type="pandas", label = 'Predicted Labels') | |
| set_output2 = gr.Image(label="Confusion Matrix") | |
| set_output3 = gr.Image(label="ROC curve") | |
| set_output4 = gr.Image(label="Threshold Tuning curve") | |
| def perf_measure(y_actual, y_hat): | |
| TP = 0 | |
| FP = 0 | |
| TN = 0 | |
| FN = 0 | |
| for i in range(len(y_hat)): | |
| if y_actual[i]==y_hat[i]==1: | |
| TP += 1 | |
| if y_hat[i]==1 and y_actual[i]!=y_hat[i]: | |
| FP += 1 | |
| if y_actual[i]==y_hat[i]==0: | |
| TN += 1 | |
| if y_hat[i]==0 and y_actual[i]!=y_hat[i]: | |
| FN += 1 | |
| return(TP, FP, TN, FN) | |
| def visualize_ROC(set_threshold,set_input): | |
| import numpy as np | |
| prob = set_input[:,1] | |
| pred_label = (prob >= set_threshold).astype(int) | |
| actual_label = set_input[:,2] | |
| import pandas as pd | |
| data = { | |
| 'Predicted Prob': prob, | |
| 'Predicted Label': pred_label, | |
| 'Actual Label': actual_label | |
| } | |
| import pandas as pd | |
| import seaborn as sn | |
| import matplotlib.pyplot as plt | |
| df = pd.DataFrame(data) | |
| confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label']) | |
| fig, ax = plt.subplots(figsize=(12,4)) | |
| sn.heatmap(confusion_matrix_results, annot=True,annot_kws={"size": 20},cbar=False, | |
| square=False, | |
| fmt='g', | |
| cmap=ListedColormap(['white']), linecolor='black', | |
| linewidths=1.5) | |
| sn.set(font_scale=2) | |
| plt.xlabel("Predicted Label") | |
| plt.ylabel("Actual Label") | |
| plt.text(0.6,0.55,'(TN)') | |
| plt.text(1.6,0.55,'(FP)') | |
| plt.text(0.6,1.55,'(FN)') | |
| plt.text(1.6,1.55,'(TP)') | |
| ax.xaxis.tick_top() | |
| ax.xaxis.set_ticks_position('top') | |
| ax.xaxis.set_label_position('top') | |
| plt.tight_layout() | |
| plt.savefig('tmp.png', dpi=100) | |
| ## get ROC curve | |
| from sklearn.metrics import roc_curve | |
| fpr_mod, tpr_mod, thrsholds_mod = roc_curve(df['Actual Label'], df['Predicted Prob']) | |
| TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label']) | |
| # Sensitivity, hit rate, recall, or true positive rate | |
| try: | |
| recall = TP/(TP+FN) | |
| except: | |
| recall = 0 | |
| try: | |
| precision = TP/(TP+FP) | |
| except: | |
| precision = 0 | |
| try: | |
| specificity = TN/(TN+FP) | |
| except: | |
| specificity = 0 | |
| try: | |
| TPR = TP/(TP+FN) | |
| except: | |
| TPR = 0 | |
| # Fall out or false positive rate | |
| try: | |
| FPR = FP/(FP+TN) | |
| except: | |
| FPR = 0 | |
| try: | |
| f1_score_cur = 2*recall*precision/(precision+recall) | |
| except: | |
| f1_score_cur = 0 | |
| try: | |
| g_mean_cur = np.sqrt(recall*specificity) | |
| except: | |
| g_mean_cur = 0 | |
| fig, ax = plt.subplots(figsize=(12,8)) | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| plt.rcParams["figure.autolayout"] = True | |
| plt.rcParams['figure.facecolor'] = 'white' | |
| m1, c1 = 1, 0 | |
| x = np.linspace(0, 1, 500) | |
| plt.plot(fpr_mod, tpr_mod, label = 'ROC', c='blue', linestyle='-') | |
| plt.plot(x, x * m1 + c1, 'black', linestyle='--') | |
| plt.xlim(0, 1) | |
| plt.ylim(0, 1) | |
| #xi = (c1 - c2) / (m2 - m1) | |
| #yi = m1 * xi + c1 | |
| plt.axvline(x=FPR, color='gray', linestyle='--') | |
| plt.axhline(y=TPR, color='gray', linestyle='--') | |
| plt.scatter(FPR, TPR, color='red', s=300) | |
| ax.set_facecolor("white") | |
| ax.tick_params(axis='x', colors='black') | |
| ax.tick_params(axis='y', colors='black') | |
| ax.spines['left'].set_color('black') | |
| ax.spines['bottom'].set_color('black') | |
| ax.spines['top'].set_color('black') | |
| ax.spines['right'].set_color('black') | |
| plt.xlabel('False Positive Rate (1 - specificity)') | |
| plt.ylabel('True Positive Rate (Recall)') | |
| plt.text(FPR, TPR, 'FPR:%s, TPR:%s' % (round(FPR,2),round(TPR,2))) | |
| plt.title("ROC curve", fontsize=20) | |
| plt.tight_layout() | |
| plt.savefig('tmp2.png', dpi=100) | |
| ### plot threshold versus f1-score | |
| thres_list = [] | |
| f1_score_list = [] | |
| g_mean_list = [] | |
| for thres in np.arange(0,1,0.01): | |
| prob = set_input[:,1] | |
| pred_label = (prob >= thres).astype(int) | |
| actual_label = set_input[:,2] | |
| import pandas as pd | |
| data = { | |
| 'Predicted Prob': prob, | |
| 'Predicted Label': pred_label, | |
| 'Actual Label': actual_label | |
| } | |
| df = pd.DataFrame(data) | |
| confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label']) | |
| TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label']) | |
| # Sensitivity, hit rate, recall, or true positive rate | |
| try: | |
| recall = TP/(TP+FN) | |
| except: | |
| recall = 0 | |
| try: | |
| precision = TP/(TP+FP) | |
| except: | |
| precision = 0 | |
| try: | |
| specificity = TN/(TN+FP) | |
| except: | |
| specificity = 0 | |
| try: | |
| TPR = TP/(TP+FN) | |
| except: | |
| TPR = 0 | |
| # Fall out or false positive rate | |
| try: | |
| FPR = FP/(FP+TN) | |
| except: | |
| FPR = 0 | |
| try: | |
| f1_score = 2*recall*precision/(precision+recall) | |
| except: | |
| f1_score = 0 | |
| try: | |
| g_mean = np.sqrt(recall*specificity) | |
| except: | |
| g_mean = 0 | |
| thres_list.append(thres) | |
| f1_score_list.append(f1_score) | |
| g_mean_list.append(g_mean) | |
| # Find best thresholds | |
| best_f1_idx = np.argmax(f1_score_list) | |
| best_gmean_idx = np.argmax(g_mean_list) | |
| best_f1_threshold = thres_list[best_f1_idx] | |
| best_gmean_threshold = thres_list[best_gmean_idx] | |
| best_f1_value = f1_score_list[best_f1_idx] | |
| best_gmean_value = g_mean_list[best_gmean_idx] | |
| fig, ax = plt.subplots(figsize=(12,8)) | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| plt.rcParams["figure.autolayout"] = True | |
| plt.rcParams['figure.facecolor'] = 'white' | |
| m1, c1 = 1, 0 | |
| x = np.linspace(0, 1, 500) | |
| # Plot curves | |
| plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-') | |
| plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-') | |
| plt.xlim(0, 1) | |
| plt.ylim(0, 1) | |
| # Mark current threshold (user selected) | |
| plt.axvline(x=set_threshold, color='blue', linestyle=':', linewidth=2, alpha=0.5, label='Current threshold') | |
| plt.scatter(set_threshold, f1_score_cur, color='blue', s=200, alpha=0.5, marker='o') | |
| plt.scatter(set_threshold, g_mean_cur, color='blue', s=200, alpha=0.5, marker='o') | |
| # Mark BEST thresholds (optimal) | |
| plt.scatter(best_f1_threshold, best_f1_value, color='black', s=400, marker='*', | |
| edgecolors='gold', linewidths=2, zorder=5, label=f'Best F1 (threshold={best_f1_threshold:.2f})') | |
| plt.scatter(best_gmean_threshold, best_gmean_value, color='red', s=400, marker='*', | |
| edgecolors='gold', linewidths=2, zorder=5, label=f'Best G-mean (threshold={best_gmean_threshold:.2f})') | |
| ax.set_facecolor("white") | |
| ax.tick_params(axis='x', colors='black') | |
| ax.tick_params(axis='y', colors='black') | |
| ax.spines['left'].set_color('black') | |
| ax.spines['bottom'].set_color('black') | |
| ax.spines['top'].set_color('black') | |
| ax.spines['right'].set_color('black') | |
| plt.xlabel('Threshold cut-off') | |
| plt.ylabel('F1-score & G-mean') | |
| plt.legend(loc='upper right', fontsize=10) | |
| # Add text annotations for best values | |
| plt.text(best_f1_threshold, best_f1_value + 0.03, f'Best F1: {best_f1_value:.2f}', | |
| ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) | |
| plt.text(best_gmean_threshold, best_gmean_value + 0.03, f'Best G-mean: {best_gmean_value:.2f}', | |
| ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5)) | |
| # Add text annotations for current values | |
| plt.text(set_threshold, f1_score_cur - 0.05, f'Current F1: {f1_score_cur:.2f}', | |
| ha='center', fontsize=9, color='blue', alpha=0.7) | |
| plt.text(set_threshold, g_mean_cur - 0.05, f'Current G-mean: {g_mean_cur:.2f}', | |
| ha='center', fontsize=9, color='blue', alpha=0.7) | |
| plt.title("Threshold tuning curves (F1-score & G-mean)\nGold stars mark optimal thresholds", fontsize=20) | |
| plt.tight_layout() | |
| plt.savefig('tmp3.png', dpi=100) | |
| #return df,'tmp.png','tmp2.png' | |
| return 'tmp.png','tmp2.png','tmp3.png' | |
| def get_example(): | |
| import numpy as np | |
| import pandas as pd | |
| np.random.seed(seed = 42) | |
| N=100 | |
| pd_class1 = pd.DataFrame({'Sample Index': [i for i in range(1,int(N/4)+1)],'Predicted Prob': np.random.uniform(0.4,0.8,int(N/4)), 'Label (Y)': np.repeat(1,int(N/4))}) | |
| pd_class2 = pd.DataFrame({'Sample Index': [i for i in range(int(N/4)+1,N+1)],'Predicted Prob': np.random.uniform(0,0.7,int(3*N/4)), 'Label (Y)': np.repeat(0,int(3*N/4))}) | |
| pd_all = pd.concat([pd_class1, pd_class2]).reset_index(drop=True) | |
| pd_all = pd_all.sample(frac=1).reset_index(drop=True) | |
| pd_all['Sample Index'] = [i for i in range(1,N+1)] | |
| return pd_all.to_numpy() | |
| ### configure Gradio | |
| interface = gr.Interface(fn=visualize_ROC, | |
| inputs=[set_input2, set_input], | |
| outputs=[set_output2,set_output3,set_output4], | |
| examples_per_page = 2, | |
| examples=[ | |
| [0.5,get_example()], | |
| [0.7,get_example()], | |
| ], | |
| title="ML Demo for Receiver Operating Characteristic (ROC) curve", | |
| description= "Click examples below for a quick demo. Gold stars show optimal F1 and G-mean thresholds.", | |
| theme = 'huggingface', | |
| #layout = 'horizontal', | |
| ) | |
| interface.launch(debug=True, height=1400, width=2800) |