Spaces:
Sleeping
Sleeping
| import io | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # --- 1. Initialize FastAPI App --- | |
| app = FastAPI( | |
| title="X-Ray Denoising API", | |
| description="An API to classify noise and denoise X-ray images.", | |
| version="1.0.0", | |
| ) | |
| # --- 2. Set up CORS --- | |
| origins = [ | |
| "http://localhost:5173", | |
| "http://localhost:3000", | |
| "https://santy171710-classifier.hf.space/" | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- 3. Load AI Models --- | |
| def load_all_models(): | |
| """Loads the classifier and all denoising models.""" | |
| print("Loading AI models...") | |
| try: | |
| classifier_model = tf.keras.models.load_model('models/xray_noise_classifier_resnet50v2.keras') | |
| denoiser_models = { | |
| 'gaussian': tf.keras.models.load_model('models/gaussian_denoiser_final_model.keras'), | |
| 'poisson': tf.keras.models.load_model('models/poisson_denoising.keras'), | |
| 'salt_pepper': tf.keras.models.load_model('models/salt_pepper_denoiser.keras'), | |
| 'speckle': tf.keras.models.load_model('models/speckle_denoising_final_model.keras') | |
| } | |
| print("β Models loaded successfully!") | |
| return classifier_model, denoiser_models | |
| except Exception as e: | |
| print(f"β Error loading models: {e}") | |
| return None, None | |
| CLASSIFIER, DENOISERS = load_all_models() | |
| NOISE_CLASSES = ['gaussian', 'poisson', 'salt_pepper', 'speckle'] | |
| # --- 4. Define Helper Functions --- | |
| def preprocess_image(image_bytes: bytes): | |
| """Converts image bytes to a NumPy array for the models.""" | |
| try: | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| # --- THIS IS THE CRUCIAL FIX --- | |
| # 1. Convert to RGB for 3 color channels. | |
| img = img.convert('RGB') | |
| # 2. Resize to 224x224, the exact size the ResNet50 model expects. | |
| img = img.resize((224, 224)) | |
| img_array = np.array(img) | |
| # Add the batch dimension. The channel dimension is now 3. | |
| img_array = img_array[np.newaxis, ...] | |
| return img_array / 255.0 | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image file. Could not preprocess. Error: {e}") | |
| def postprocess_output(denoised_array: np.ndarray): | |
| """Converts model output array back to an image file in memory.""" | |
| # Squeeze the array to remove the batch dimension | |
| processed_array = np.squeeze(denoised_array) | |
| # Denormalize from 0-1 to 0-255 and convert to integer type | |
| processed_array = (processed_array * 255).astype(np.uint8) | |
| image = Image.fromarray(processed_array) | |
| img_io = io.BytesIO() | |
| image.save(img_io, 'PNG') | |
| img_io.seek(0) | |
| return img_io | |
| # --- 5. Create the API Endpoint --- | |
| async def denoise_image(image: UploadFile = File(...)): | |
| """ | |
| Receives an X-ray image, classifies the noise, applies the correct | |
| denoiser model, and returns the cleaned image. | |
| """ | |
| if not CLASSIFIER or not DENOISERS: | |
| raise HTTPException(status_code=503, detail="Models are not available on the server.") | |
| image_bytes = await image.read() | |
| try: | |
| # Preprocess for the classifier | |
| classifier_input = preprocess_image(image_bytes) | |
| # Run the classifier | |
| prediction = CLASSIFIER.predict(classifier_input) | |
| noise_type_index = np.argmax(prediction) | |
| noise_type = NOISE_CLASSES[noise_type_index] | |
| print(f"Detected noise type: {noise_type}") | |
| # --- IMPORTANT --- | |
| # We need to re-process the image for the denoiser models if they | |
| # expect a different input size or format (e.g., grayscale 256x256). | |
| # Assuming denoisers expect grayscale 256x256 for this example. | |
| img_for_denoiser = Image.open(io.BytesIO(image_bytes)).convert('L').resize((256, 256)) | |
| denoiser_input = np.array(img_for_denoiser)[np.newaxis, ..., np.newaxis] / 255.0 | |
| # Select and run the correct denoiser | |
| denoiser_model = DENOISERS[noise_type] | |
| denoised_array = denoiser_model.predict(denoiser_input) | |
| output_image_buffer = postprocess_output(denoised_array) | |
| return StreamingResponse(output_image_buffer, media_type="image/png") | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| print(f"An unexpected error occurred during processing: {e}") | |
| raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}") | |
| # --- 6. Add a root endpoint for basic health check --- | |
| def read_root(): | |
| return {"status": "ok", "message": "Welcome to the X-Ray Denoising API!"} |