Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import torch.nn as nn | |
| import io | |
| import numpy as np | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| import cv2 | |
| import base64 | |
| from pytorch_grad_cam import GradCAMPlusPlus | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from huggingface_hub import hf_hub_download | |
| from pydantic import BaseModel | |
| # --- Konfiguracja Logowania --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Konfiguracja --- | |
| HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII" | |
| MODEL_FILENAME = "best_model.pth" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| # Globalne zmienne dla modelu i transformacji | |
| model_instance = None | |
| transform_pipeline = None | |
| interpretations_dict = { | |
| 1: "Wynik negatywny - brak zmian nowotworowych", | |
| 2: "Zmiana 艂agodna", | |
| 3: "Prawdopodobnie zmiana 艂agodna - zalecana kontrola", | |
| 4: "Podejrzenie zmiany z艂o艣liwej - zalecana biopsja", | |
| 5: "Wysoka podejrzliwo艣膰 z艂o艣liwo艣ci - wymagana biopsja" | |
| } | |
| # --- Inicjalizacja modelu --- | |
| def initialize_model(): | |
| global model_instance, transform_pipeline | |
| if model_instance is not None: | |
| return | |
| logger.info("Rozpoczynanie inicjalizacji modelu...") | |
| try: | |
| hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ") | |
| model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token) | |
| logger.info(f"Plik modelu pomy艣lnie pobrany do: {model_pt_path}") | |
| except Exception as e: | |
| logger.error(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}", exc_info=True) | |
| raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}") | |
| model_arch = models.resnet18(weights=None) | |
| num_feats = model_arch.fc.in_features | |
| model_arch.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(num_feats, 5)) | |
| model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE)) | |
| model_arch.to(DEVICE) | |
| model_arch.eval() | |
| model_instance = model_arch | |
| transform_pipeline = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) | |
| ]) | |
| logger.info(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}") | |
| # --- Funkcja do predykcji z kwantyfikacj膮 niepewno艣ci (MC Dropout) --- | |
| def predict_with_mc_dropout(current_model_instance, batch_tensor_on_device, mc_dropout_samples: int, uncertainty_threshold_std: float): | |
| logger.info(f"Performing MC Dropout on a batch of size {batch_tensor_on_device.shape[0]} with {mc_dropout_samples} samples.") | |
| original_mode_is_training = current_model_instance.training | |
| current_model_instance.train() | |
| batch_size = batch_tensor_on_device.shape[0] | |
| num_classes = 5 | |
| all_probs_batch = np.zeros((batch_size, mc_dropout_samples, num_classes)) | |
| with torch.no_grad(): | |
| for i in range(mc_dropout_samples): | |
| output = current_model_instance(batch_tensor_on_device) | |
| probs_tensor = torch.nn.functional.softmax(output, dim=1) | |
| all_probs_batch[:, i, :] = probs_tensor.cpu().numpy() | |
| if not original_mode_is_training: | |
| current_model_instance.eval() | |
| mean_probabilities_batch = np.mean(all_probs_batch, axis=1) | |
| std_dev_probabilities_batch = np.std(all_probs_batch, axis=1) | |
| results = [] | |
| for i in range(batch_size): | |
| mean_probabilities = mean_probabilities_batch[i] | |
| std_dev_probabilities = std_dev_probabilities_batch[i] | |
| predicted_class_index = np.argmax(mean_probabilities) | |
| confidence_in_predicted_class = float(np.max(all_probs_batch[i, :, predicted_class_index])) | |
| uncertainty_metric = np.mean(std_dev_probabilities) | |
| is_uncertain = uncertainty_metric > uncertainty_threshold_std | |
| logger.info(f"MC Dropout Results for image {i}: Predicted Index: {int(predicted_class_index)}, Confidence (MaxProb): {confidence_in_predicted_class:.4f}, Uncertainty (avg_std): {uncertainty_metric:.4f}, Is Uncertain: {is_uncertain}") | |
| birads_category_if_confident = int(predicted_class_index) + 1 | |
| if is_uncertain: | |
| result = { | |
| "birads": None, "confidence": None, | |
| "interpretation": f"Model jest niepewny co do tego obrazu (niepewno艣膰: {uncertainty_metric:.4f}). Sprawd藕 jako艣膰 i typ obrazu.", | |
| "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))}, | |
| "grad_cam_image_base64": None, "error": "High prediction uncertainty", | |
| "details": f"Uncertainty metric ({uncertainty_metric:.4f}) przekroczy艂a pr贸g ({uncertainty_threshold_std})." | |
| } | |
| else: | |
| result = { | |
| "birads": birads_category_if_confident, | |
| "confidence": confidence_in_predicted_class, | |
| "interpretation": interpretations_dict.get(birads_category_if_confident, "Nieznana klasyfikacja"), | |
| "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))}, | |
| "grad_cam_image_base64": None, "error": None, | |
| "details": f"Uncertainty metric ({uncertainty_metric:.4f}) jest w granicach progu ({uncertainty_threshold_std}).", | |
| "predicted_class_index": predicted_class_index | |
| } | |
| results.append(result) | |
| return results | |
| # --- Funkcja do tworzenia obrazu z na艂o偶on膮 map膮 Grad-CAM --- | |
| def create_grad_cam_overlay_image(original_pil_image: Image.Image, grayscale_cam: np.ndarray, birads_category: int, transparency: float = 0.5) -> Image.Image: | |
| try: | |
| img_np = np.array(original_pil_image.convert('RGB')).astype(np.float32) / 255.0 | |
| cam_resized = cv2.resize(grayscale_cam, (img_np.shape[1], img_np.shape[0])) | |
| cam_normalized = (cam_resized - np.min(cam_resized)) / (np.max(cam_resized) - np.min(cam_resized) + 1e-8) | |
| threshold = 0.7 | |
| cam_normalized[cam_normalized < threshold] = 0 | |
| kernel = np.ones((5, 5), np.uint8) | |
| cam_cleaned = cv2.morphologyEx(cam_normalized, cv2.MORPH_OPEN, kernel) | |
| birads_colors_rgb = { | |
| 1: (0.1, 0.7, 0.1), 2: (0.53, 0.81, 0.92), 3: (1.0, 0.9, 0.0), | |
| 4: (1.0, 0.5, 0.0), 5: (0.9, 0.1, 0.1) | |
| } | |
| chosen_color = np.array(birads_colors_rgb.get(birads_category, (0.5, 0.5, 0.5))) | |
| color_overlay_np = np.zeros_like(img_np) | |
| for c in range(3): color_overlay_np[:, :, c] = chosen_color[c] | |
| alpha = cam_cleaned * transparency | |
| alpha_expanded = alpha[..., np.newaxis] | |
| highlighted_image_np = img_np * (1 - alpha_expanded) + color_overlay_np * alpha_expanded | |
| highlighted_image_np = np.clip(highlighted_image_np, 0, 1) | |
| final_image_np = (highlighted_image_np * 255).astype(np.uint8) | |
| return Image.fromarray(final_image_np) | |
| except Exception as e: | |
| logger.error(f"B艂膮d podczas tworzenia obrazu Grad-CAM overlay: {e}", exc_info=True) | |
| return None | |
| # --- ZAKTUALIZOWANA Funkcja do heurystycznych test贸w OOD --- | |
| def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str, colorfulness_threshold: float, uniformity_threshold: float, aspect_ratio_min: float, aspect_ratio_max: float) -> Optional[str]: | |
| """ | |
| Wykonuje heurystyki OOD. Zwraca konkretny komunikat b艂臋du w razie problemu, w przeciwnym razie None. | |
| """ | |
| logger.info(f"[RequestID: {request_id}] Uruchamianie heurystycznych test贸w OOD...") | |
| width, height = pil_image.size | |
| # Sprawdzimy najpierw kolorowo艣膰, bo to najcz臋stszy problem | |
| img_rgb_for_color_check = pil_image.convert('RGB') | |
| img_np_rgb = np.array(img_rgb_for_color_check) | |
| mean_std_across_channels = np.mean(np.std(img_np_rgb, axis=2)) | |
| logger.info(f"[RequestID: {request_id}] Heurystyka: Kolorowo艣膰 = {mean_std_across_channels:.2f} (pr贸g: {colorfulness_threshold})") | |
| if mean_std_across_channels > colorfulness_threshold: | |
| # Ten komunikat jest teraz bardziej specyficzny | |
| msg = f"Wykryto kolorowy obraz (wska藕nik: {mean_std_across_channels:.2f}). System oczekuje obrazu w skali szaro艣ci, typowego dla bada艅 medycznych." | |
| logger.warning(f"[RequestID: {request_id}] Heurystyka OOD ODRZUCONA: {msg}") | |
| # Zwracamy specjalny typ b艂臋du, kt贸ry potem rozpoznamy | |
| return f"INVALID_IMAGE_TYPE: {msg}" | |
| aspect_ratio = width / height | |
| if not (aspect_ratio_min < aspect_ratio < aspect_ratio_max): | |
| msg = f"Nietypowe proporcje obrazu: {aspect_ratio:.2f}." | |
| return f"HEURISTIC_FAILED: {msg}" | |
| gray_image = pil_image.convert('L') | |
| std_dev_intensity = np.std(np.array(gray_image)) | |
| if std_dev_intensity < uniformity_threshold: | |
| msg = f"Obraz wydaje si臋 zbyt jednolity (np. ca艂y czarny): {std_dev_intensity:.2f}." | |
| return f"HEURISTIC_FAILED: {msg}" | |
| logger.info(f"[RequestID: {request_id}] Heurystyczne testy OOD zako艅czone pomy艣lnie.") | |
| return None | |
| # --- Aplikacja FastAPI --- | |
| class PredictionResult(BaseModel): | |
| birads: Optional[int] = None | |
| confidence: Optional[float] = None | |
| interpretation: str | |
| class_probabilities: Dict[str, float] | |
| grad_cam_image_base64: Optional[str] = None | |
| error: Optional[str] = None | |
| details: Optional[str] = None | |
| app = FastAPI(title="BI-RADS Mammography Classification API") | |
| async def startup_event(): | |
| logger.info("Rozpoczynanie eventu startup aplikacji FastAPI.") | |
| initialize_model() | |
| # --- ZAKTUALIZOWANY Endpoint /predict/ --- | |
| async def predict_images( | |
| files: List[UploadFile] = File(...), | |
| colorfulness_threshold: float = 2.0, | |
| uniformity_threshold: float = 10.0, | |
| aspect_ratio_min: float = 0.4, | |
| aspect_ratio_max: float = 2.5, | |
| mc_dropout_samples: int = 25, | |
| uncertainty_threshold_std: float = 0.11 | |
| ): | |
| request_id = os.urandom(8).hex() | |
| logger.info(f"[RequestID: {request_id}] Otrzymano 偶膮danie /predict/ dla {len(files)} plik贸w.") | |
| if model_instance is None or transform_pipeline is None: | |
| raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany.") | |
| all_results = [] | |
| valid_images_pil = [] | |
| valid_tensors = [] | |
| original_indices = [] | |
| for idx, file in enumerate(files): | |
| try: | |
| contents = await file.read() | |
| image_pil_original = Image.open(io.BytesIO(contents)) | |
| ood_error_details = run_heuristic_ood_checks( | |
| image_pil_original.copy(), request_id, | |
| colorfulness_threshold, uniformity_threshold, aspect_ratio_min, aspect_ratio_max | |
| ) | |
| if ood_error_details: | |
| # Rozpoznajemy nasz specjalny typ b艂臋du | |
| if ood_error_details.startswith("INVALID_IMAGE_TYPE"): | |
| error_type = "Invalid Image Type" | |
| interpretation = "Przes艂any plik nie wygl膮da na obraz mammograficzny. Prosz臋 wgra膰 odpowiednie zdj臋cie USG." | |
| details = ood_error_details.replace("INVALID_IMAGE_TYPE: ", "") | |
| else: # Pozosta艂e b艂臋dy heurystyczne | |
| error_type = "Heuristic OOD check failed" | |
| interpretation = "Obraz odrzucony przez wst臋pne testy. Mo偶e mie膰 nietypowe wymiary lub by膰 zbyt jednolity." | |
| details = ood_error_details.replace("HEURISTIC_FAILED: ", "") | |
| result = PredictionResult( | |
| interpretation=interpretation, | |
| class_probabilities={}, error=error_type, | |
| details=details | |
| ) | |
| all_results.append((idx, result)) | |
| continue | |
| image_rgb = image_pil_original.convert("RGB") | |
| input_tensor = transform_pipeline(image_rgb).unsqueeze(0).to(DEVICE) | |
| valid_images_pil.append(image_rgb) | |
| valid_tensors.append(input_tensor) | |
| original_indices.append(idx) | |
| except Exception as e: | |
| logger.error(f"[RequestID: {request_id}] B艂膮d podczas odczytu pliku {file.filename}: {e}", exc_info=True) | |
| result = PredictionResult( | |
| interpretation="B艂膮d podczas przetwarzania pliku.", class_probabilities={}, | |
| error="File processing error.", details=str(e) | |
| ) | |
| all_results.append((idx, result)) | |
| if valid_tensors: | |
| batch_tensor = torch.cat(valid_tensors, dim=0) | |
| logger.info(f"[RequestID: {request_id}] Przetwarzanie wsadu {batch_tensor.shape[0]} poprawnych obraz贸w.") | |
| mc_results = predict_with_mc_dropout(model_instance, batch_tensor, mc_dropout_samples, uncertainty_threshold_std) | |
| model_instance.eval() | |
| target_layers = [model_instance.layer4[-1]] | |
| cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers) | |
| for i, result_dict in enumerate(mc_results): | |
| if not result_dict.get("error"): | |
| birads_cat = result_dict["birads"] | |
| pred_idx = result_dict["predicted_class_index"] | |
| input_tensor_for_cam = batch_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True) | |
| targets_for_cam = [ClassifierOutputTarget(pred_idx)] | |
| grayscale_cam = cam_algorithm(input_tensor=input_tensor_for_cam, targets=targets_for_cam) | |
| if grayscale_cam is not None: | |
| overlay_image_pil = create_grad_cam_overlay_image( | |
| original_pil_image=valid_images_pil[i], | |
| grayscale_cam=grayscale_cam[0, :], | |
| birads_category=birads_cat | |
| ) | |
| if overlay_image_pil: | |
| buffered = io.BytesIO() | |
| overlay_image_pil.save(buffered, format="PNG") | |
| result_dict["grad_cam_image_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| result_dict.pop("predicted_class_index", None) | |
| all_results.append((original_indices[i], PredictionResult(**result_dict))) | |
| all_results.sort(key=lambda x: x[0]) | |
| final_results = [res for _, res in all_results] | |
| return final_results | |
| async def root(): | |
| logger.info("Otrzymano 偶膮danie GET na /") | |
| return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."} |