dhairyashah commited on
Commit
657d714
1 Parent(s): b18f491

Create app.py.bak

Browse files
Files changed (1) hide show
  1. app.py.bak +121 -0
app.py.bak ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from flask import Flask, request, jsonify
3
+ import os
4
+ from werkzeug.utils import secure_filename
5
+ import cv2
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from facenet_pytorch import MTCNN, InceptionResnetV1
9
+ import numpy as np
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ import os
13
+
14
+ app = Flask(__name__)
15
+
16
+ # Configuration
17
+ UPLOAD_FOLDER = 'uploads'
18
+ ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'webm'}
19
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
20
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
21
+
22
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
+
24
+ # Device configuration
25
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
26
+
27
+ mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
+
29
+ model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
30
+ # Model Credits: https://huggingface.co/spaces/dhairyashah/deepfake-alpha-version/blob/main/CREDITS.md
31
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
32
+ model.load_state_dict(checkpoint['model_state_dict'])
33
+ model.to(DEVICE)
34
+ model.eval()
35
+
36
+ def allowed_file(filename):
37
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
38
+
39
+ @spaces.GPU
40
+ def process_frame(frame):
41
+ face = mtcnn(frame)
42
+ if face is None:
43
+ return None, None
44
+
45
+ face = face.unsqueeze(0)
46
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
47
+
48
+ face = face.to(DEVICE)
49
+ face = face.to(torch.float32)
50
+ face = face / 255.0
51
+
52
+ with torch.no_grad():
53
+ output = torch.sigmoid(model(face).squeeze(0))
54
+ prediction = "fake" if output.item() >= 0.5 else "real"
55
+
56
+ return prediction, output.item()
57
+
58
+ @spaces.GPU
59
+ def analyze_video(video_path, sample_rate=30):
60
+ cap = cv2.VideoCapture(video_path)
61
+ frame_count = 0
62
+ fake_count = 0
63
+ total_processed = 0
64
+
65
+ while cap.isOpened():
66
+ ret, frame = cap.read()
67
+ if not ret:
68
+ break
69
+
70
+ if frame_count % sample_rate == 0:
71
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72
+ prediction, confidence = process_frame(rgb_frame)
73
+
74
+ if prediction is not None:
75
+ total_processed += 1
76
+ if prediction == "fake":
77
+ fake_count += 1
78
+
79
+ frame_count += 1
80
+
81
+ cap.release()
82
+
83
+ if total_processed > 0:
84
+ fake_percentage = (fake_count / total_processed) * 100
85
+ return fake_percentage
86
+ else:
87
+ return 0
88
+
89
+ @app.route('/analyze', methods=['POST'])
90
+ def analyze_video_api():
91
+ if 'video' not in request.files:
92
+ return jsonify({'error': 'No video file provided'}), 400
93
+
94
+ file = request.files['video']
95
+
96
+ if file.filename == '':
97
+ return jsonify({'error': 'No selected file'}), 400
98
+
99
+ if file and allowed_file(file.filename):
100
+ filename = secure_filename(file.filename)
101
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
102
+ file.save(filepath)
103
+
104
+ try:
105
+ fake_percentage = analyze_video(filepath)
106
+ os.remove(filepath) # Remove the file after analysis
107
+
108
+ result = {
109
+ 'fake_percentage': round(fake_percentage, 2),
110
+ 'is_likely_deepfake': fake_percentage >= 60
111
+ }
112
+
113
+ return jsonify(result), 200
114
+ except Exception as e:
115
+ os.remove(filepath) # Remove the file if an error occurs
116
+ return jsonify({'error': str(e)}), 500
117
+ else:
118
+ return jsonify({'error': f'Invalid file type: {file.filename}'}), 400
119
+
120
+ if __name__ == '__main__':
121
+ app.run(host='0.0.0.0', port=7860)