garbage-api / app.py
mohamedtsou's picture
Create app.py
227a247 verified
raw
history blame contribute delete
956 Bytes
from fastapi import FastAPI, File, UploadFile
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch, io, os, uvicorn
app = FastAPI()
MODEL_NAME = "yangy50/garbage-classification"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
return {
model.config.id2label[i]: float(probs[i])
for i in range(len(probs))
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))