File size: 3,454 Bytes
c357b11
4340122
 
 
 
 
 
 
75cb51e
64a53bd
4340122
a700de2
4340122
c357b11
4340122
64a53bd
c357b11
4340122
 
 
 
0fa110a
4340122
 
 
0fa110a
4340122
0fa110a
4340122
 
 
 
0fa110a
a700de2
64a53bd
f56b8cf
64a53bd
f56b8cf
 
 
 
 
64a53bd
f56b8cf
64a53bd
 
f56b8cf
 
64a53bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56b8cf
 
 
 
 
 
64a53bd
 
 
f56b8cf
 
 
1d1c5fa
dbd815b
64a53bd
a700de2
 
ca04b82
 
a700de2
 
 
 
 
ca04b82
 
 
a700de2
f56b8cf
a700de2
f56b8cf
 
a700de2
 
64a53bd
f56b8cf
64a53bd
f56b8cf
73731e6
a700de2
4340122
 
a700de2
4340122
a3c4c44
4340122
 
30a4f6b
73731e6
4340122
fb414d0
be9e570
a700de2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import warnings

warnings.filterwarnings("ignore")

# Download and Load Model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

mtcnn = MTCNN(
    select_largest=False,
    post_process=False,
    device=DEVICE
).to(DEVICE).eval()
model = InceptionResnetV1(
    pretrained="vggface2",
    classify=True,
    num_classes=1,
    device=DEVICE
)

checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()

# Model Inference 
def predict_frame(frame):
    """Predict whether the input frame contains a real or fake face"""
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_pil = Image.fromarray(frame)

    face = mtcnn(frame_pil)
    if face is None:
        return None, None  # No face detected

    # Preprocess the face
    face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
    face = face.to(DEVICE, dtype=torch.float32) / 255.0

    # Predict
    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "real" if output.item() < 0.5 else "fake"
        
        # Confidence scores
        real_prediction = 1 - output.item()
        fake_prediction = output.item()
        
        confidences = {
            'real': real_prediction,
            'fake': fake_prediction
        }

    # Visualize
    target_layers = [model.block8.branch1[-1]]
    use_cuda = True if torch.cuda.is_available() else False
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
    targets = [ClassifierOutputTarget(0)]
    grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy()
    visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)
    face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0)

    return prediction, face_with_mask

def predict_video(input_video):
    cap = cv2.VideoCapture(input_video)

    frames = []
    confidences = []
    frame_count = 0
    skip_frames = 20

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count+=1
        if frame_count % skip_frames != 0:  # Skip frames if not divisible by skip_frames
            continue

        prediction, frame_with_mask = predict_frame(frame)

        frames.append(frame_with_mask)
        confidences.append(prediction)

    cap.release()

    # Determine the final prediction based on the maximum occurrence of predictions
    final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real'

    return final_prediction

# Gradio Interface
interface = gr.Interface(
    fn=predict_video,
    inputs=[
        gr.Video(label="Input Video")
    ],
    outputs=[
        gr.Label(label="Class"),

    ],
    title="Deep fake video Detection",
    description="Detect whether the  Video is fake or real"
)

interface.launch()