ImjerryCo commited on
Commit
e60e640
·
verified ·
1 Parent(s): 116d186

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +84 -0
utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import tempfile
3
+ import requests
4
+ import os
5
+ from PIL import Image
6
+ from transformers import pipeline
7
+
8
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
9
+
10
+
11
+ def download_file(url):
12
+ response = requests.get(url, stream=True)
13
+ tmp = tempfile.NamedTemporaryFile(delete=False)
14
+ for chunk in response.iter_content(1024):
15
+ tmp.write(chunk)
16
+ tmp.close()
17
+ return tmp.name
18
+
19
+
20
+ def get_video_duration(video_path):
21
+ cap = cv2.VideoCapture(video_path)
22
+ fps = cap.get(cv2.CAP_PROP_FPS)
23
+ frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
24
+ cap.release()
25
+ return frame_count / fps if fps > 0 else 0
26
+
27
+
28
+ def extract_frame(video_path, second):
29
+ cap = cv2.VideoCapture(video_path)
30
+ fps = cap.get(cv2.CAP_PROP_FPS)
31
+ frame_number = int(fps * second)
32
+
33
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
34
+ success, frame = cap.read()
35
+ cap.release()
36
+
37
+ if not success:
38
+ return None
39
+
40
+ tmp_file = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
41
+ cv2.imwrite(tmp_file.name, frame)
42
+ return tmp_file.name
43
+
44
+
45
+ def get_frame_times(duration, file_size_mb):
46
+ if duration <= 2:
47
+ return [1, 2]
48
+
49
+ elif duration <= 10:
50
+ return [2]
51
+
52
+ elif duration <= 15:
53
+ return [4, 9, 13]
54
+
55
+ if file_size_mb > 14:
56
+ return [4, 9, 13]
57
+
58
+ return [2]
59
+
60
+
61
+ def check_image_nsfw(image_path):
62
+ image = Image.open(image_path).convert("RGB")
63
+ result = classifier(image)
64
+
65
+ for r in result:
66
+ if r["label"] == "nsfw" and r["score"] > 0.5:
67
+ return True
68
+
69
+ return False
70
+
71
+
72
+ def check_video_nsfw(video_path):
73
+ size_mb = os.path.getsize(video_path) / (1024 * 1024)
74
+ duration = get_video_duration(video_path)
75
+
76
+ times = get_frame_times(duration, size_mb)
77
+
78
+ for t in times:
79
+ frame = extract_frame(video_path, t)
80
+ if frame:
81
+ if check_image_nsfw(frame):
82
+ return True # 🚨 return immediately if ANY frame is NSFW
83
+
84
+ return False