ImjerryCo commited on
Commit
da8f393
·
verified ·
1 Parent(s): d2904be

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +135 -0
utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import tempfile
3
+ import requests
4
+ import os
5
+ from PIL import Image
6
+ from transformers import pipeline
7
+ import torch
8
+
9
+ # 🔥 SPEED BOOST SETTINGS
10
+ torch.set_grad_enabled(False)
11
+ torch.set_num_threads(2)
12
+
13
+ # 🔥 Faster NSFW model
14
+ classifier = pipeline(
15
+ "image-classification",
16
+ model="AdamCodd/vit-base-nsfw-detector",
17
+ device=-1 # CPU
18
+ )
19
+
20
+ # -----------------------------
21
+ # Download with retry + headers (FIX CATBOX)
22
+ # -----------------------------
23
+ def download_file(url):
24
+ headers = {
25
+ "User-Agent": "Mozilla/5.0",
26
+ "Accept": "*/*",
27
+ "Connection": "keep-alive",
28
+ "Range": "bytes=0-"
29
+ }
30
+
31
+ for _ in range(3): # retry
32
+ try:
33
+ response = requests.get(
34
+ url,
35
+ headers=headers,
36
+ stream=True,
37
+ timeout=10
38
+ )
39
+
40
+ if response.status_code != 200:
41
+ continue
42
+
43
+ tmp = tempfile.NamedTemporaryFile(delete=False)
44
+
45
+ for chunk in response.iter_content(1024 * 1024):
46
+ if chunk:
47
+ tmp.write(chunk)
48
+
49
+ tmp.close()
50
+ return tmp.name
51
+
52
+ except requests.exceptions.RequestException:
53
+ continue
54
+
55
+ raise Exception("Failed to fetch file")
56
+
57
+
58
+ # -----------------------------
59
+ # Video duration
60
+ # -----------------------------
61
+ def get_video_duration(video_path):
62
+ cap = cv2.VideoCapture(video_path)
63
+ fps = cap.get(cv2.CAP_PROP_FPS)
64
+ frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
65
+ cap.release()
66
+ return frames / fps if fps > 0 else 0
67
+
68
+
69
+ # -----------------------------
70
+ # Extract frame
71
+ # -----------------------------
72
+ def extract_frame(video_path, second):
73
+ cap = cv2.VideoCapture(video_path)
74
+ fps = cap.get(cv2.CAP_PROP_FPS)
75
+
76
+ frame_no = int(fps * second)
77
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
78
+
79
+ success, frame = cap.read()
80
+ cap.release()
81
+
82
+ if not success:
83
+ return None
84
+
85
+ tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
86
+ cv2.imwrite(tmp.name, frame)
87
+ return tmp.name
88
+
89
+
90
+ # -----------------------------
91
+ # FAST frame selection
92
+ # -----------------------------
93
+ def get_frame_times(duration):
94
+ if duration <= 3:
95
+ return [1]
96
+
97
+ elif duration <= 10:
98
+ return [2]
99
+
100
+ else:
101
+ return [3, 8] # max 2 frames (FAST)
102
+
103
+
104
+ # -----------------------------
105
+ # Image NSFW check (OPTIMIZED)
106
+ # -----------------------------
107
+ def check_image_nsfw(image_path):
108
+ img = Image.open(image_path).convert("RGB")
109
+
110
+ # 🔥 Resize = BIG SPEED BOOST
111
+ img = img.resize((224, 224))
112
+
113
+ result = classifier(img)
114
+
115
+ for r in result:
116
+ if r["label"].lower() == "nsfw" and r["score"] > 0.5:
117
+ return True
118
+
119
+ return False
120
+
121
+
122
+ # -----------------------------
123
+ # Video NSFW check
124
+ # -----------------------------
125
+ def check_video_nsfw(video_path):
126
+ duration = get_video_duration(video_path)
127
+ times = get_frame_times(duration)
128
+
129
+ for t in times:
130
+ frame = extract_frame(video_path, t)
131
+ if frame:
132
+ if check_image_nsfw(frame):
133
+ return True # 🚨 stop early
134
+
135
+ return False