Classifier / app.py
santy171710's picture
updated app.py file
5670a94 verified
import io
import numpy as np
import tensorflow as tf
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# --- 1. Initialize FastAPI App ---
app = FastAPI(
title="X-Ray Denoising API",
description="An API to classify noise and denoise X-ray images.",
version="1.0.0",
)
# --- 2. Set up CORS ---
origins = [
"http://localhost:5173",
"http://localhost:3000",
"https://santy171710-classifier.hf.space/"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 3. Load AI Models ---
def load_all_models():
"""Loads the classifier and all denoising models."""
print("Loading AI models...")
try:
classifier_model = tf.keras.models.load_model('models/xray_noise_classifier_resnet50v2.keras')
denoiser_models = {
'gaussian': tf.keras.models.load_model('models/gaussian_denoiser_final_model.keras'),
'poisson': tf.keras.models.load_model('models/poisson_denoising.keras'),
'salt_pepper': tf.keras.models.load_model('models/salt_pepper_denoiser.keras'),
'speckle': tf.keras.models.load_model('models/speckle_denoising_final_model.keras')
}
print("βœ… Models loaded successfully!")
return classifier_model, denoiser_models
except Exception as e:
print(f"❌ Error loading models: {e}")
return None, None
CLASSIFIER, DENOISERS = load_all_models()
NOISE_CLASSES = ['gaussian', 'poisson', 'salt_pepper', 'speckle']
# --- 4. Define Helper Functions ---
def preprocess_image(image_bytes: bytes):
"""Converts image bytes to a NumPy array for the models."""
try:
img = Image.open(io.BytesIO(image_bytes))
# --- THIS IS THE CRUCIAL FIX ---
# 1. Convert to RGB for 3 color channels.
img = img.convert('RGB')
# 2. Resize to 224x224, the exact size the ResNet50 model expects.
img = img.resize((224, 224))
img_array = np.array(img)
# Add the batch dimension. The channel dimension is now 3.
img_array = img_array[np.newaxis, ...]
return img_array / 255.0
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image file. Could not preprocess. Error: {e}")
def postprocess_output(denoised_array: np.ndarray):
"""Converts model output array back to an image file in memory."""
# Squeeze the array to remove the batch dimension
processed_array = np.squeeze(denoised_array)
# Denormalize from 0-1 to 0-255 and convert to integer type
processed_array = (processed_array * 255).astype(np.uint8)
image = Image.fromarray(processed_array)
img_io = io.BytesIO()
image.save(img_io, 'PNG')
img_io.seek(0)
return img_io
# --- 5. Create the API Endpoint ---
@app.post("/api/denoise", response_class=StreamingResponse)
async def denoise_image(image: UploadFile = File(...)):
"""
Receives an X-ray image, classifies the noise, applies the correct
denoiser model, and returns the cleaned image.
"""
if not CLASSIFIER or not DENOISERS:
raise HTTPException(status_code=503, detail="Models are not available on the server.")
image_bytes = await image.read()
try:
# Preprocess for the classifier
classifier_input = preprocess_image(image_bytes)
# Run the classifier
prediction = CLASSIFIER.predict(classifier_input)
noise_type_index = np.argmax(prediction)
noise_type = NOISE_CLASSES[noise_type_index]
print(f"Detected noise type: {noise_type}")
# --- IMPORTANT ---
# We need to re-process the image for the denoiser models if they
# expect a different input size or format (e.g., grayscale 256x256).
# Assuming denoisers expect grayscale 256x256 for this example.
img_for_denoiser = Image.open(io.BytesIO(image_bytes)).convert('L').resize((256, 256))
denoiser_input = np.array(img_for_denoiser)[np.newaxis, ..., np.newaxis] / 255.0
# Select and run the correct denoiser
denoiser_model = DENOISERS[noise_type]
denoised_array = denoiser_model.predict(denoiser_input)
output_image_buffer = postprocess_output(denoised_array)
return StreamingResponse(output_image_buffer, media_type="image/png")
except HTTPException as e:
raise e
except Exception as e:
print(f"An unexpected error occurred during processing: {e}")
raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}")
# --- 6. Add a root endpoint for basic health check ---
@app.get("/")
def read_root():
return {"status": "ok", "message": "Welcome to the X-Ray Denoising API!"}