from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import torch.nn as nn import io import numpy as np import os from typing import List, Dict, Any, Optional import logging import cv2 import base64 from pytorch_grad_cam import GradCAMPlusPlus from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from huggingface_hub import hf_hub_download from pydantic import BaseModel # --- Konfiguracja Logowania --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Konfiguracja --- HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII" MODEL_FILENAME = "best_model.pth" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Globalne zmienne dla modelu i transformacji model_instance = None transform_pipeline = None interpretations_dict = { 1: "Wynik negatywny - brak zmian nowotworowych", 2: "Zmiana łagodna", 3: "Prawdopodobnie zmiana łagodna - zalecana kontrola", 4: "Podejrzenie zmiany złośliwej - zalecana biopsja", 5: "Wysoka podejrzliwość złośliwości - wymagana biopsja" } # --- Inicjalizacja modelu --- def initialize_model(): global model_instance, transform_pipeline if model_instance is not None: return logger.info("Rozpoczynanie inicjalizacji modelu...") try: hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ") model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token) logger.info(f"Plik modelu pomyślnie pobrany do: {model_pt_path}") except Exception as e: logger.error(f"Błąd podczas pobierania modelu z Hugging Face Hub: {e}", exc_info=True) raise RuntimeError(f"Nie można pobrać modelu: {e}") model_arch = models.resnet18(weights=None) num_feats = model_arch.fc.in_features model_arch.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(num_feats, 5)) model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE)) model_arch.to(DEVICE) model_arch.eval() model_instance = model_arch transform_pipeline = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) logger.info(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}") # --- Funkcja do predykcji z kwantyfikacją niepewności (MC Dropout) --- def predict_with_mc_dropout(current_model_instance, batch_tensor_on_device, mc_dropout_samples: int, uncertainty_threshold_std: float): logger.info(f"Performing MC Dropout on a batch of size {batch_tensor_on_device.shape[0]} with {mc_dropout_samples} samples.") original_mode_is_training = current_model_instance.training current_model_instance.train() batch_size = batch_tensor_on_device.shape[0] num_classes = 5 all_probs_batch = np.zeros((batch_size, mc_dropout_samples, num_classes)) with torch.no_grad(): for i in range(mc_dropout_samples): output = current_model_instance(batch_tensor_on_device) probs_tensor = torch.nn.functional.softmax(output, dim=1) all_probs_batch[:, i, :] = probs_tensor.cpu().numpy() if not original_mode_is_training: current_model_instance.eval() mean_probabilities_batch = np.mean(all_probs_batch, axis=1) std_dev_probabilities_batch = np.std(all_probs_batch, axis=1) results = [] for i in range(batch_size): mean_probabilities = mean_probabilities_batch[i] std_dev_probabilities = std_dev_probabilities_batch[i] predicted_class_index = np.argmax(mean_probabilities) confidence_in_predicted_class = float(np.max(all_probs_batch[i, :, predicted_class_index])) uncertainty_metric = np.mean(std_dev_probabilities) is_uncertain = uncertainty_metric > uncertainty_threshold_std logger.info(f"MC Dropout Results for image {i}: Predicted Index: {int(predicted_class_index)}, Confidence (MaxProb): {confidence_in_predicted_class:.4f}, Uncertainty (avg_std): {uncertainty_metric:.4f}, Is Uncertain: {is_uncertain}") birads_category_if_confident = int(predicted_class_index) + 1 if is_uncertain: result = { "birads": None, "confidence": None, "interpretation": f"Model jest niepewny co do tego obrazu (niepewność: {uncertainty_metric:.4f}). Sprawdź jakość i typ obrazu.", "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))}, "grad_cam_image_base64": None, "error": "High prediction uncertainty", "details": f"Uncertainty metric ({uncertainty_metric:.4f}) przekroczyła próg ({uncertainty_threshold_std})." } else: result = { "birads": birads_category_if_confident, "confidence": confidence_in_predicted_class, "interpretation": interpretations_dict.get(birads_category_if_confident, "Nieznana klasyfikacja"), "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))}, "grad_cam_image_base64": None, "error": None, "details": f"Uncertainty metric ({uncertainty_metric:.4f}) jest w granicach progu ({uncertainty_threshold_std}).", "predicted_class_index": predicted_class_index } results.append(result) return results # --- Funkcja do tworzenia obrazu z nałożoną mapą Grad-CAM --- def create_grad_cam_overlay_image(original_pil_image: Image.Image, grayscale_cam: np.ndarray, birads_category: int, transparency: float = 0.5) -> Image.Image: try: img_np = np.array(original_pil_image.convert('RGB')).astype(np.float32) / 255.0 cam_resized = cv2.resize(grayscale_cam, (img_np.shape[1], img_np.shape[0])) cam_normalized = (cam_resized - np.min(cam_resized)) / (np.max(cam_resized) - np.min(cam_resized) + 1e-8) threshold = 0.7 cam_normalized[cam_normalized < threshold] = 0 kernel = np.ones((5, 5), np.uint8) cam_cleaned = cv2.morphologyEx(cam_normalized, cv2.MORPH_OPEN, kernel) birads_colors_rgb = { 1: (0.1, 0.7, 0.1), 2: (0.53, 0.81, 0.92), 3: (1.0, 0.9, 0.0), 4: (1.0, 0.5, 0.0), 5: (0.9, 0.1, 0.1) } chosen_color = np.array(birads_colors_rgb.get(birads_category, (0.5, 0.5, 0.5))) color_overlay_np = np.zeros_like(img_np) for c in range(3): color_overlay_np[:, :, c] = chosen_color[c] alpha = cam_cleaned * transparency alpha_expanded = alpha[..., np.newaxis] highlighted_image_np = img_np * (1 - alpha_expanded) + color_overlay_np * alpha_expanded highlighted_image_np = np.clip(highlighted_image_np, 0, 1) final_image_np = (highlighted_image_np * 255).astype(np.uint8) return Image.fromarray(final_image_np) except Exception as e: logger.error(f"Błąd podczas tworzenia obrazu Grad-CAM overlay: {e}", exc_info=True) return None # --- ZAKTUALIZOWANA Funkcja do heurystycznych testów OOD --- def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str, colorfulness_threshold: float, uniformity_threshold: float, aspect_ratio_min: float, aspect_ratio_max: float) -> Optional[str]: """ Wykonuje heurystyki OOD. Zwraca konkretny komunikat błędu w razie problemu, w przeciwnym razie None. """ logger.info(f"[RequestID: {request_id}] Uruchamianie heurystycznych testów OOD...") width, height = pil_image.size # Sprawdzimy najpierw kolorowość, bo to najczęstszy problem img_rgb_for_color_check = pil_image.convert('RGB') img_np_rgb = np.array(img_rgb_for_color_check) mean_std_across_channels = np.mean(np.std(img_np_rgb, axis=2)) logger.info(f"[RequestID: {request_id}] Heurystyka: Kolorowość = {mean_std_across_channels:.2f} (próg: {colorfulness_threshold})") if mean_std_across_channels > colorfulness_threshold: # Ten komunikat jest teraz bardziej specyficzny msg = f"Wykryto kolorowy obraz (wskaźnik: {mean_std_across_channels:.2f}). System oczekuje obrazu w skali szarości, typowego dla badań medycznych." logger.warning(f"[RequestID: {request_id}] Heurystyka OOD ODRZUCONA: {msg}") # Zwracamy specjalny typ błędu, który potem rozpoznamy return f"INVALID_IMAGE_TYPE: {msg}" aspect_ratio = width / height if not (aspect_ratio_min < aspect_ratio < aspect_ratio_max): msg = f"Nietypowe proporcje obrazu: {aspect_ratio:.2f}." return f"HEURISTIC_FAILED: {msg}" gray_image = pil_image.convert('L') std_dev_intensity = np.std(np.array(gray_image)) if std_dev_intensity < uniformity_threshold: msg = f"Obraz wydaje się zbyt jednolity (np. cały czarny): {std_dev_intensity:.2f}." return f"HEURISTIC_FAILED: {msg}" logger.info(f"[RequestID: {request_id}] Heurystyczne testy OOD zakończone pomyślnie.") return None # --- Aplikacja FastAPI --- class PredictionResult(BaseModel): birads: Optional[int] = None confidence: Optional[float] = None interpretation: str class_probabilities: Dict[str, float] grad_cam_image_base64: Optional[str] = None error: Optional[str] = None details: Optional[str] = None app = FastAPI(title="BI-RADS Mammography Classification API") @app.on_event("startup") async def startup_event(): logger.info("Rozpoczynanie eventu startup aplikacji FastAPI.") initialize_model() # --- ZAKTUALIZOWANY Endpoint /predict/ --- @app.post("/predict/", response_model=List[PredictionResult]) async def predict_images( files: List[UploadFile] = File(...), colorfulness_threshold: float = 2.0, uniformity_threshold: float = 10.0, aspect_ratio_min: float = 0.4, aspect_ratio_max: float = 2.5, mc_dropout_samples: int = 25, uncertainty_threshold_std: float = 0.11 ): request_id = os.urandom(8).hex() logger.info(f"[RequestID: {request_id}] Otrzymano żądanie /predict/ dla {len(files)} plików.") if model_instance is None or transform_pipeline is None: raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany.") all_results = [] valid_images_pil = [] valid_tensors = [] original_indices = [] for idx, file in enumerate(files): try: contents = await file.read() image_pil_original = Image.open(io.BytesIO(contents)) ood_error_details = run_heuristic_ood_checks( image_pil_original.copy(), request_id, colorfulness_threshold, uniformity_threshold, aspect_ratio_min, aspect_ratio_max ) if ood_error_details: # Rozpoznajemy nasz specjalny typ błędu if ood_error_details.startswith("INVALID_IMAGE_TYPE"): error_type = "Invalid Image Type" interpretation = "Przesłany plik nie wygląda na obraz mammograficzny. Proszę wgrać odpowiednie zdjęcie USG." details = ood_error_details.replace("INVALID_IMAGE_TYPE: ", "") else: # Pozostałe błędy heurystyczne error_type = "Heuristic OOD check failed" interpretation = "Obraz odrzucony przez wstępne testy. Może mieć nietypowe wymiary lub być zbyt jednolity." details = ood_error_details.replace("HEURISTIC_FAILED: ", "") result = PredictionResult( interpretation=interpretation, class_probabilities={}, error=error_type, details=details ) all_results.append((idx, result)) continue image_rgb = image_pil_original.convert("RGB") input_tensor = transform_pipeline(image_rgb).unsqueeze(0).to(DEVICE) valid_images_pil.append(image_rgb) valid_tensors.append(input_tensor) original_indices.append(idx) except Exception as e: logger.error(f"[RequestID: {request_id}] Błąd podczas odczytu pliku {file.filename}: {e}", exc_info=True) result = PredictionResult( interpretation="Błąd podczas przetwarzania pliku.", class_probabilities={}, error="File processing error.", details=str(e) ) all_results.append((idx, result)) if valid_tensors: batch_tensor = torch.cat(valid_tensors, dim=0) logger.info(f"[RequestID: {request_id}] Przetwarzanie wsadu {batch_tensor.shape[0]} poprawnych obrazów.") mc_results = predict_with_mc_dropout(model_instance, batch_tensor, mc_dropout_samples, uncertainty_threshold_std) model_instance.eval() target_layers = [model_instance.layer4[-1]] cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers) for i, result_dict in enumerate(mc_results): if not result_dict.get("error"): birads_cat = result_dict["birads"] pred_idx = result_dict["predicted_class_index"] input_tensor_for_cam = batch_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True) targets_for_cam = [ClassifierOutputTarget(pred_idx)] grayscale_cam = cam_algorithm(input_tensor=input_tensor_for_cam, targets=targets_for_cam) if grayscale_cam is not None: overlay_image_pil = create_grad_cam_overlay_image( original_pil_image=valid_images_pil[i], grayscale_cam=grayscale_cam[0, :], birads_category=birads_cat ) if overlay_image_pil: buffered = io.BytesIO() overlay_image_pil.save(buffered, format="PNG") result_dict["grad_cam_image_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8') result_dict.pop("predicted_class_index", None) all_results.append((original_indices[i], PredictionResult(**result_dict))) all_results.sort(key=lambda x: x[0]) final_results = [res for _, res in all_results] return final_results @app.get("/") async def root(): logger.info("Otrzymano żądanie GET na /") return {"message": "Witaj w BI-RADS Classification API! Użyj endpointu /predict/ do wysyłania obrazów."}