|
import torch |
|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.responses import JSONResponse, RedirectResponse |
|
from transformers import ConvNextForImageClassification, AutoImageProcessor |
|
from PIL import Image |
|
import io |
|
import gradio as gr |
|
from starlette.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
from gradio.routes import mount_gradio_app |
|
import tempfile |
|
import os |
|
from typing import Optional |
|
|
|
|
|
class_names = [ |
|
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', |
|
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', |
|
'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', |
|
'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', |
|
'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', |
|
'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', |
|
'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', |
|
'Warts Molluscum and other Viral Infections' |
|
] |
|
|
|
|
|
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224") |
|
model.classifier = torch.nn.Linear(in_features=1024, out_features=23) |
|
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu")) |
|
model.eval() |
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
def predict(image: Image.Image): |
|
inputs = processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predicted_class = torch.argmax(outputs.logits, dim=1).item() |
|
return predicted_class, class_names[predicted_class] |
|
|
|
|
|
@app.post("/api/predict") |
|
async def predict_from_upload(file: UploadFile = File(...)): |
|
"""API endpoint for image uploads""" |
|
try: |
|
img_bytes = await file.read() |
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
predicted_class, predicted_name = predict(img) |
|
return JSONResponse(content={ |
|
"predicted_class": predicted_class, |
|
"predicted_name": predicted_name |
|
}) |
|
except Exception as e: |
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|
|
@app.get("/") |
|
def redirect_root_to_gradio(): |
|
return RedirectResponse(url="/gradio") |
|
|
|
|
|
def gradio_interface(image): |
|
"""Gradio function to handle the prediction from image""" |
|
predicted_class, predicted_name = predict(image) |
|
return f"{predicted_name} (Class {predicted_class})" |
|
|
|
|
|
gradio_app = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="Skin Disease Classifier", |
|
description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model." |
|
) |
|
|
|
|
|
app = mount_gradio_app(app, gradio_app, path="/gradio") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|