Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| # 1. Define the class labels for each model. | |
| CLASS_LABELS = { | |
| "Pneumonia": ["Normal", "Pneumonia"], | |
| "Tuberculosis": ["Normal", "Tuberculosis"], | |
| "Brain Tumor": ["glioma", "meningioma", "notumor", "pituitary"], | |
| "Skin Cancer": [ | |
| "Actinic keratoses", "Basal cell carcinoma", | |
| "Benign keratosis-like lesions", "Dermatofibroma", | |
| "Melanoma", "Melanocytic nevi", "Vascular lesions" | |
| ], | |
| "Kvasir": [ | |
| "dyed-lifted-polyps", "dyed-resection-margins", "esophagitis", | |
| "normal-cecum", "normal-pylorus", "normal-z-line", "polyps", | |
| "ulcerative-colitis" | |
| ] | |
| } | |
| # 2. Map the analysis keys from the UI to our internal model names. | |
| ANALYSIS_MAP = { | |
| "Brain Tumor": "Brain Tumor", | |
| "Endoscopic Ulcer (Kvasir)": "Kvasir", | |
| "Pneumonia": "Pneumonia", | |
| "Skin Cancer": "Skin Cancer", | |
| "Tuberculosis": "Tuberculosis" | |
| } | |
| # --- MODIFIED SECTION --- | |
| # 3. Map model names to their .h5 file paths inside the 'model' directory. | |
| MODEL_PATHS = { | |
| "Brain Tumor": "model/best_braintumor_model.h5", | |
| "Kvasir": "model/best_kvasir_model.h5", | |
| "Pneumonia": "model/best_pneumonia_model.h5", | |
| "Skin Cancer": "model/best_skincancer_model.h5", | |
| "Tuberculosis": "model/best_tuberculosis_model.h5" | |
| } | |
| # 4. Image preprocessing function | |
| def preprocess_image(image_file, target_size=(224, 224)): | |
| """Preprocesses the uploaded image for model prediction.""" | |
| img = Image.open(image_file).convert('RGB') | |
| img = img.resize(target_size) | |
| img_array = tf.keras.preprocessing.image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) / 255.0 | |
| return img_array | |
| # 5. Prediction function | |
| def predict(image_file, target_disease): | |
| """Loads the appropriate model and returns a prediction.""" | |
| try: | |
| model_path = MODEL_PATHS.get(target_disease) | |
| if not model_path or not os.path.exists(model_path): | |
| return "Error", f"Model file not found at {model_path}. Make sure it's in the 'model' directory." | |
| model = tf.keras.models.load_model(model_path, compile=False) | |
| processed_image = preprocess_image(image_file) | |
| prediction = model.predict(processed_image) | |
| class_labels = CLASS_LABELS[target_disease] | |
| if len(class_labels) == 2: # Binary classification | |
| score = prediction[0][0] | |
| confidence = score * 100 if score > 0.5 else (1 - score) * 100 | |
| label = class_labels[1] if score > 0.5 else class_labels[0] | |
| else: # Multi-class classification | |
| score = np.max(prediction) | |
| confidence = score * 100 | |
| label_index = np.argmax(prediction) | |
| label = class_labels[label_index] | |
| return label, confidence | |
| except Exception as e: | |
| return "Error", f"An error occurred: {str(e)}" | |
| # 6. Explanation generation function (placeholder) | |
| def generate_explanation(label, confidence): | |
| """Generates a simple explanation for the prediction.""" | |
| explanations = { | |
| "Pneumonia": "The model detected patterns consistent with pneumonia, an infection that inflames the air sacs in one or both lungs.", | |
| "Tuberculosis": "The model identified signs often associated with tuberculosis, a serious infectious disease that mainly affects the lungs.", | |
| "Melanoma": "This finding is consistent with melanoma, the most serious type of skin cancer. Key indicators can include asymmetry, irregular borders, and varied color.", | |
| "glioma": "The prediction points towards a glioma, a type of tumor that occurs in the brain and spinal cord, starting in the glial cells.", | |
| "ulcerative-colitis": "The model detected inflammation and ulcers in the colon's lining, which are characteristic of ulcerative colitis." | |
| } | |
| base_explanation = explanations.get(label, f"The model identified features consistent with {label}.") | |
| return f"{base_explanation} The model's confidence in this finding is {confidence:.2f}%." |