dhairyashah commited on
Commit
d3ce1ea
·
verified ·
1 Parent(s): 9eca54f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -5
app.py CHANGED
@@ -1,10 +1,118 @@
1
- from flask import Flask, jsonify
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = Flask(__name__)
4
 
5
- @app.route('/api/hello', methods=['GET'])
6
- def hello():
7
- return jsonify(message="Hello, World!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  if __name__ == '__main__':
10
- app.run(host='0.0.0.0', port=7860)
 
1
+ from flask import Flask, request, jsonify
2
+ from pyngrok import ngrok
3
+ import os
4
+ from werkzeug.utils import secure_filename
5
+ import cv2
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from facenet_pytorch import MTCNN, InceptionResnetV1
9
+ import numpy as np
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ import os
13
 
14
  app = Flask(__name__)
15
 
16
+ # Configuration
17
+ UPLOAD_FOLDER = 'uploads'
18
+ ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov'}
19
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
20
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
21
+
22
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
+
24
+ # Device configuration
25
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
26
+
27
+ mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
+
29
+ model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
30
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
31
+ model.load_state_dict(checkpoint['model_state_dict'])
32
+ model.to(DEVICE)
33
+ model.eval()
34
+
35
+ def allowed_file(filename):
36
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
37
+
38
+ def process_frame(frame):
39
+ face = mtcnn(frame)
40
+ if face is None:
41
+ return None, None
42
+
43
+ face = face.unsqueeze(0)
44
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
45
+
46
+ face = face.to(DEVICE)
47
+ face = face.to(torch.float32)
48
+ face = face / 255.0
49
+
50
+ with torch.no_grad():
51
+ output = torch.sigmoid(model(face).squeeze(0))
52
+ prediction = "fake" if output.item() >= 0.5 else "real"
53
+
54
+ return prediction, output.item()
55
+
56
+ def analyze_video(video_path, sample_rate=30):
57
+ cap = cv2.VideoCapture(video_path)
58
+ frame_count = 0
59
+ fake_count = 0
60
+ total_processed = 0
61
+
62
+ while cap.isOpened():
63
+ ret, frame = cap.read()
64
+ if not ret:
65
+ break
66
+
67
+ if frame_count % sample_rate == 0:
68
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
69
+ prediction, confidence = process_frame(rgb_frame)
70
+
71
+ if prediction is not None:
72
+ total_processed += 1
73
+ if prediction == "fake":
74
+ fake_count += 1
75
+
76
+ frame_count += 1
77
+
78
+ cap.release()
79
+
80
+ if total_processed > 0:
81
+ fake_percentage = (fake_count / total_processed) * 100
82
+ return fake_percentage
83
+ else:
84
+ return 0
85
+
86
+ @app.route('/analyze', methods=['POST'])
87
+ def analyze_video_api():
88
+ if 'video' not in request.files:
89
+ return jsonify({'error': 'No video file provided'}), 400
90
+
91
+ file = request.files['video']
92
+
93
+ if file.filename == '':
94
+ return jsonify({'error': 'No selected file'}), 400
95
+
96
+ if file and allowed_file(file.filename):
97
+ filename = secure_filename(file.filename)
98
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
99
+ file.save(filepath)
100
+
101
+ try:
102
+ fake_percentage = analyze_video(filepath)
103
+ os.remove(filepath) # Remove the file after analysis
104
+
105
+ result = {
106
+ 'fake_percentage': round(fake_percentage, 2),
107
+ 'is_likely_deepfake': fake_percentage >= 60
108
+ }
109
+
110
+ return jsonify(result), 200
111
+ except Exception as e:
112
+ os.remove(filepath) # Remove the file if an error occurs
113
+ return jsonify({'error': str(e)}), 500
114
+ else:
115
+ return jsonify({'error': 'Invalid file type'}), 400
116
 
117
  if __name__ == '__main__':
118
+ app.run(host='0.0.0.0', port=7860)