dgapiv2 / app.py
Hayloo9838's picture
Update app.py
3f59d69 verified
raw
history blame contribute delete
983 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForImageClassification
from PIL import Image
import torchvision.transforms as T
import requests
from io import BytesIO
app = FastAPI()
# load model once
model_name = "Falconsai/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.5], [0.5])
])
class ImageInput(BaseModel):
url: str
@app.get("/")
def read_root():
return {"status": "running"}
@app.post("/predict")
def predict(input: ImageInput):
img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor).logits
pred = torch.argmax(logits, dim=1).item()
return {"class": int(pred)}