deepscan / app.py
Sattfaceee123's picture
Update app.py
d5f04bf verified
import gradio as gr
import torch, torch.nn as nn, warnings
from torchvision import transforms
from transformers import EfficientNetModel
from PIL import Image
import numpy as np
warnings.filterwarnings("ignore")
# ── Model ─────────────────────────────────────────────────────────
class FFTBranch(nn.Module):
def __init__(self, out_dim=512):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1,32,3,padding=1),nn.BatchNorm2d(32),nn.GELU(),nn.MaxPool2d(2),
nn.Conv2d(32,64,3,padding=1),nn.BatchNorm2d(64),nn.GELU(),nn.MaxPool2d(2),
nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.GELU(),
nn.AdaptiveAvgPool2d((4,4)),
)
self.proj = nn.Sequential(nn.Linear(128*4*4,out_dim),nn.GELU(),nn.Dropout(0.3))
def forward(self, x):
g = x.mean(dim=1,keepdim=True)
fft = torch.fft.fftshift(torch.fft.fft2(g))
mag = torch.log(torch.abs(fft)+1e-8)
mn = mag.flatten(2).min(2)[0].unsqueeze(-1).unsqueeze(-1)
mx = mag.flatten(2).max(2)[0].unsqueeze(-1).unsqueeze(-1)
mag = (mag-mn)/(mx-mn+1e-8)
return self.proj(self.cnn(mag).flatten(1))
class CNNFFTDetector(nn.Module):
def __init__(self):
super().__init__()
self.cnn = EfficientNetModel.from_pretrained("google/efficientnet-b0")
params = list(self.cnn.parameters())
for i,p in enumerate(params):
p.requires_grad = (i>=int(len(params)*0.6))
self.cnn_proj = nn.Sequential(nn.Linear(1280,512),nn.GELU(),nn.Dropout(0.3))
self.fft = FFTBranch(out_dim=512)
self.classifier = nn.Sequential(
nn.Linear(1024,256),nn.GELU(),nn.Dropout(0.4),
nn.Linear(256,64),nn.GELU(),nn.Linear(64,1))
def forward(self, x):
c = self.cnn_proj(self.cnn(x).pooler_output)
f = self.fft(x)
return self.classifier(torch.cat([c,f],dim=1))
print("Loading model...")
device = torch.device("cpu")
model = CNNFFTDetector().to(device)
ckpt = torch.load("best.pth", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state"])
model.eval()
print(f"Model ready β€” {ckpt['best_val_acc']*100:.2f}%")
tf = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3),
])
def predict(image):
if image is None:
return {"AI Generated": 0.0, "Real": 1.0}, "Please upload an image"
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
tensor = tf(image).unsqueeze(0).to(device)
with torch.no_grad():
score = torch.sigmoid(model(tensor)).item()
fake_pct = round(score*100, 1)
real_pct = round((1-score)*100, 1)
label = "AI Generated / Deepfake" if score >= 0.5 else "Real Image"
verdict = f"## {'πŸ”΄' if score>=0.5 else '🟒'} {label}\n\n**AI/Fake:** {fake_pct}% \n**Real:** {real_pct}% \n**Confidence:** {round(max(score,1-score)*100,1)}%"
return {"AI Generated": float(score), "Real": float(1-score)}, verdict
# ── UI ────────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft(), title="LunaNet") as demo:
gr.Markdown("# πŸŒ™ LunaNet β€” AI Image & Deepfake Detector\n**Revealing the Unseen** Β· CNN (EfficientNetB0) + FFT Β· 91.47% accuracy")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Upload Image")
btn = gr.Button("✦ Analyse", variant="primary", size="lg")
with gr.Column():
label_out = gr.Label(num_top_classes=2, label="Detection Result")
md_out = gr.Markdown(label="Verdict")
# api_name makes it callable as /predict from external frontends
btn.click(fn=predict, inputs=img_input, outputs=[label_out, md_out], api_name="predict")
img_input.upload(fn=predict, inputs=img_input, outputs=[label_out, md_out])
gr.Markdown("---\n**Training data:** CIFAKE Β· 140k Faces Β· OpenForensics Β· Celeb-DF v2")
demo.launch(ssr_mode=False)