|
import torch |
|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.responses import JSONResponse |
|
from transformers import ConvNextForImageClassification, AutoImageProcessor |
|
from PIL import Image |
|
import io |
|
|
|
|
|
class_names = [ |
|
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', |
|
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', |
|
'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', |
|
'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', |
|
'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', |
|
'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', |
|
'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', |
|
'Warts Molluscum and other Viral Infections' |
|
] |
|
|
|
|
|
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224") |
|
model.classifier = torch.nn.Linear(in_features=1024, out_features=23) |
|
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu")) |
|
model.eval() |
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
def predict(image: Image.Image): |
|
inputs = processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predicted_class = torch.argmax(outputs.logits, dim=1).item() |
|
return predicted_class, class_names[predicted_class] |
|
|
|
|
|
@app.post("/predict/") |
|
async def predict_endpoint(file: UploadFile = File(...)): |
|
try: |
|
img_bytes = await file.read() |
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
predicted_class, predicted_name = predict(img) |
|
return JSONResponse(content={ |
|
"predicted_class": predicted_class, |
|
"predicted_name": predicted_name |
|
}) |
|
except Exception as e: |
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|
|
|
|
app = app |
|
|