Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model, load_model | |
| from tensorflow.keras.layers import GlobalAveragePooling2D, Dense | |
| from tensorflow.keras.applications import DenseNet121 | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from tensorflow.keras.applications.densenet import preprocess_input | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import io | |
| import uuid | |
| from datetime import datetime, timedelta | |
| import base64 | |
| import pydicom | |
| import os | |
| # Configuration | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| PORT = 7860 | |
| app = FastAPI( | |
| title="ChexNet Medical Imaging API", | |
| description="API for chest X-ray analysis with Grad-CAM visualization", | |
| version="5.0.0" | |
| ) | |
| # Rate limiter setup | |
| limiter = Limiter(key_func=get_remote_address) | |
| app.state.limiter = limiter | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Model configuration | |
| layer_name = 'conv5_block16_concat' | |
| class_names = [ | |
| 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', | |
| 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', | |
| 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding' | |
| ] | |
| def build_model(): | |
| base_model = DenseNet121( | |
| weights=None, | |
| include_top=False, | |
| input_shape=(None, None, 3) | |
| ) | |
| x = base_model.output | |
| x = GlobalAveragePooling2D()(x) | |
| predictions = Dense(14, activation='sigmoid')(x) | |
| return Model(inputs=base_model.input, outputs=predictions) | |
| def load_model_with_fallback(): | |
| try: | |
| model = build_model() | |
| model.load_weights('pretrained_model.h5') | |
| return model | |
| except Exception as e: | |
| print(f"Primary loading failed: {e}") | |
| try: | |
| model = load_model('Densenet.h5', compile=False) | |
| return model | |
| except Exception as e: | |
| print(f"Fallback loading failed: {e}") | |
| raise RuntimeError("All model loading strategies failed") | |
| # Load model | |
| try: | |
| model = load_model_with_fallback() | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Model loading failed: {e}") | |
| raise | |
| def generate_gradcam(img): | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| grad_model = Model( | |
| inputs=model.inputs, | |
| outputs=[model.get_layer(layer_name).output, model.output] | |
| ) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = grad_model(img_array) | |
| class_idx = tf.argmax(predictions[0]) | |
| output = conv_outputs[0] | |
| grads = tape.gradient(predictions, conv_outputs)[0] | |
| guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads | |
| weights = tf.reduce_mean(guided_grads, axis=(0, 1)) | |
| cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1) | |
| heatmap = np.maximum(cam, 0) | |
| heatmap /= np.max(heatmap) | |
| heatmap_img = plt.cm.jet(heatmap)[..., :3] | |
| original_img = Image.fromarray(img) | |
| heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8)) | |
| heatmap_img = heatmap_img.resize(original_img.size) | |
| return Image.blend(original_img, heatmap_img, 0.5) | |
| def process_predictions(predictions): | |
| decoded = [] | |
| for pred in predictions: | |
| top_indices = np.argsort(pred)[::-1][:len(class_names)] | |
| decoded.append([(class_names[i], float(pred[i])) for i in top_indices]) | |
| return decoded | |
| def dump_file_sample(file_bytes, filename="debug_file_sample.bin"): | |
| """Save a sample of file bytes for debugging""" | |
| try: | |
| sample_size = min(512, len(file_bytes)) | |
| with open(filename, "wb") as f: | |
| f.write(file_bytes[:sample_size]) | |
| print(f"Saved {sample_size} bytes sample to {filename}") | |
| # Try to print first few bytes as hex | |
| hex_sample = ' '.join([f'{b:02x}' for b in file_bytes[:16]]) | |
| print(f"First 16 bytes: {hex_sample}") | |
| except Exception as e: | |
| print(f"Failed to save debug sample: {e}") | |
| def preprocess_dicom(file_bytes): | |
| """Process DICOM format images for the model with robust error handling.""" | |
| # Create unique temporary filenames to avoid conflicts | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| uid = str(uuid.uuid4())[:8] | |
| temp_file = os.path.join(temp_dir, f"temp_dicom_{uid}.dcm") | |
| temp_img_file = os.path.join(temp_dir, f"temp_dicom_img_{uid}.png") | |
| try: | |
| print(f"Processing DICOM file of size {len(file_bytes)} bytes") | |
| # Write bytes to temporary file | |
| with open(temp_file, "wb") as f: | |
| f.write(file_bytes) | |
| # Read the DICOM file with force=True to ignore errors | |
| try: | |
| # Use defer_size=True to avoid reading large data elements | |
| # until explicitly accessed | |
| dicom_data = pydicom.dcmread(temp_file, force=True, defer_size=None) | |
| # Check transfer syntax | |
| if hasattr(dicom_data, 'file_meta') and hasattr(dicom_data.file_meta, 'TransferSyntaxUID'): | |
| ts_uid = str(dicom_data.file_meta.TransferSyntaxUID) | |
| print(f"DICOM file read successfully. Transfer syntax: {ts_uid}") | |
| else: | |
| print("DICOM file read but no transfer syntax found - assuming default Implicit VR Little Endian") | |
| except Exception as e: | |
| print(f"Error reading DICOM file: {e}") | |
| raise ValueError(f"Failed to read DICOM file: {e}") | |
| # Verify pixel data exists | |
| if not hasattr(dicom_data, 'PixelData'): | |
| print("PixelData attribute missing") | |
| # Try to check for alternate pixel data representations | |
| alt_pixel_attrs = ['FloatPixelData', 'DoubleFloatPixelData'] | |
| has_pixel_data = False | |
| for attr in alt_pixel_attrs: | |
| if hasattr(dicom_data, attr): | |
| has_pixel_data = True | |
| print(f"Found alternate pixel data: {attr}") | |
| break | |
| if not has_pixel_data: | |
| raise ValueError("DICOM file does not contain any pixel data") | |
| # Print DICOM image properties for diagnosis | |
| print(f"DICOM properties:") | |
| for attr in ['BitsAllocated', 'BitsStored', 'HighBit', 'SamplesPerPixel', 'Rows', 'Columns']: | |
| if hasattr(dicom_data, attr): | |
| print(f" {attr}: {getattr(dicom_data, attr)}") | |
| else: | |
| print(f" {attr}: Not specified") | |
| # Algorithm to try multiple methods to extract pixel data | |
| img = None | |
| methods_tried = [] | |
| # Method 1: Direct pixel_array access with exception handling | |
| if img is None: | |
| try: | |
| methods_tried.append("Direct pixel_array") | |
| img = dicom_data.pixel_array | |
| if img.size > 0: | |
| print(f"Successfully extracted pixel data via pixel_array: shape={img.shape}, dtype={img.dtype}") | |
| else: | |
| img = None | |
| raise ValueError("Extracted pixel array is empty") | |
| except Exception as e: | |
| print(f"Method 1 (direct pixel_array) failed: {e}") | |
| img = None | |
| # Method 2: Save and reload through PNG for compressed images | |
| if img is None: | |
| try: | |
| methods_tried.append("PNG intermediate") | |
| print("Trying PNG intermediate method...") | |
| dicom_data.save_as(temp_img_file) | |
| # Try with IMREAD_UNCHANGED first to preserve bit depth | |
| img = cv2.imread(temp_img_file, cv2.IMREAD_UNCHANGED) | |
| if img is None or img.size == 0: | |
| # Fall back to IMREAD_GRAYSCALE | |
| img = cv2.imread(temp_img_file, cv2.IMREAD_GRAYSCALE) | |
| if img is not None and img.size > 0: | |
| print(f"Successfully extracted pixel data via PNG: shape={img.shape}, dtype={img.dtype}") | |
| else: | |
| img = None | |
| raise ValueError("PNG conversion resulted in empty image") | |
| except Exception as e: | |
| print(f"Method 2 (PNG intermediate) failed: {e}") | |
| img = None | |
| # Method 3: PIL intermediate | |
| if img is None: | |
| try: | |
| methods_tried.append("PIL intermediate") | |
| print("Trying PIL intermediate method...") | |
| from PIL import Image | |
| dicom_data.save_as(temp_img_file) | |
| pil_img = Image.open(temp_img_file) | |
| img = np.array(pil_img) | |
| if img is not None and img.size > 0: | |
| print(f"Successfully extracted pixel data via PIL: shape={img.shape}, dtype={img.dtype}") | |
| else: | |
| img = None | |
| raise ValueError("PIL conversion resulted in empty image") | |
| except Exception as e: | |
| print(f"Method 3 (PIL intermediate) failed: {e}") | |
| img = None | |
| # If all methods failed, create a diagnostic image | |
| if img is None: | |
| print(f"All pixel data extraction methods failed: {', '.join(methods_tried)}") | |
| # Create a diagnostic image | |
| img = np.ones((540, 540), dtype=np.uint8) * 128 | |
| # Add text about the error | |
| img_with_text = np.ones((540, 540, 3), dtype=np.uint8) * 128 | |
| error_text = "Failed to extract DICOM pixel data" | |
| cv2.putText(img_with_text, error_text, (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) | |
| # Return the diagnostic image | |
| print("Returning diagnostic image due to extraction failure") | |
| return img_with_text | |
| # DICOM images are often 16-bit or higher, normalize to 8-bit for visualization | |
| print(f"Original image: shape={img.shape}, dtype={img.dtype}, min={np.min(img)}, max={np.max(img)}") | |
| # 1. Normalize pixel values to 8-bit range | |
| if img.dtype != np.uint8: | |
| try: | |
| # Calculate data range for proper normalization | |
| img_min = float(np.min(img)) | |
| img_max = float(np.max(img)) | |
| # Only normalize if we have a non-zero range | |
| if img_max > img_min: | |
| # Convert to float32 first for better precision | |
| img = img.astype(np.float32) | |
| # Scale to range [0, 255] | |
| img = 255.0 * (img - img_min) / (img_max - img_min) | |
| # Convert to uint8 | |
| img = img.astype(np.uint8) | |
| print(f"Normalized to 8-bit: new range=[{np.min(img)}, {np.max(img)}]") | |
| else: | |
| # Handle uniform pixel values | |
| img = np.full(img.shape, 128, dtype=np.uint8) | |
| print("Image has uniform pixel values, using mid-gray") | |
| except Exception as e: | |
| print(f"Error during normalization: {e}") | |
| # Create a valid grayscale image in case of error | |
| img = np.full(img.shape if len(img.shape) >= 2 else (540, 540), 128, dtype=np.uint8) | |
| # 2. Handle color conversion based on image dimensions | |
| try: | |
| # Check image dimensions | |
| if len(img.shape) == 2: | |
| # Single channel (grayscale) image - convert to 3-channel | |
| print("Converting grayscale to RGB using manual conversion") | |
| h, w = img.shape | |
| rgb_img = np.zeros((h, w, 3), dtype=np.uint8) | |
| rgb_img[:, :, 0] = img # R | |
| rgb_img[:, :, 1] = img # G | |
| rgb_img[:, :, 2] = img # B | |
| img = rgb_img | |
| elif len(img.shape) == 3: | |
| if img.shape[2] == 1: | |
| # Single channel image in 3D array | |
| print("Converting single-channel 3D array to RGB") | |
| h, w, _ = img.shape | |
| img_2d = img.reshape(h, w) | |
| rgb_img = np.zeros((h, w, 3), dtype=np.uint8) | |
| rgb_img[:, :, 0] = img_2d | |
| rgb_img[:, :, 1] = img_2d | |
| rgb_img[:, :, 2] = img_2d | |
| img = rgb_img | |
| elif img.shape[2] == 3: | |
| # Already RGB, make sure it's the right color space | |
| print("Image already has 3 channels, ensuring RGB color space") | |
| # No conversion needed if already RGB | |
| elif img.shape[2] == 4: | |
| # RGBA image - remove alpha channel | |
| print("Converting RGBA to RGB by removing alpha channel") | |
| img = img[:, :, :3] | |
| else: | |
| # Unusual number of channels, convert to grayscale then RGB | |
| print(f"Unusual channel count ({img.shape[2]}), converting to grayscale then RGB") | |
| if np.max(img) > 0: # Avoid division by zero | |
| # Average across channels and normalize | |
| gray = np.mean(img, axis=2).astype(np.uint8) | |
| h, w = gray.shape | |
| rgb_img = np.zeros((h, w, 3), dtype=np.uint8) | |
| rgb_img[:, :, 0] = gray | |
| rgb_img[:, :, 1] = gray | |
| rgb_img[:, :, 2] = gray | |
| img = rgb_img | |
| else: | |
| # Create a valid RGB image if all pixels are zero | |
| h, w = img.shape[:2] | |
| img = np.full((h, w, 3), 128, dtype=np.uint8) | |
| else: | |
| # Invalid dimensions, create fallback image | |
| print(f"Invalid image dimensions: {img.shape}") | |
| img = np.full((540, 540, 3), 128, dtype=np.uint8) | |
| cv2.putText(img, "Invalid image dimensions", (50, 270), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) | |
| except Exception as e: | |
| print(f"Error during color conversion: {e}") | |
| # Create a valid RGB image in case of error | |
| img = np.full((540, 540, 3), 128, dtype=np.uint8) | |
| cv2.putText(img, f"Error: {str(e)[:30]}", (50, 270), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) | |
| # 3. Add final validation and cleanup | |
| print(f"After color conversion: shape={img.shape}, dtype={img.dtype}") | |
| # Final validation | |
| if img is None or img.size == 0 or len(img.shape) < 2: | |
| raise ValueError("Image processing resulted in invalid image") | |
| # Resize for model input | |
| print(f"Final image shape before resize: {img.shape}") | |
| img = cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA) | |
| print(f"Resized image shape: {img.shape}") | |
| return img | |
| except Exception as e: | |
| print(f"DICOM processing failed: {e}") | |
| raise | |
| finally: | |
| # Clean up temporary files | |
| for temp_file_path in [temp_file, temp_img_file]: | |
| if os.path.exists(temp_file_path): | |
| try: | |
| os.remove(temp_file_path) | |
| except Exception as e: | |
| print(f"Failed to remove temporary file {temp_file_path}: {e}") | |
| def preprocess_image(file_bytes, content_type=None): | |
| """Process images for the model, handling both DICOM and standard formats.""" | |
| print(f"Preprocessing image with content type: {content_type}, size: {len(file_bytes)} bytes") | |
| # Save a debug sample of the file bytes | |
| dump_file_sample(file_bytes) | |
| # Check if the file is a DICOM file | |
| is_likely_dicom = False | |
| # Check content type for DICOM indicators | |
| if content_type and ('dicom' in content_type.lower() or | |
| content_type.lower() == 'application/octet-stream' or | |
| content_type.lower() == 'application/dicom'): | |
| is_likely_dicom = True | |
| # Also check file signature (DICOM files usually start with "DICM" at byte offset 128) | |
| if len(file_bytes) > 132: | |
| dicom_signature = file_bytes[128:132] | |
| if dicom_signature == b'DICM': | |
| is_likely_dicom = True | |
| print("DICOM signature detected in file") | |
| if is_likely_dicom: | |
| try: | |
| return preprocess_dicom(file_bytes) | |
| except Exception as e: | |
| print(f"DICOM processing error: {e}") | |
| # Fall back to standard image processing if DICOM processing fails | |
| print("Falling back to standard image processing") | |
| # Process as standard image format | |
| try: | |
| print("Processing as standard image format") | |
| img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR) | |
| # Validate image was successfully decoded | |
| if img is None or img.size == 0: | |
| print("Standard image decoding failed - creating fallback image") | |
| # Create a fallback image for debugging | |
| img = np.ones((540, 540, 3), dtype=np.uint8) * 128 | |
| # Add diagnostic pattern | |
| cv2.line(img, (0, 0), (540, 540), (200, 100, 100), 10) | |
| cv2.line(img, (540, 0), (0, 540), (100, 200, 100), 10) | |
| return img | |
| # If we got a valid image, proceed with color conversion | |
| print(f"Standard image decoded successfully: shape={img.shape}, dtype={img.dtype}") | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA) | |
| except Exception as e: | |
| print(f"Standard image processing error: {e}") | |
| # Create fallback image as last resort | |
| img = np.ones((540, 540, 3), dtype=np.uint8) * 128 | |
| cv2.putText(img, "Error: " + str(e)[:30], (50, 270), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) | |
| return img | |
| async def analyze_image( | |
| request: Request, | |
| file: UploadFile = File(...) | |
| ): | |
| # Accept both standard image formats and DICOM files | |
| if not (file.content_type.startswith('image/') or | |
| 'dicom' in file.content_type.lower() or | |
| file.content_type == 'application/octet-stream'): | |
| raise HTTPException(400, "Only image or DICOM files accepted") | |
| if file.size > MAX_FILE_SIZE: | |
| raise HTTPException(413, f"File too large (max {MAX_FILE_SIZE//1024//1024}MB)") | |
| try: | |
| contents = await file.read() | |
| img = preprocess_image(contents, file.content_type) | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| predictions = model.predict(img_array) | |
| decoded = process_predictions(predictions) | |
| heatmap = generate_gradcam(img) | |
| # Convert heatmap to base64 instead of saving to file | |
| img_byte_arr = io.BytesIO() | |
| heatmap.save(img_byte_arr, format='PNG') | |
| heatmap_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| return { | |
| "predictions": decoded[0], | |
| "heatmap_image": heatmap_base64, | |
| "heatmap_format": "base64 encoded PNG" | |
| } | |
| except Exception as e: | |
| error_message = str(e) | |
| print(f"Analysis failed with error: {error_message}") | |
| # Return a more detailed error message | |
| if "empty()" in error_message and "cvtColor" in error_message: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to process image: The image data is empty or corrupt. Please check your DICOM file format. Original error: {error_message}" | |
| ) | |
| elif "DICOM" in error_message: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"DICOM processing error: {error_message}. Please ensure your DICOM file contains valid pixel data." | |
| ) | |
| else: | |
| raise HTTPException(500, f"Analysis failed: {error_message}") | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.now().isoformat(), | |
| "features": { | |
| "dicom_support": True | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=PORT) |