File size: 2,871 Bytes
5ec611f
fc87468
48badb6
 
fc87468
5ec611f
48badb6
5ec611f
48badb6
fc87468
48badb6
 
 
89f3ba0
48badb6
 
 
 
 
fc87468
 
48badb6
89f3ba0
 
 
 
 
 
48badb6
 
89f3ba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d02841f
48badb6
fc87468
48badb6
d02841f
89f3ba0
48badb6
 
 
89f3ba0
48badb6
89f3ba0
fc87468
48badb6
5ec611f
48badb6
5ec611f
fc87468
 
 
 
 
 
 
48badb6
89f3ba0
 
48badb6
fc87468
48badb6
fc87468
48badb6
fc87468
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
import os
import urllib.request
from torchvision import transforms
from PIL import Image
import torch.nn as nn

# إعدادات النموذج
REPO_ID = "Alhdrawi/x_alhdrawi"
MODEL_FILE = "best_128_0.0002_original_15000_0.859.pt"
MODEL_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/{MODEL_FILE}"
MODEL_LOCAL_PATH = f"/tmp/{MODEL_FILE}"

# قائمة الأمراض التي يتوقعها النموذج
diseases = [
    "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
    "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
    "Pleural_Thickening", "Pneumonia", "Pneumothorax"
]

# تحويل الصورة مثل ما دربت النموذج
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])

# تعريف بنية النموذج (نفس اللي استخدمته وقت التدريب)
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=14):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# تحميل النموذج
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(diseases)).to(device)

def download_and_load_model():
    if not os.path.exists(MODEL_LOCAL_PATH):
        print(f"Downloading model from {MODEL_URL}")
        urllib.request.urlretrieve(MODEL_URL, MODEL_LOCAL_PATH)
    
    state_dict = torch.load(MODEL_LOCAL_PATH, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    print(f"✅ Model loaded from {MODEL_FILE}")

# دالة التنبؤ
def predict(image):
    img = transform(image.convert("L")).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img)
        probs = torch.sigmoid(outputs).cpu().squeeze().numpy()
    results = {d: round(float(p), 3) for d, p in zip(diseases, probs)}
    return results

# تحميل النموذج عند بدء التشغيل
download_and_load_model()

# واجهة Gradio
with gr.Blocks() as demo:
    gr.Markdown(f"## 🧠 CheXzero | النموذج المستخدم: `{MODEL_FILE}`")
    with gr.Row():
        image_input = gr.Image(type="pil", label="صورة أشعة X-Ray")
        output = gr.Label(num_top_classes=5)

    image_input.change(fn=predict, inputs=image_input, outputs=output)

demo.launch()