Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from transformers import ConvNextForImageClassification, AutoImageProcessor | |
| from PIL import Image | |
| import io | |
| import gradio as gr | |
| from starlette.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from gradio.routes import mount_gradio_app | |
| import tempfile | |
| import os | |
| from typing import Optional | |
| # Class names for skin disease classification | |
| class_names = [ | |
| 'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', | |
| 'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', | |
| 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', | |
| 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', | |
| 'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', | |
| 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', | |
| 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', | |
| 'Warts Molluscum and other Viral Infections' | |
| ] | |
| # Load the ConvNeXt model and processor | |
| model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224") | |
| model.classifier = torch.nn.Linear(in_features=1024, out_features=23) | |
| model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu")) | |
| model.eval() | |
| processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224") | |
| # FastAPI app setup | |
| app = FastAPI() | |
| # CORS Middleware to allow cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins for demo purposes | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Function to predict the skin disease from an image | |
| def predict(image: Image.Image): | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| return predicted_class, class_names[predicted_class] | |
| # FastAPI route for prediction via API | |
| async def predict_from_upload(file: UploadFile = File(...)): | |
| """API endpoint for image uploads""" | |
| try: | |
| img_bytes = await file.read() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| predicted_class, predicted_name = predict(img) | |
| return JSONResponse(content={ | |
| "predicted_class": predicted_class, | |
| "predicted_name": predicted_name | |
| }) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |
| # Redirect root to Gradio interface | |
| def redirect_root_to_gradio(): | |
| return RedirectResponse(url="/gradio") | |
| # Gradio interface for testing | |
| def gradio_interface(image): | |
| """Gradio function to handle the prediction from image""" | |
| predicted_class, predicted_name = predict(image) | |
| return f"{predicted_name} (Class {predicted_class})" | |
| # Gradio app setup | |
| gradio_app = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="Skin Disease Classifier", | |
| description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model." | |
| ) | |
| # Mount Gradio app into FastAPI | |
| app = mount_gradio_app(app, gradio_app, path="/gradio") | |
| # For running the app locally with uvicorn | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |