Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| import logging | |
| import numpy as np | |
| from PIL import Image | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Suppress TensorFlow logging | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| try: | |
| import absl.logging | |
| absl.logging.set_verbosity(absl.logging.ERROR) | |
| except ImportError: | |
| pass | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| from model import RawImageModel, PrecomputedModel | |
| # Global Model Instances | |
| raw_model = None | |
| precomputed_model = None | |
| pos_emb = None | |
| neg_emb = None | |
| # Optimal Threshold from Kaggle validation | |
| THRESHOLD = -0.1173 | |
| def load_models(): | |
| global raw_model, precomputed_model, pos_emb, neg_emb | |
| if raw_model is None: | |
| logger.info("Loading models...") | |
| try: | |
| precomputed_model = PrecomputedModel() | |
| raw_model = RawImageModel() | |
| # Pre-fetch text embeddings | |
| pos_txt = 'small pneumothorax' | |
| neg_txt = 'no pneumothorax' | |
| pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) | |
| logger.info("Models loaded.") | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| raise e | |
| # ZeroGPU compatibility for Hugging Face Spaces | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Dummy decorator if running locally without spaces installed | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| def predict(image): | |
| if image is None: | |
| return "No image uploaded.", 0.0, "Please upload an image." | |
| try: | |
| # Save temp image for model ingestion | |
| temp_path = "temp_gradio_upload.png" | |
| image.save(temp_path) | |
| # Run Inference | |
| img_emb = raw_model.compute_embeddings(temp_path) | |
| score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb) | |
| score = float(score) | |
| # Binary Classification | |
| if score >= THRESHOLD: | |
| prediction = '<p style="color: red; font-size: 24px; font-weight: bold;">PNEUMOTHORAX ⚠️</p>' | |
| else: | |
| prediction = '<p style="color: green; font-size: 24px; font-weight: bold;">NORMAL ✅</p>' | |
| return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})" | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {e}") | |
| return "<p style='color:red'>Error</p>", 0.0, str(e) | |
| # Load models at startup | |
| load_models() | |
| # UI Layout | |
| with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo: | |
| gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification") | |
| gr.Markdown("Detect Pneumothorax from raw X-ray images using a pre-trained foundation model.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 1. Upload X-Ray") | |
| input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)") | |
| predict_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### 2. Results") | |
| output_label = gr.HTML(label="Prediction") | |
| output_score = gr.Number(label="Zero-Shot Score") | |
| output_msg = gr.Textbox(label="Details") | |
| gr.Markdown("---") | |
| gr.Markdown("### Performance Context") | |
| gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.") | |
| with gr.Tabs(): | |
| with gr.TabItem("Local Kaggle Benchmark"): | |
| gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve") | |
| gr.Markdown("**AUC: 0.88** on 250 local samples.") | |
| with gr.TabItem("Google Benchmark"): | |
| gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC") | |
| predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |