Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File | |
from transformers import AutoProcessor, AutoConfig, AutoModelForImageClassification | |
from PIL import Image | |
import torch | |
import io | |
app = FastAPI() | |
# Load processor and config | |
processor = AutoProcessor.from_pretrained("quantized_model") | |
config = AutoConfig.from_pretrained("dima806/deepfake_vs_real_image_detection") | |
# Load model architecture and quantized weights | |
model = AutoModelForImageClassification.from_config(config) | |
model_quantized = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
model_quantized.load_state_dict(torch.load("quantized_model/model_quantized.pt", map_location="cpu")) | |
model_quantized.eval() | |
async def predict(file: UploadFile = File(...)): | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model_quantized(**inputs).logits | |
predicted_idx = logits.argmax(-1).item() | |
confidence = logits.softmax(-1)[0][predicted_idx].item() | |
label = model_quantized.config.id2label[predicted_idx] | |
return {"label": label, "confidence": confidence} | |