Spaces:
Sleeping
Sleeping
File size: 5,814 Bytes
14aebdf 1c87170 14aebdf 1c87170 14aebdf 1c87170 14aebdf ea52793 14aebdf 1c87170 14aebdf 1c87170 14aebdf 1c87170 14aebdf 1c87170 14aebdf ea52793 14aebdf ea52793 1c87170 ea52793 14aebdf 1c87170 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import numpy as np
import joblib
import warnings
import matplotlib.pyplot as plt
import matplotlib
import shap
import os
import tempfile
from config import MODEL_PATH, FEATURE_NAMES
warnings.filterwarnings('ignore')
matplotlib.use('Agg')
plt.rcParams['font.family'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet):
height_m = height / 100
bmi = weight / (height_m ** 2)
nlr = neutrophil / lymphocyte if lymphocyte > 0 else 0
plr = platelet / lymphocyte if lymphocyte > 0 else 0
return bmi, nlr, plr
def create_shap_plot(shap_values, feature_values, feature_names, prediction_proba):
shap_vals = shap_values[0][:, 1] # Shape: (18,) - SHAP values for class 1
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_filename = temp_file.name
temp_file.close()
fig, ax = plt.subplots(figsize=(10, 12))
sorted_indices = np.argsort(np.abs(shap_vals))
sorted_shap_vals = shap_vals[sorted_indices]
sorted_feature_names = [feature_names[i] for i in sorted_indices]
sorted_feature_values = feature_values[sorted_indices]
colors = ['red' if val > 0 else 'blue' for val in sorted_shap_vals]
bars = ax.barh(range(len(sorted_shap_vals)), sorted_shap_vals, color=colors, alpha=0.7)
ax.set_yticks(range(len(sorted_feature_names)))
ax.set_yticklabels([f"{name} = {val:.2f}" for name, val in zip(sorted_feature_names, sorted_feature_values)])
ax.set_xlabel('SHAP Value (Impact on Prediction)', fontsize=12)
ax.set_title(f'Feature Impact Analysis\nComplication Risk: {prediction_proba[1]*100:.1f}%',
fontsize=14, pad=20)
ax.axvline(x=0, color='black', linestyle='-', alpha=0.3)
for i, (bar, val) in enumerate(zip(bars, sorted_shap_vals)):
if val != 0:
ax.text(val + (0.001 if val > 0 else -0.001), i, f'{val:.3f}',
va='center', ha='left' if val > 0 else 'right', fontsize=9)
ax.text(0.02, 0.98, 'Red: Increases risk\nBlue: Decreases risk',
transform=ax.transAxes, va='top', ha='left',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
plt.tight_layout()
plt.savefig(temp_filename, dpi=300, bbox_inches='tight',
facecolor='white', edgecolor='none')
plt.close()
return temp_filename
def get_shap_explainer_and_values(model, input_data):
background_data = np.array([[
28, 65, 162, 24.7, 2, 1, 0, 1, 28, 11.5, 34.0,
250, 8.5, 12.0, 6.0, 1.8, 3.33, 139
]])
explainer = shap.KernelExplainer(model.predict_proba, background_data)
shap_values = explainer.shap_values(input_data, nsamples=100)
return shap_values
def predict_outcome(age, weight, height, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit,
platelet, mpv, pdw, neutrophil, lymphocyte):
model = get_model()
if model is None:
return "خطا: مدل بارگذاری نشده است", "", None
try:
bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet)
input_data = np.array([[
age, weight, height, bmi, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit, platelet,
mpv, pdw, neutrophil, lymphocyte, nlr, plr
]])
prediction_proba = model.predict_proba(input_data)[0]
prediction = model.predict(input_data)[0]
if prediction == 0:
result = f"🟢 پیشبینی: سالم (احتمال سالم بودن: {prediction_proba[0]*100:.1f}%)"
risk_level = "کم"
else:
result = f"🔴 پیشبینی: پرخطر (احتمال عوارض: {prediction_proba[1]*100:.1f}%)"
risk_level = "بالا"
detailed_report = f"""
📊 **گزارش تفصیلی پیشبینی**
**نتیجه کلی:** {result}
**سطح ریسک:** {risk_level}
**ویژگیهای محاسبه شده:**
- BMI: {bmi:.2f}
- NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f}
- PLR (نسبت پلاکت به لنفوسیت): {plr:.2f}
**احتمالات تفصیلی:**
- احتمال سالم بودن: {prediction_proba[0]*100:.1f}%
- احتمال بروز عوارض: {prediction_proba[1]*100:.1f}%
⚠️ **توجه:** این پیشبینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود.
"""
shap_values = get_shap_explainer_and_values(model, input_data)
shap_plot_path = create_shap_plot(
shap_values,
input_data[0],
FEATURE_NAMES,
prediction_proba
)
return result, detailed_report, shap_plot_path
except Exception as e:
return f"خطا در پردازش: {str(e)}", "", None
model = None
def get_model():
global model
if model is None:
try:
model = joblib.load(MODEL_PATH)
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
return model
def cleanup_temp_files():
try:
temp_dir = tempfile.gettempdir()
for filename in os.listdir(temp_dir):
if filename.endswith('.png') and 'tmp' in filename:
try:
os.remove(os.path.join(temp_dir, filename))
except:
pass
except:
pass
|