| import os |
| import io |
| import cv2 |
| import json |
| import time |
| import math |
| import base64 |
| import queue |
| import shutil |
| import numpy as np |
| import requests |
| import onnxruntime as ort |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| MODEL_URL = "https://github.com/mdciri/YOLOv7-Bone-Fracture-Detection/releases/download/trained-models/yolov7-p6-bonefracture.onnx" |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "models") |
| MODEL_PATH = os.path.join(MODEL_DIR, "yolov7-p6-bonefracture.onnx") |
| INPUT_SIZE = 640 |
| CONF_THRES_DEFAULT = 0.25 |
| IOU_THRES_DEFAULT = 0.45 |
|
|
| |
| CLASSES = [ |
| "boneanomaly", |
| "bonelesion", |
| "foreignbody", |
| "fracture", |
| "metal", |
| "periostealreaction", |
| "pronatorsign", |
| "softtissue", |
| "text", |
| ] |
|
|
| _session = None |
| _input_name = None |
| _output_name = None |
|
|
|
|
| def ensure_model_available(): |
| os.makedirs(MODEL_DIR, exist_ok=True) |
| if not os.path.exists(MODEL_PATH): |
| try: |
| with requests.get(MODEL_URL, stream=True, timeout=120) as r: |
| r.raise_for_status() |
| tmp_path = MODEL_PATH + ".downloading" |
| with open(tmp_path, "wb") as f: |
| for chunk in r.iter_content(chunk_size=1 << 20): |
| if chunk: |
| f.write(chunk) |
| os.replace(tmp_path, MODEL_PATH) |
| except Exception as e: |
| raise RuntimeError( |
| "Téléchargement du modèle échoué. Activez Internet dans les paramètres du Space ou réessayez plus tard. Détails: " |
| + str(e) |
| ) |
|
|
|
|
| def load_session(): |
| global _session, _input_name, _output_name |
| if _session is None: |
| ensure_model_available() |
| providers = ["CPUExecutionProvider"] |
| _session = ort.InferenceSession(MODEL_PATH, providers=providers) |
| _input_name = _session.get_inputs()[0].name |
| _output_name = _session.get_outputs()[0].name |
| return _session |
|
|
|
|
| def ensure_rgb(image: np.ndarray) -> np.ndarray: |
| """Ensure input image is 3-channel RGB.""" |
| if image is None: |
| return image |
| if image.ndim == 2: |
| |
| return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| if image.ndim == 3 and image.shape[2] == 4: |
| |
| return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
| return image |
|
|
|
|
| def letterbox(im, new_shape=(INPUT_SIZE, INPUT_SIZE), color=(114, 114, 114)): |
| shape = im.shape[:2] |
| r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) |
| nh, nw = int(round(shape[0] * r)), int(round(shape[1] * r)) |
| im_resized = cv2.resize(im, (nw, nh), interpolation=cv2.INTER_LINEAR) |
| top = (new_shape[0] - nh) // 2 |
| bottom = new_shape[0] - nh - top |
| left = (new_shape[1] - nw) // 2 |
| right = new_shape[1] - nw - left |
| im_padded = cv2.copyMakeBorder(im_resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) |
| return im_padded, r, (left, top) |
|
|
|
|
| def xywh2xyxy(x): |
| y = x.copy() |
| y[:, 0] = x[:, 0] - x[:, 2] / 2 |
| y[:, 1] = x[:, 1] - x[:, 3] / 2 |
| y[:, 2] = x[:, 0] + x[:, 2] / 2 |
| y[:, 3] = x[:, 1] + x[:, 3] / 2 |
| return y |
|
|
|
|
| def nms(boxes, scores, iou_thres=0.45): |
| idxs = scores.argsort()[::-1] |
| keep = [] |
| while idxs.size > 0: |
| i = idxs[0] |
| keep.append(i) |
| if idxs.size == 1: |
| break |
| ious = iou(boxes[i], boxes[idxs[1:]]) |
| idxs = idxs[1:][ious < iou_thres] |
| return keep |
|
|
|
|
| def iou(box, boxes): |
| x1 = np.maximum(box[0], boxes[:, 0]) |
| y1 = np.maximum(box[1], boxes[:, 1]) |
| x2 = np.minimum(box[2], boxes[:, 2]) |
| y2 = np.minimum(box[3], boxes[:, 3]) |
| inter = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1) |
| area1 = (box[2] - box[0]) * (box[3] - box[1]) |
| area2 = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
| union = area1 + area2 - inter + 1e-16 |
| return inter / union |
|
|
|
|
| def scale_boxes(boxes, gain, pad): |
| boxes[:, [0, 2]] -= pad[0] |
| boxes[:, [1, 3]] -= pad[1] |
| boxes[:, :4] /= gain |
| return boxes |
|
|
|
|
| def infer_yolov7(image_rgb, conf_thres=0.25, iou_thres=0.45, only_fracture=True): |
| h0, w0 = image_rgb.shape[:2] |
| image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) |
| |
| img = cv2.resize(image_bgr, (INPUT_SIZE, INPUT_SIZE), interpolation=cv2.INTER_LINEAR) |
| img = img.astype(np.float32) / 255.0 |
| img = np.transpose(img, (2, 0, 1)) |
| img = np.expand_dims(img, 0) |
|
|
| session = load_session() |
| pred = session.run([_output_name], {_input_name: img})[0] |
| if pred.ndim == 3: |
| pred = pred[0] |
| |
| if pred.size == 0: |
| return [] |
| boxes_xyxy = pred[:, 0:4].astype(np.float32) |
| scores = pred[:, 4].astype(np.float32) |
| labels = pred[:, 5].astype(np.int32) |
|
|
| |
| mask = scores >= conf_thres |
| boxes_xyxy = boxes_xyxy[mask] |
| scores = scores[mask] |
| labels = labels[mask] |
|
|
| if boxes_xyxy.shape[0] == 0: |
| return [] |
|
|
| |
| sx = w0 / float(INPUT_SIZE) |
| sy = h0 / float(INPUT_SIZE) |
| boxes_xyxy[:, [0, 2]] *= sx |
| boxes_xyxy[:, [1, 3]] *= sy |
|
|
| dets = [] |
| for b, c, s in zip(boxes_xyxy, labels, scores): |
| x1, y1, x2, y2 = b.tolist() |
| x1 = max(0, min(w0 - 1, x1)) |
| y1 = max(0, min(h0 - 1, y1)) |
| x2 = max(0, min(w0 - 1, x2)) |
| y2 = max(0, min(h0 - 1, y2)) |
| name = CLASSES[c] if 0 <= c < len(CLASSES) else str(int(c)) |
| if only_fracture and name != "fracture": |
| continue |
| dets.append({ |
| "box": [float(x1), float(y1), float(x2), float(y2)], |
| "score": float(s), |
| "class_id": int(c), |
| "class_name": name, |
| }) |
| return dets |
|
|
|
|
| def draw_detections(image_rgb, dets): |
| img = image_rgb.copy() |
| for d in dets: |
| x1, y1, x2, y2 = map(int, d["box"]) |
| name = d["class_name"] |
| score = d["score"] |
| color = (255, 0, 0) if name == "fracture" else (0, 150, 255) |
| cv2.rectangle(img, (x1, y1), (x2, y2), color, 3) |
| label = f"{name}:{score:.2f}" |
| (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) |
| y1_text = max(0, y1 - 8) |
| cv2.rectangle(img, (x1, y1_text - th - 6), (x1 + tw + 6, y1_text + 2), color, -1) |
| cv2.putText(img, label, (x1 + 3, y1_text), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) |
| return img |
|
|
|
|
| def predict(image, region, conf_thres, iou_thres, show_non_fracture): |
| if image is None: |
| return None, json.dumps({"error": "Aucune image fournie."}, ensure_ascii=False, indent=2) |
|
|
| |
| image = ensure_rgb(image) |
|
|
| only_fracture = not show_non_fracture |
|
|
| start = time.time() |
| try: |
| dets = infer_yolov7(image, conf_thres=conf_thres, iou_thres=iou_thres, only_fracture=only_fracture) |
| except Exception as e: |
| msg = str(e) |
| return None, json.dumps({"error": msg}, ensure_ascii=False, indent=2) |
| elapsed = time.time() - start |
|
|
| annotated = draw_detections(image, dets) |
| resp = { |
| "region": region, |
| "detections": dets, |
| "count": len(dets), |
| "time_s": round(elapsed, 3), |
| "note": "Modèle entraîné sur le poignet (GRAZPEDWRI-DX). Les autres régions sont exploratoires.", |
| "medical_warning": "Cet outil n’est pas un dispositif médical. Il ne remplace pas l’avis d’un(e) radiologue/médecin.", |
| } |
| return annotated, json.dumps(resp, ensure_ascii=False, indent=2) |
|
|
|
|
| def build_ui(): |
| with gr.Blocks(title="Détection de fracture (Radiographie)") as demo: |
| gr.Markdown(""" |
| # Détection de fracture (Radiographie) — Prototype |
| - Interface en français, fonctionnement 100% en ligne. |
| - Téléversez une radiographie, puis lancez l’analyse. |
| - Modèle détection (boîtes) entraîné sur le poignet; autres régions = usage exploratoire. |
| - N’est pas un dispositif médical. |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| inp = gr.Image(type="numpy", label="Téléverser une radiographie") |
| with gr.Column(scale=1): |
| region = gr.Dropdown( |
| choices=[ |
| "Poignet (modèle entraîné)", |
| "Autre (exploratoire)", |
| ], |
| value="Poignet (modèle entraîné)", |
| label="Région anatomique", |
| ) |
| conf = gr.Slider(0.05, 0.9, value=CONF_THRES_DEFAULT, step=0.01, label="Seuil de confiance") |
| iou = gr.Slider(0.1, 0.9, value=IOU_THRES_DEFAULT, step=0.01, label="Seuil NMS (IoU)") |
| show_non_frac = gr.Checkbox(False, label="Afficher aussi les autres classes (non-fracture)") |
| btn = gr.Button("Analyser", variant="primary") |
|
|
| with gr.Row(): |
| out_img = gr.Image(type="numpy", label="Résultat annoté") |
| out_json = gr.Code(language="json", label="Détails des détections") |
|
|
| btn.click(predict, inputs=[inp, region, conf, iou, show_non_frac], outputs=[out_img, out_json]) |
|
|
| gr.Markdown(""" |
| ### Avertissement |
| Cet outil sert d’aide et ne remplace pas un avis médical professionnel. |
| """) |
|
|
| return demo |
|
|
|
|
| demo = build_ui() |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|