kautilya286's picture
Add Application File
8869a1f
raw
history blame contribute delete
1.23 kB
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, AutoConfig, AutoModelForImageClassification
from PIL import Image
import torch
import io
app = FastAPI()
# Load processor and config
processor = AutoProcessor.from_pretrained("quantized_model")
config = AutoConfig.from_pretrained("dima806/deepfake_vs_real_image_detection")
# Load model architecture and quantized weights
model = AutoModelForImageClassification.from_config(config)
model_quantized = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
model_quantized.load_state_dict(torch.load("quantized_model/model_quantized.pt", map_location="cpu"))
model_quantized.eval()
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
logits = model_quantized(**inputs).logits
predicted_idx = logits.argmax(-1).item()
confidence = logits.softmax(-1)[0][predicted_idx].item()
label = model_quantized.config.id2label[predicted_idx]
return {"label": label, "confidence": confidence}