File size: 5,814 Bytes
14aebdf
 
 
1c87170
 
 
 
 
 
14aebdf
 
 
1c87170
 
 
 
 
14aebdf
 
 
 
 
 
 
1c87170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14aebdf
 
 
ea52793
14aebdf
1c87170
 
 
14aebdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c87170
 
 
 
14aebdf
 
 
1c87170
 
 
 
 
 
 
 
 
 
14aebdf
 
1c87170
14aebdf
ea52793
 
 
14aebdf
 
 
ea52793
 
1c87170
ea52793
 
 
 
14aebdf
1c87170
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import numpy as np
import joblib
import warnings
import matplotlib.pyplot as plt
import matplotlib
import shap
import os
import tempfile
from config import MODEL_PATH, FEATURE_NAMES

warnings.filterwarnings('ignore')

matplotlib.use('Agg')

plt.rcParams['font.family'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet):
    height_m = height / 100
    bmi = weight / (height_m ** 2)
    nlr = neutrophil / lymphocyte if lymphocyte > 0 else 0
    plr = platelet / lymphocyte if lymphocyte > 0 else 0
    return bmi, nlr, plr

def create_shap_plot(shap_values, feature_values, feature_names, prediction_proba):
    shap_vals = shap_values[0][:, 1]  # Shape: (18,) - SHAP values for class 1
        
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
    temp_filename = temp_file.name
    temp_file.close()
    
    fig, ax = plt.subplots(figsize=(10, 12))
    
    sorted_indices = np.argsort(np.abs(shap_vals))
    sorted_shap_vals = shap_vals[sorted_indices]
    sorted_feature_names = [feature_names[i] for i in sorted_indices]
    sorted_feature_values = feature_values[sorted_indices]
    
    colors = ['red' if val > 0 else 'blue' for val in sorted_shap_vals]
    bars = ax.barh(range(len(sorted_shap_vals)), sorted_shap_vals, color=colors, alpha=0.7)
    
    ax.set_yticks(range(len(sorted_feature_names)))
    ax.set_yticklabels([f"{name} = {val:.2f}" for name, val in zip(sorted_feature_names, sorted_feature_values)])
    ax.set_xlabel('SHAP Value (Impact on Prediction)', fontsize=12)
    ax.set_title(f'Feature Impact Analysis\nComplication Risk: {prediction_proba[1]*100:.1f}%', 
                fontsize=14, pad=20)
    
    ax.axvline(x=0, color='black', linestyle='-', alpha=0.3)
    
    for i, (bar, val) in enumerate(zip(bars, sorted_shap_vals)):
        if val != 0:
            ax.text(val + (0.001 if val > 0 else -0.001), i, f'{val:.3f}', 
                    va='center', ha='left' if val > 0 else 'right', fontsize=9)
    
    ax.text(0.02, 0.98, 'Red: Increases risk\nBlue: Decreases risk', 
            transform=ax.transAxes, va='top', ha='left', 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig(temp_filename, dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.close()
    
    return temp_filename

def get_shap_explainer_and_values(model, input_data):
        background_data = np.array([[
            28, 65, 162, 24.7, 2, 1, 0, 1, 28, 11.5, 34.0, 
            250, 8.5, 12.0, 6.0, 1.8, 3.33, 139
        ]])
        
        explainer = shap.KernelExplainer(model.predict_proba, background_data)
        shap_values = explainer.shap_values(input_data, nsamples=100)
        
        return shap_values

        
def predict_outcome(age, weight, height, gravidity, parity, h_abortion, 
                   living_child, gestational_age, hemoglobin, hematocrit, 
                   platelet, mpv, pdw, neutrophil, lymphocyte):
    model = get_model()
    
    if model is None:
        return "خطا: مدل بارگذاری نشده است", "", None
    
    try:
        bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet)
        
        input_data = np.array([[
            age, weight, height, bmi, gravidity, parity, h_abortion,
            living_child, gestational_age, hemoglobin, hematocrit, platelet,
            mpv, pdw, neutrophil, lymphocyte, nlr, plr
        ]])
        
        prediction_proba = model.predict_proba(input_data)[0]
        prediction = model.predict(input_data)[0]
        
        if prediction == 0:
            result = f"🟢 پیش‌بینی: سالم (احتمال سالم بودن: {prediction_proba[0]*100:.1f}%)"
            risk_level = "کم"
        else:
            result = f"🔴 پیش‌بینی: پرخطر (احتمال عوارض: {prediction_proba[1]*100:.1f}%)"
            risk_level = "بالا"
        
        detailed_report = f"""
📊 **گزارش تفصیلی پیش‌بینی**

**نتیجه کلی:** {result}

**سطح ریسک:** {risk_level}

**ویژگی‌های محاسبه شده:**
- BMI: {bmi:.2f}
- NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f}
- PLR (نسبت پلاکت به لنفوسیت): {plr:.2f}

**احتمالات تفصیلی:**
- احتمال سالم بودن: {prediction_proba[0]*100:.1f}%
- احتمال بروز عوارض: {prediction_proba[1]*100:.1f}%

⚠️ **توجه:** این پیش‌بینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود.
        """
        
        shap_values = get_shap_explainer_and_values(model, input_data)
        
        shap_plot_path = create_shap_plot(
            shap_values,
            input_data[0], 
            FEATURE_NAMES, 
            prediction_proba
        )

        return result, detailed_report, shap_plot_path
        
    except Exception as e:
        return f"خطا در پردازش: {str(e)}", "", None


model = None

def get_model():
    global model
    if model is None:
        try:
            model = joblib.load(MODEL_PATH)
            print("Model loaded successfully!")
            return model
        except Exception as e:
            print(f"Error loading model: {e}")
            return None
    return model

def cleanup_temp_files():
    try:
        temp_dir = tempfile.gettempdir()
        for filename in os.listdir(temp_dir):
            if filename.endswith('.png') and 'tmp' in filename:
                try:
                    os.remove(os.path.join(temp_dir, filename))
                except:
                    pass
    except:
        pass