File size: 1,233 Bytes
8869a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, AutoConfig, AutoModelForImageClassification
from PIL import Image
import torch
import io

app = FastAPI()

# Load processor and config
processor = AutoProcessor.from_pretrained("quantized_model")
config = AutoConfig.from_pretrained("dima806/deepfake_vs_real_image_detection")

# Load model architecture and quantized weights
model = AutoModelForImageClassification.from_config(config)
model_quantized = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
model_quantized.load_state_dict(torch.load("quantized_model/model_quantized.pt", map_location="cpu"))
model_quantized.eval()

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")

    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        logits = model_quantized(**inputs).logits
        predicted_idx = logits.argmax(-1).item()
        confidence = logits.softmax(-1)[0][predicted_idx].item()
        label = model_quantized.config.id2label[predicted_idx]

    return {"label": label, "confidence": confidence}