|
|
import streamlit as st
|
|
|
from transformers import AutoImageProcessor, SiglipForImageClassification
|
|
|
from PIL import Image
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = r"C:\Users\Sharulatha\Documents\hackathon\deepfake-detector-model-v1\checkpoint-625"
|
|
|
|
|
|
st.title("🕵️ Deepfake Detector")
|
|
|
|
|
|
@st.cache_resource
|
|
|
def load_model():
|
|
|
model = SiglipForImageClassification.from_pretrained(model_path)
|
|
|
processor = AutoImageProcessor.from_pretrained(model_path)
|
|
|
return model, processor
|
|
|
|
|
|
model, processor = load_model()
|
|
|
|
|
|
|
|
|
id2label = {"0": "fake", "1": "real"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
|
|
|
|
|
if uploaded_file is not None:
|
|
|
image = Image.open(uploaded_file).convert("RGB")
|
|
|
st.image(image, caption="Uploaded Image", use_column_width=True)
|
|
|
|
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model(**inputs)
|
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
|
|
|
|
|
|
prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))}
|
|
|
|
|
|
st.subheader("Prediction")
|
|
|
st.json(prediction)
|
|
|
|