Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch, torch.nn as nn, warnings | |
| from torchvision import transforms | |
| from transformers import EfficientNetModel | |
| from PIL import Image | |
| import numpy as np | |
| warnings.filterwarnings("ignore") | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FFTBranch(nn.Module): | |
| def __init__(self, out_dim=512): | |
| super().__init__() | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1,32,3,padding=1),nn.BatchNorm2d(32),nn.GELU(),nn.MaxPool2d(2), | |
| nn.Conv2d(32,64,3,padding=1),nn.BatchNorm2d(64),nn.GELU(),nn.MaxPool2d(2), | |
| nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.GELU(), | |
| nn.AdaptiveAvgPool2d((4,4)), | |
| ) | |
| self.proj = nn.Sequential(nn.Linear(128*4*4,out_dim),nn.GELU(),nn.Dropout(0.3)) | |
| def forward(self, x): | |
| g = x.mean(dim=1,keepdim=True) | |
| fft = torch.fft.fftshift(torch.fft.fft2(g)) | |
| mag = torch.log(torch.abs(fft)+1e-8) | |
| mn = mag.flatten(2).min(2)[0].unsqueeze(-1).unsqueeze(-1) | |
| mx = mag.flatten(2).max(2)[0].unsqueeze(-1).unsqueeze(-1) | |
| mag = (mag-mn)/(mx-mn+1e-8) | |
| return self.proj(self.cnn(mag).flatten(1)) | |
| class CNNFFTDetector(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.cnn = EfficientNetModel.from_pretrained("google/efficientnet-b0") | |
| params = list(self.cnn.parameters()) | |
| for i,p in enumerate(params): | |
| p.requires_grad = (i>=int(len(params)*0.6)) | |
| self.cnn_proj = nn.Sequential(nn.Linear(1280,512),nn.GELU(),nn.Dropout(0.3)) | |
| self.fft = FFTBranch(out_dim=512) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(1024,256),nn.GELU(),nn.Dropout(0.4), | |
| nn.Linear(256,64),nn.GELU(),nn.Linear(64,1)) | |
| def forward(self, x): | |
| c = self.cnn_proj(self.cnn(x).pooler_output) | |
| f = self.fft(x) | |
| return self.classifier(torch.cat([c,f],dim=1)) | |
| print("Loading model...") | |
| device = torch.device("cpu") | |
| model = CNNFFTDetector().to(device) | |
| ckpt = torch.load("best.pth", map_location="cpu", weights_only=False) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.eval() | |
| print(f"Model ready β {ckpt['best_val_acc']*100:.2f}%") | |
| tf = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3,[0.5]*3), | |
| ]) | |
| def predict(image): | |
| if image is None: | |
| return {"AI Generated": 0.0, "Real": 1.0}, "Please upload an image" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGB") | |
| tensor = tf(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| score = torch.sigmoid(model(tensor)).item() | |
| fake_pct = round(score*100, 1) | |
| real_pct = round((1-score)*100, 1) | |
| label = "AI Generated / Deepfake" if score >= 0.5 else "Real Image" | |
| verdict = f"## {'π΄' if score>=0.5 else 'π’'} {label}\n\n**AI/Fake:** {fake_pct}% \n**Real:** {real_pct}% \n**Confidence:** {round(max(score,1-score)*100,1)}%" | |
| return {"AI Generated": float(score), "Real": float(1-score)}, verdict | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="LunaNet") as demo: | |
| gr.Markdown("# π LunaNet β AI Image & Deepfake Detector\n**Revealing the Unseen** Β· CNN (EfficientNetB0) + FFT Β· 91.47% accuracy") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Upload Image") | |
| btn = gr.Button("β¦ Analyse", variant="primary", size="lg") | |
| with gr.Column(): | |
| label_out = gr.Label(num_top_classes=2, label="Detection Result") | |
| md_out = gr.Markdown(label="Verdict") | |
| # api_name makes it callable as /predict from external frontends | |
| btn.click(fn=predict, inputs=img_input, outputs=[label_out, md_out], api_name="predict") | |
| img_input.upload(fn=predict, inputs=img_input, outputs=[label_out, md_out]) | |
| gr.Markdown("---\n**Training data:** CIFAKE Β· 140k Faces Β· OpenForensics Β· Celeb-DF v2") | |
| demo.launch(ssr_mode=False) |