|
from fastapi import FastAPI, UploadFile, File |
|
from transformers import AutoModelForImageClassification, AutoProcessor |
|
from PIL import Image |
|
import torch |
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "jazzmacedo/fruits-and-vegetables-detector-36" |
|
model = AutoModelForImageClassification.from_pretrained(model_name) |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
@app.post("/predict") |
|
async def predict(file: UploadFile = File(...)): |
|
|
|
image = Image.open(file.file).convert("RGB") |
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
predicted_class_name = model.config.id2label[predicted_class_idx] |
|
|
|
return {"prediction": predicted_class_name} |
|
|