Spaces:
Runtime error
Runtime error
File size: 2,871 Bytes
5ec611f fc87468 48badb6 fc87468 5ec611f 48badb6 5ec611f 48badb6 fc87468 48badb6 89f3ba0 48badb6 fc87468 48badb6 89f3ba0 48badb6 89f3ba0 d02841f 48badb6 fc87468 48badb6 d02841f 89f3ba0 48badb6 89f3ba0 48badb6 89f3ba0 fc87468 48badb6 5ec611f 48badb6 5ec611f fc87468 48badb6 89f3ba0 48badb6 fc87468 48badb6 fc87468 48badb6 fc87468 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import torch
import os
import urllib.request
from torchvision import transforms
from PIL import Image
import torch.nn as nn
# إعدادات النموذج
REPO_ID = "Alhdrawi/x_alhdrawi"
MODEL_FILE = "best_128_0.0002_original_15000_0.859.pt"
MODEL_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/{MODEL_FILE}"
MODEL_LOCAL_PATH = f"/tmp/{MODEL_FILE}"
# قائمة الأمراض التي يتوقعها النموذج
diseases = [
"Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
"Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
"Pleural_Thickening", "Pneumonia", "Pneumothorax"
]
# تحويل الصورة مثل ما دربت النموذج
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229])
])
# تعريف بنية النموذج (نفس اللي استخدمته وقت التدريب)
class SimpleCNN(nn.Module):
def __init__(self, num_classes=14):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# تحميل النموذج
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(diseases)).to(device)
def download_and_load_model():
if not os.path.exists(MODEL_LOCAL_PATH):
print(f"Downloading model from {MODEL_URL}")
urllib.request.urlretrieve(MODEL_URL, MODEL_LOCAL_PATH)
state_dict = torch.load(MODEL_LOCAL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print(f"✅ Model loaded from {MODEL_FILE}")
# دالة التنبؤ
def predict(image):
img = transform(image.convert("L")).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
probs = torch.sigmoid(outputs).cpu().squeeze().numpy()
results = {d: round(float(p), 3) for d, p in zip(diseases, probs)}
return results
# تحميل النموذج عند بدء التشغيل
download_and_load_model()
# واجهة Gradio
with gr.Blocks() as demo:
gr.Markdown(f"## 🧠 CheXzero | النموذج المستخدم: `{MODEL_FILE}`")
with gr.Row():
image_input = gr.Image(type="pil", label="صورة أشعة X-Ray")
output = gr.Label(num_top_classes=5)
image_input.change(fn=predict, inputs=image_input, outputs=output)
demo.launch()
|