Upload 17 files
Browse files- app.py +195 -0
- classifiers.py +136 -0
- networks/drn.py +416 -0
- networks/drn_seg.py +95 -0
- pipeline.py +172 -0
- requirements.txt +40 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-312.pyc +0 -0
- utils/__pycache__/preprocessing.cpython-312.pyc +0 -0
- utils/__pycache__/tools.cpython-312.pyc +0 -0
- utils/__pycache__/visualization.cpython-312.pyc +0 -0
- utils/__pycache__/visualize.cpython-312.pyc +0 -0
- utils/download_weights.py +45 -0
- utils/preprocessing.py +98 -0
- utils/tools.py +143 -0
- utils/visualization.py +14 -0
- utils/visualize.py +61 -0
app.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
from models import mesonet, mesoinception, fal_detector, local_detector
|
4 |
+
from utils.visualization import display_results
|
5 |
+
from utils.preprocessing import preprocess_image, preprocess_video
|
6 |
+
from utils.preprocessing import preprocess_image, generate_local_image
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import tempfile
|
10 |
+
import os
|
11 |
+
|
12 |
+
# Initialize session state for tabs and uploaded files
|
13 |
+
if "active_tab" not in st.session_state:
|
14 |
+
st.session_state["active_tab"] = "Face Photoshop Detection"
|
15 |
+
|
16 |
+
if "uploaded_file" not in st.session_state:
|
17 |
+
st.session_state["uploaded_file"] = None
|
18 |
+
|
19 |
+
# Load models
|
20 |
+
models = {
|
21 |
+
"MesoNet": mesonet.load_mesonet("models/weights/Meso4_DF.h5"),
|
22 |
+
"MesoInception": mesoinception.load_mesonetInception("models/weights/MesoInception_DF.h5"),
|
23 |
+
"Photoshop FALdetector Global": fal_detector.load_fal_detector("models/weights/global.pth"),
|
24 |
+
"Photoshop FALdetector Local": local_detector.load_local_detector("models/weights/local.pth", gpu_id=-1),
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
st.title("DeepSAIF")
|
29 |
+
# Create tabs for different functionalities
|
30 |
+
tab1, tab2, tab3 = st.tabs(["Face Photoshop Detection", "DeepFake Detection for Images", "DeepFake Detection for Videos"])
|
31 |
+
|
32 |
+
# Tab 1: Photoshop Detection
|
33 |
+
with tab1:
|
34 |
+
if st.session_state["active_tab"] != "Face Photoshop Detection":
|
35 |
+
st.session_state["uploaded_file"] = None
|
36 |
+
st.session_state["active_tab"] = "Face Photoshop Detection"
|
37 |
+
|
38 |
+
st.header("Face Photoshop Detection")
|
39 |
+
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png"], key="photoshop")
|
40 |
+
|
41 |
+
if uploaded_file:
|
42 |
+
st.session_state["uploaded_file"] = uploaded_file
|
43 |
+
image = Image.open(uploaded_file).convert("RGB")
|
44 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
45 |
+
local_image = generate_local_image(image)
|
46 |
+
|
47 |
+
# Run inference on all models
|
48 |
+
results = {}
|
49 |
+
for model_name, model in models.items():
|
50 |
+
if model_name == "Photoshop FALdetector Global":
|
51 |
+
results[model_name] = fal_detector.predict_fal_detector(model, image)
|
52 |
+
elif model_name == "Photoshop FALdetector Local":
|
53 |
+
heatmap_path, prediction = local_detector.predict_and_generate_heatmap(model, image)
|
54 |
+
if heatmap_path:
|
55 |
+
# Display the heatmap using Streamlit
|
56 |
+
st.image(heatmap_path, caption=f"Heatmap for {model_name}", use_container_width=True)
|
57 |
+
|
58 |
+
# Delete the temporary heatmap file after display
|
59 |
+
os.remove(heatmap_path)
|
60 |
+
os.remove('cropped_input.jpg')
|
61 |
+
os.remove('warped.jpg')
|
62 |
+
else:
|
63 |
+
st.error(f"Failed to generate heatmap for {model_name}")
|
64 |
+
results[model_name] = prediction
|
65 |
+
# elif model_name == "Global Classifier":
|
66 |
+
# results[model_name] = global_classifier.classify_fake(model, image)
|
67 |
+
|
68 |
+
|
69 |
+
# Display results
|
70 |
+
display_results(results)
|
71 |
+
|
72 |
+
# Tab 2: DeepFake Detection for Images
|
73 |
+
with tab2:
|
74 |
+
if st.session_state["active_tab"] != "DeepFake Detection for Images":
|
75 |
+
st.session_state["uploaded_file"] = None
|
76 |
+
st.session_state["active_tab"] = "DeepFake Detection for Images"
|
77 |
+
|
78 |
+
st.header("DeepFake Detection for Images")
|
79 |
+
|
80 |
+
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png"], key="deepfake_image")
|
81 |
+
|
82 |
+
if uploaded_file:
|
83 |
+
st.session_state["uploaded_file"] = uploaded_file
|
84 |
+
image = Image.open(uploaded_file).convert("RGB")
|
85 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
86 |
+
local_image = generate_local_image(image)
|
87 |
+
# Preprocess the image
|
88 |
+
# preprocessed_image = preprocess_image(uploaded_file)
|
89 |
+
|
90 |
+
# Run inference on all models
|
91 |
+
results = {}
|
92 |
+
for model_name, model in models.items():
|
93 |
+
if model_name == "MesoNet":
|
94 |
+
results[model_name] = mesonet.predict_mesonet(model, image)
|
95 |
+
elif model_name == "MesoInception":
|
96 |
+
results[model_name] = mesoinception.predict_mesonetInception(model, image)
|
97 |
+
|
98 |
+
# Display results
|
99 |
+
display_results(results)
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
def confident_strategy(pred, t=0.8):
|
104 |
+
"""
|
105 |
+
Implements the confident averaging strategy for predictions.
|
106 |
+
Args:
|
107 |
+
pred (list[float]): List of predictions for each frame.
|
108 |
+
t (float): Threshold for high-confidence fake detection.
|
109 |
+
Returns:
|
110 |
+
float: Final confidence score for the video.
|
111 |
+
"""
|
112 |
+
if len(pred) == 0:
|
113 |
+
return np.nan
|
114 |
+
pred = np.array(pred)
|
115 |
+
sz = len(pred)
|
116 |
+
fakes = np.count_nonzero(pred > t)
|
117 |
+
if fakes > sz // 2.5 and fakes > 11:
|
118 |
+
return np.mean(pred[pred > t])
|
119 |
+
elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
|
120 |
+
return np.mean(pred[pred < 0.2])
|
121 |
+
else:
|
122 |
+
return np.mean(pred)
|
123 |
+
|
124 |
+
# Tab 3: DeepFake Detection for Videos
|
125 |
+
with tab3:
|
126 |
+
if st.session_state["active_tab"] != "DeepFake Detection for Videos":
|
127 |
+
st.session_state["uploaded_file"] = None
|
128 |
+
st.session_state["active_tab"] = "DeepFake Detection for Videos"
|
129 |
+
|
130 |
+
st.header("DeepFake Detection for Videos")
|
131 |
+
uploaded_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"], key="deepfake_video")
|
132 |
+
if uploaded_file:
|
133 |
+
st.session_state["uploaded_file"] = uploaded_file
|
134 |
+
with st.spinner("Processing video..."):
|
135 |
+
# Save uploaded file to a temporary location
|
136 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
|
137 |
+
temp_file.write(uploaded_file.getbuffer())
|
138 |
+
video_path = temp_file.name
|
139 |
+
|
140 |
+
try:
|
141 |
+
# Test video accessibility
|
142 |
+
cap = cv2.VideoCapture(video_path)
|
143 |
+
if not cap.isOpened():
|
144 |
+
st.error("Failed to open video file.")
|
145 |
+
else:
|
146 |
+
st.success("Video file opened successfully!")
|
147 |
+
|
148 |
+
# Extract frames from video
|
149 |
+
frames = preprocess_video(video_path, frame_count=32)
|
150 |
+
if len(frames) == 0:
|
151 |
+
st.error("Failed to extract frames from the video.")
|
152 |
+
else:
|
153 |
+
# st.success(f"Extracted {len(frames)} frames.")
|
154 |
+
# for frame in frames[:5]: # Display first 5 frames
|
155 |
+
# st.image(frame, caption="Extracted Frame")
|
156 |
+
|
157 |
+
# Dictionary to store model predictions
|
158 |
+
model_results = {
|
159 |
+
"MesoNet": [],
|
160 |
+
"Photoshop FALdetector Global": []
|
161 |
+
}
|
162 |
+
|
163 |
+
# Iterate over frames and make predictions for each model
|
164 |
+
for frame in frames:
|
165 |
+
preprocessed_frame = preprocess_image(frame) # Preprocess frame
|
166 |
+
local_image = generate_local_image(preprocessed_frame)
|
167 |
+
|
168 |
+
# Predictions for MesoNet and Photoshop FALdetector Global
|
169 |
+
model_results["MesoNet"].append(
|
170 |
+
mesonet.predict_mesonet(models["MesoNet"], preprocessed_frame)
|
171 |
+
)
|
172 |
+
model_results["Photoshop FALdetector Global"].append(
|
173 |
+
fal_detector.predict_fal_detector(models["Photoshop FALdetector Global"], local_image)
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
# Apply the confident averaging strategy for each model
|
178 |
+
final_results = {}
|
179 |
+
for model_name, predictions in model_results.items():
|
180 |
+
final_results[model_name] = confident_strategy(predictions)
|
181 |
+
|
182 |
+
# Display results
|
183 |
+
st.write("### Video Analysis Results")
|
184 |
+
display_results(final_results)
|
185 |
+
|
186 |
+
# Optionally show detailed frame predictions per model
|
187 |
+
if st.checkbox("Show Detailed Frame Predictions"):
|
188 |
+
for model_name, predictions in model_results.items():
|
189 |
+
st.write(f"### Predictions for {model_name}")
|
190 |
+
st.bar_chart(predictions)
|
191 |
+
|
192 |
+
|
193 |
+
finally:
|
194 |
+
# Clean up temporary file
|
195 |
+
os.remove(video_path)
|
classifiers.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
|
3 |
+
from tensorflow.keras.models import Model as KerasModel
|
4 |
+
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Dropout, Reshape, Concatenate, LeakyReLU
|
5 |
+
from tensorflow.keras.optimizers import Adam
|
6 |
+
|
7 |
+
IMGWIDTH = 256
|
8 |
+
|
9 |
+
class Classifier:
|
10 |
+
def __init__():
|
11 |
+
self.model = 0
|
12 |
+
|
13 |
+
def predict(self, x):
|
14 |
+
if x.size == 0:
|
15 |
+
return []
|
16 |
+
return self.model.predict(x)
|
17 |
+
|
18 |
+
def fit(self, x, y):
|
19 |
+
return self.model.train_on_batch(x, y)
|
20 |
+
|
21 |
+
def get_accuracy(self, x, y):
|
22 |
+
return self.model.test_on_batch(x, y)
|
23 |
+
|
24 |
+
def load(self, path):
|
25 |
+
self.model.load_weights(path)
|
26 |
+
|
27 |
+
|
28 |
+
class Meso1(Classifier):
|
29 |
+
"""
|
30 |
+
Feature extraction + Classification
|
31 |
+
"""
|
32 |
+
def __init__(self, learning_rate = 0.001, dl_rate = 1):
|
33 |
+
self.model = self.init_model(dl_rate)
|
34 |
+
optimizer = Adam(lr = learning_rate)
|
35 |
+
self.model.compile(optimizer = optimizer, loss = 'mean_squared_error', metrics = ['accuracy'])
|
36 |
+
|
37 |
+
def init_model(self, dl_rate):
|
38 |
+
x = Input(shape = (IMGWIDTH, IMGWIDTH, 3))
|
39 |
+
|
40 |
+
x1 = Conv2D(16, (3, 3), dilation_rate = dl_rate, strides = 1, padding='same', activation = 'relu')(x)
|
41 |
+
x1 = Conv2D(4, (1, 1), padding='same', activation = 'relu')(x1)
|
42 |
+
x1 = BatchNormalization()(x1)
|
43 |
+
x1 = MaxPooling2D(pool_size=(8, 8), padding='same')(x1)
|
44 |
+
|
45 |
+
y = Flatten()(x1)
|
46 |
+
y = Dropout(0.5)(y)
|
47 |
+
y = Dense(1, activation = 'sigmoid')(y)
|
48 |
+
return KerasModel(inputs = x, outputs = y)
|
49 |
+
|
50 |
+
|
51 |
+
class Meso4(Classifier):
|
52 |
+
def __init__(self, learning_rate = 0.001):
|
53 |
+
self.model = self.init_model()
|
54 |
+
optimizer = Adam(learning_rate = learning_rate)
|
55 |
+
self.model.compile(optimizer = optimizer, loss = 'mean_squared_error', metrics = ['accuracy'])
|
56 |
+
|
57 |
+
def init_model(self):
|
58 |
+
x = Input(shape = (IMGWIDTH, IMGWIDTH, 3))
|
59 |
+
|
60 |
+
x1 = Conv2D(8, (3, 3), padding='same', activation = 'relu')(x)
|
61 |
+
x1 = BatchNormalization()(x1)
|
62 |
+
x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
|
63 |
+
|
64 |
+
x2 = Conv2D(8, (5, 5), padding='same', activation = 'relu')(x1)
|
65 |
+
x2 = BatchNormalization()(x2)
|
66 |
+
x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
|
67 |
+
|
68 |
+
x3 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x2)
|
69 |
+
x3 = BatchNormalization()(x3)
|
70 |
+
x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
|
71 |
+
|
72 |
+
x4 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x3)
|
73 |
+
x4 = BatchNormalization()(x4)
|
74 |
+
x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
|
75 |
+
|
76 |
+
y = Flatten()(x4)
|
77 |
+
y = Dropout(0.5)(y)
|
78 |
+
y = Dense(16)(y)
|
79 |
+
y = LeakyReLU(negative_slope=0.1)(y)
|
80 |
+
y = Dropout(0.5)(y)
|
81 |
+
y = Dense(1, activation = 'sigmoid')(y)
|
82 |
+
|
83 |
+
return KerasModel(inputs = x, outputs = y)
|
84 |
+
|
85 |
+
|
86 |
+
class MesoInception4(Classifier):
|
87 |
+
def __init__(self, learning_rate = 0.001):
|
88 |
+
self.model = self.init_model()
|
89 |
+
optimizer = Adam(learning_rate = learning_rate)
|
90 |
+
self.model.compile(optimizer = optimizer, loss = 'mean_squared_error', metrics = ['accuracy'])
|
91 |
+
|
92 |
+
def InceptionLayer(self, a, b, c, d):
|
93 |
+
def func(x):
|
94 |
+
x1 = Conv2D(a, (1, 1), padding='same', activation='relu')(x)
|
95 |
+
|
96 |
+
x2 = Conv2D(b, (1, 1), padding='same', activation='relu')(x)
|
97 |
+
x2 = Conv2D(b, (3, 3), padding='same', activation='relu')(x2)
|
98 |
+
|
99 |
+
x3 = Conv2D(c, (1, 1), padding='same', activation='relu')(x)
|
100 |
+
x3 = Conv2D(c, (3, 3), dilation_rate = 2, strides = 1, padding='same', activation='relu')(x3)
|
101 |
+
|
102 |
+
x4 = Conv2D(d, (1, 1), padding='same', activation='relu')(x)
|
103 |
+
x4 = Conv2D(d, (3, 3), dilation_rate = 3, strides = 1, padding='same', activation='relu')(x4)
|
104 |
+
|
105 |
+
y = Concatenate(axis = -1)([x1, x2, x3, x4])
|
106 |
+
|
107 |
+
return y
|
108 |
+
return func
|
109 |
+
|
110 |
+
def init_model(self):
|
111 |
+
x = Input(shape = (IMGWIDTH, IMGWIDTH, 3))
|
112 |
+
|
113 |
+
x1 = self.InceptionLayer(1, 4, 4, 2)(x)
|
114 |
+
x1 = BatchNormalization()(x1)
|
115 |
+
x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
|
116 |
+
|
117 |
+
x2 = self.InceptionLayer(2, 4, 4, 2)(x1)
|
118 |
+
x2 = BatchNormalization()(x2)
|
119 |
+
x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
|
120 |
+
|
121 |
+
x3 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x2)
|
122 |
+
x3 = BatchNormalization()(x3)
|
123 |
+
x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
|
124 |
+
|
125 |
+
x4 = Conv2D(16, (5, 5), padding='same', activation = 'relu')(x3)
|
126 |
+
x4 = BatchNormalization()(x4)
|
127 |
+
x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
|
128 |
+
|
129 |
+
y = Flatten()(x4)
|
130 |
+
y = Dropout(0.5)(y)
|
131 |
+
y = Dense(16)(y)
|
132 |
+
y = LeakyReLU(negative_slope=0.1)(y)
|
133 |
+
y = Dropout(0.5)(y)
|
134 |
+
y = Dense(1, activation = 'sigmoid')(y)
|
135 |
+
|
136 |
+
return KerasModel(inputs = x, outputs = y)
|
networks/drn.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import math
|
6 |
+
import torch.utils.model_zoo as model_zoo
|
7 |
+
|
8 |
+
torch.backends.cudnn.benchmark = True
|
9 |
+
BatchNorm = nn.BatchNorm2d
|
10 |
+
|
11 |
+
|
12 |
+
# __all__ = ['DRN', 'drn26', 'drn42', 'drn58']
|
13 |
+
|
14 |
+
|
15 |
+
webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/'
|
16 |
+
|
17 |
+
model_urls = {
|
18 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
19 |
+
'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
|
20 |
+
'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
|
21 |
+
'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
|
22 |
+
'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
|
23 |
+
'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
|
24 |
+
'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
|
25 |
+
'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
|
30 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
31 |
+
padding=padding, bias=False, dilation=dilation)
|
32 |
+
|
33 |
+
|
34 |
+
class BasicBlock(nn.Module):
|
35 |
+
expansion = 1
|
36 |
+
|
37 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
38 |
+
dilation=(1, 1), residual=True):
|
39 |
+
super(BasicBlock, self).__init__()
|
40 |
+
self.conv1 = conv3x3(inplanes, planes, stride,
|
41 |
+
padding=dilation[0], dilation=dilation[0])
|
42 |
+
self.bn1 = BatchNorm(planes)
|
43 |
+
self.relu = nn.ReLU(inplace=True)
|
44 |
+
self.conv2 = conv3x3(planes, planes,
|
45 |
+
padding=dilation[1], dilation=dilation[1])
|
46 |
+
self.bn2 = BatchNorm(planes)
|
47 |
+
self.downsample = downsample
|
48 |
+
self.stride = stride
|
49 |
+
self.residual = residual
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
residual = x
|
53 |
+
|
54 |
+
out = self.conv1(x)
|
55 |
+
out = self.bn1(out)
|
56 |
+
out = self.relu(out)
|
57 |
+
|
58 |
+
out = self.conv2(out)
|
59 |
+
out = self.bn2(out)
|
60 |
+
|
61 |
+
if self.downsample is not None:
|
62 |
+
residual = self.downsample(x)
|
63 |
+
if self.residual:
|
64 |
+
out += residual
|
65 |
+
out = self.relu(out)
|
66 |
+
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class Bottleneck(nn.Module):
|
71 |
+
expansion = 4
|
72 |
+
|
73 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
74 |
+
dilation=(1, 1), residual=True):
|
75 |
+
super(Bottleneck, self).__init__()
|
76 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
77 |
+
self.bn1 = BatchNorm(planes)
|
78 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
79 |
+
padding=dilation[1], bias=False,
|
80 |
+
dilation=dilation[1])
|
81 |
+
self.bn2 = BatchNorm(planes)
|
82 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
83 |
+
self.bn3 = BatchNorm(planes * 4)
|
84 |
+
self.relu = nn.ReLU(inplace=True)
|
85 |
+
self.downsample = downsample
|
86 |
+
self.stride = stride
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
residual = x
|
90 |
+
|
91 |
+
out = self.conv1(x)
|
92 |
+
out = self.bn1(out)
|
93 |
+
out = self.relu(out)
|
94 |
+
|
95 |
+
out = self.conv2(out)
|
96 |
+
out = self.bn2(out)
|
97 |
+
out = self.relu(out)
|
98 |
+
|
99 |
+
out = self.conv3(out)
|
100 |
+
out = self.bn3(out)
|
101 |
+
|
102 |
+
if self.downsample is not None:
|
103 |
+
residual = self.downsample(x)
|
104 |
+
|
105 |
+
out += residual
|
106 |
+
out = self.relu(out)
|
107 |
+
|
108 |
+
return out
|
109 |
+
|
110 |
+
|
111 |
+
class DRN(nn.Module):
|
112 |
+
|
113 |
+
def __init__(self, block, layers, num_classes=1000,
|
114 |
+
channels=(16, 32, 64, 128, 256, 512, 512, 512),
|
115 |
+
out_map=False, out_middle=False, pool_size=28, arch='D'):
|
116 |
+
super(DRN, self).__init__()
|
117 |
+
self.inplanes = channels[0]
|
118 |
+
self.out_map = out_map
|
119 |
+
self.out_dim = channels[-1]
|
120 |
+
self.out_middle = out_middle
|
121 |
+
self.arch = arch
|
122 |
+
|
123 |
+
if arch == 'C':
|
124 |
+
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
|
125 |
+
padding=3, bias=False)
|
126 |
+
self.bn1 = BatchNorm(channels[0])
|
127 |
+
self.relu = nn.ReLU(inplace=True)
|
128 |
+
|
129 |
+
self.layer1 = self._make_layer(
|
130 |
+
BasicBlock, channels[0], layers[0], stride=1)
|
131 |
+
self.layer2 = self._make_layer(
|
132 |
+
BasicBlock, channels[1], layers[1], stride=2)
|
133 |
+
elif arch == 'D':
|
134 |
+
self.layer0 = nn.Sequential(
|
135 |
+
nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
|
136 |
+
bias=False),
|
137 |
+
BatchNorm(channels[0]),
|
138 |
+
nn.ReLU(inplace=True)
|
139 |
+
)
|
140 |
+
|
141 |
+
self.layer1 = self._make_conv_layers(
|
142 |
+
channels[0], layers[0], stride=1)
|
143 |
+
self.layer2 = self._make_conv_layers(
|
144 |
+
channels[1], layers[1], stride=2)
|
145 |
+
|
146 |
+
self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2)
|
147 |
+
self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2)
|
148 |
+
self.layer5 = self._make_layer(block, channels[4], layers[4],
|
149 |
+
dilation=2, new_level=False)
|
150 |
+
self.layer6 = None if layers[5] == 0 else \
|
151 |
+
self._make_layer(block, channels[5], layers[5], dilation=4,
|
152 |
+
new_level=False)
|
153 |
+
|
154 |
+
if arch == 'C':
|
155 |
+
self.layer7 = None if layers[6] == 0 else \
|
156 |
+
self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
|
157 |
+
new_level=False, residual=False)
|
158 |
+
self.layer8 = None if layers[7] == 0 else \
|
159 |
+
self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
|
160 |
+
new_level=False, residual=False)
|
161 |
+
elif arch == 'D':
|
162 |
+
self.layer7 = None if layers[6] == 0 else \
|
163 |
+
self._make_conv_layers(channels[6], layers[6], dilation=2)
|
164 |
+
self.layer8 = None if layers[7] == 0 else \
|
165 |
+
self._make_conv_layers(channels[7], layers[7], dilation=1)
|
166 |
+
|
167 |
+
if num_classes > 0:
|
168 |
+
self.avgpool = nn.AvgPool2d(pool_size)
|
169 |
+
self.fc = nn.Conv2d(self.out_dim, num_classes, kernel_size=1,
|
170 |
+
stride=1, padding=0, bias=True)
|
171 |
+
for m in self.modules():
|
172 |
+
if isinstance(m, nn.Conv2d):
|
173 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
174 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
175 |
+
elif isinstance(m, BatchNorm):
|
176 |
+
m.weight.data.fill_(1)
|
177 |
+
m.bias.data.zero_()
|
178 |
+
|
179 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
180 |
+
new_level=True, residual=True):
|
181 |
+
assert dilation == 1 or dilation % 2 == 0
|
182 |
+
downsample = None
|
183 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
184 |
+
downsample = nn.Sequential(
|
185 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
186 |
+
kernel_size=1, stride=stride, bias=False),
|
187 |
+
BatchNorm(planes * block.expansion),
|
188 |
+
)
|
189 |
+
|
190 |
+
layers = list()
|
191 |
+
layers.append(block(
|
192 |
+
self.inplanes, planes, stride, downsample,
|
193 |
+
dilation=(1, 1) if dilation == 1 else (
|
194 |
+
dilation // 2 if new_level else dilation, dilation),
|
195 |
+
residual=residual))
|
196 |
+
self.inplanes = planes * block.expansion
|
197 |
+
for i in range(1, blocks):
|
198 |
+
layers.append(block(self.inplanes, planes, residual=residual,
|
199 |
+
dilation=(dilation, dilation)))
|
200 |
+
|
201 |
+
return nn.Sequential(*layers)
|
202 |
+
|
203 |
+
def _make_conv_layers(self, channels, convs, stride=1, dilation=1):
|
204 |
+
modules = []
|
205 |
+
for i in range(convs):
|
206 |
+
modules.extend([
|
207 |
+
nn.Conv2d(self.inplanes, channels, kernel_size=3,
|
208 |
+
stride=stride if i == 0 else 1,
|
209 |
+
padding=dilation, bias=False, dilation=dilation),
|
210 |
+
BatchNorm(channels),
|
211 |
+
nn.ReLU(inplace=True)])
|
212 |
+
self.inplanes = channels
|
213 |
+
return nn.Sequential(*modules)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
y = list()
|
217 |
+
|
218 |
+
if self.arch == 'C':
|
219 |
+
x = self.conv1(x)
|
220 |
+
x = self.bn1(x)
|
221 |
+
x = self.relu(x)
|
222 |
+
elif self.arch == 'D':
|
223 |
+
x = self.layer0(x)
|
224 |
+
|
225 |
+
x = self.layer1(x)
|
226 |
+
y.append(x)
|
227 |
+
x = self.layer2(x)
|
228 |
+
y.append(x)
|
229 |
+
|
230 |
+
x = self.layer3(x)
|
231 |
+
y.append(x)
|
232 |
+
|
233 |
+
x = self.layer4(x)
|
234 |
+
y.append(x)
|
235 |
+
|
236 |
+
x = self.layer5(x)
|
237 |
+
y.append(x)
|
238 |
+
|
239 |
+
if self.layer6 is not None:
|
240 |
+
x = self.layer6(x)
|
241 |
+
y.append(x)
|
242 |
+
|
243 |
+
if self.layer7 is not None:
|
244 |
+
x = self.layer7(x)
|
245 |
+
y.append(x)
|
246 |
+
|
247 |
+
if self.layer8 is not None:
|
248 |
+
x = self.layer8(x)
|
249 |
+
y.append(x)
|
250 |
+
|
251 |
+
if self.out_map:
|
252 |
+
x = self.fc(x)
|
253 |
+
else:
|
254 |
+
x = self.avgpool(x)
|
255 |
+
x = self.fc(x)
|
256 |
+
x = x.view(x.size(0), -1)
|
257 |
+
|
258 |
+
if self.out_middle:
|
259 |
+
return x, y
|
260 |
+
else:
|
261 |
+
return x
|
262 |
+
|
263 |
+
|
264 |
+
class DRN_A(nn.Module):
|
265 |
+
|
266 |
+
def __init__(self, block, layers, num_classes=1000):
|
267 |
+
self.inplanes = 64
|
268 |
+
super(DRN_A, self).__init__()
|
269 |
+
self.out_dim = 512 * block.expansion
|
270 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
271 |
+
bias=False)
|
272 |
+
self.bn1 = nn.BatchNorm2d(64)
|
273 |
+
self.relu = nn.ReLU(inplace=True)
|
274 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
275 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
276 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
277 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
278 |
+
dilation=2)
|
279 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
280 |
+
dilation=4)
|
281 |
+
self.avgpool = nn.AvgPool2d(28, stride=1)
|
282 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
283 |
+
|
284 |
+
for m in self.modules():
|
285 |
+
if isinstance(m, nn.Conv2d):
|
286 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
287 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
288 |
+
elif isinstance(m, BatchNorm):
|
289 |
+
m.weight.data.fill_(1)
|
290 |
+
m.bias.data.zero_()
|
291 |
+
|
292 |
+
# for m in self.modules():
|
293 |
+
# if isinstance(m, nn.Conv2d):
|
294 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
295 |
+
# elif isinstance(m, nn.BatchNorm2d):
|
296 |
+
# nn.init.constant_(m.weight, 1)
|
297 |
+
# nn.init.constant_(m.bias, 0)
|
298 |
+
|
299 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
300 |
+
downsample = None
|
301 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
302 |
+
downsample = nn.Sequential(
|
303 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
304 |
+
kernel_size=1, stride=stride, bias=False),
|
305 |
+
nn.BatchNorm2d(planes * block.expansion),
|
306 |
+
)
|
307 |
+
|
308 |
+
layers = []
|
309 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
310 |
+
self.inplanes = planes * block.expansion
|
311 |
+
for i in range(1, blocks):
|
312 |
+
layers.append(block(self.inplanes, planes,
|
313 |
+
dilation=(dilation, dilation)))
|
314 |
+
|
315 |
+
return nn.Sequential(*layers)
|
316 |
+
|
317 |
+
def forward(self, x):
|
318 |
+
x = self.conv1(x)
|
319 |
+
x = self.bn1(x)
|
320 |
+
x = self.relu(x)
|
321 |
+
x = self.maxpool(x)
|
322 |
+
|
323 |
+
x = self.layer1(x)
|
324 |
+
x = self.layer2(x)
|
325 |
+
x = self.layer3(x)
|
326 |
+
x = self.layer4(x)
|
327 |
+
|
328 |
+
x = self.avgpool(x)
|
329 |
+
x = x.view(x.size(0), -1)
|
330 |
+
x = self.fc(x)
|
331 |
+
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
def drn_a_50(pretrained=False, **kwargs):
|
336 |
+
model = DRN_A(Bottleneck, [3, 4, 6, 3], **kwargs)
|
337 |
+
if pretrained:
|
338 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
339 |
+
return model
|
340 |
+
|
341 |
+
|
342 |
+
def drn_c_26(pretrained=False, **kwargs):
|
343 |
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', **kwargs)
|
344 |
+
if pretrained:
|
345 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-c-26']))
|
346 |
+
return model
|
347 |
+
|
348 |
+
|
349 |
+
def drn_c_42(pretrained=False, **kwargs):
|
350 |
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs)
|
351 |
+
if pretrained:
|
352 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-c-42']))
|
353 |
+
return model
|
354 |
+
|
355 |
+
|
356 |
+
def drn_c_58(pretrained=False, **kwargs):
|
357 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs)
|
358 |
+
if pretrained:
|
359 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-c-58']))
|
360 |
+
return model
|
361 |
+
|
362 |
+
|
363 |
+
def drn_d_22(pretrained=False, **kwargs):
|
364 |
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', **kwargs)
|
365 |
+
if pretrained:
|
366 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-22']))
|
367 |
+
return model
|
368 |
+
|
369 |
+
|
370 |
+
def drn_d_24(pretrained=False, **kwargs):
|
371 |
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', **kwargs)
|
372 |
+
if pretrained:
|
373 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-24']))
|
374 |
+
return model
|
375 |
+
|
376 |
+
|
377 |
+
def drn_d_38(pretrained=False, **kwargs):
|
378 |
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs)
|
379 |
+
if pretrained:
|
380 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-38']))
|
381 |
+
return model
|
382 |
+
|
383 |
+
|
384 |
+
def drn_d_40(pretrained=False, **kwargs):
|
385 |
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs)
|
386 |
+
if pretrained:
|
387 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-40']))
|
388 |
+
return model
|
389 |
+
|
390 |
+
|
391 |
+
def drn_d_54(pretrained=False, **kwargs):
|
392 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs)
|
393 |
+
if pretrained:
|
394 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-54']))
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def drn_d_56(pretrained=False, **kwargs):
|
399 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs)
|
400 |
+
if pretrained:
|
401 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-56']))
|
402 |
+
return model
|
403 |
+
|
404 |
+
|
405 |
+
def drn_d_105(pretrained=False, **kwargs):
|
406 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', **kwargs)
|
407 |
+
if pretrained:
|
408 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-105']))
|
409 |
+
return model
|
410 |
+
|
411 |
+
|
412 |
+
def drn_d_107(pretrained=False, **kwargs):
|
413 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 2, 2], arch='D', **kwargs)
|
414 |
+
if pretrained:
|
415 |
+
model.load_state_dict(model_zoo.load_url(model_urls['drn-d-107']))
|
416 |
+
return model
|
networks/drn_seg.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from networks.drn import drn_c_26
|
5 |
+
|
6 |
+
|
7 |
+
def fill_up_weights(up):
|
8 |
+
w = up.weight.data
|
9 |
+
f = math.ceil(w.size(2) / 2)
|
10 |
+
c = (2 * f - 1 - f % 2) / (2. * f)
|
11 |
+
for i in range(w.size(2)):
|
12 |
+
for j in range(w.size(3)):
|
13 |
+
w[0, 0, i, j] = \
|
14 |
+
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
|
15 |
+
for c in range(1, w.size(0)):
|
16 |
+
w[c, 0, :, :] = w[0, 0, :, :]
|
17 |
+
|
18 |
+
|
19 |
+
class DRNSeg(nn.Module):
|
20 |
+
def __init__(self, classes, pretrained_drn=False,
|
21 |
+
pretrained_model=None, use_torch_up=False):
|
22 |
+
super(DRNSeg, self).__init__()
|
23 |
+
|
24 |
+
model = drn_c_26(pretrained=pretrained_drn)
|
25 |
+
self.base = nn.Sequential(*list(model.children())[:-2])
|
26 |
+
if pretrained_model:
|
27 |
+
self.load_pretrained(pretrained_model)
|
28 |
+
|
29 |
+
self.seg = nn.Conv2d(model.out_dim, classes,
|
30 |
+
kernel_size=1, bias=True)
|
31 |
+
|
32 |
+
m = self.seg
|
33 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
34 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
35 |
+
m.bias.data.zero_()
|
36 |
+
if use_torch_up:
|
37 |
+
self.up = nn.UpsamplingBilinear2d(scale_factor=8)
|
38 |
+
else:
|
39 |
+
up = nn.ConvTranspose2d(classes, classes, 16, stride=8, padding=4,
|
40 |
+
output_padding=0, groups=classes,
|
41 |
+
bias=False)
|
42 |
+
fill_up_weights(up)
|
43 |
+
up.weight.requires_grad = False
|
44 |
+
self.up = up
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.base(x)
|
48 |
+
x = self.seg(x)
|
49 |
+
y = self.up(x)
|
50 |
+
return y
|
51 |
+
|
52 |
+
def optim_parameters(self, memo=None):
|
53 |
+
for param in self.base.parameters():
|
54 |
+
yield param
|
55 |
+
for param in self.seg.parameters():
|
56 |
+
yield param
|
57 |
+
|
58 |
+
def load_pretrained(self, pretrained_model):
|
59 |
+
print("loading the pretrained drn model from %s" % pretrained_model)
|
60 |
+
state_dict = torch.load(pretrained_model, map_location='cpu')
|
61 |
+
if hasattr(state_dict, '_metadata'):
|
62 |
+
del state_dict._metadata
|
63 |
+
|
64 |
+
# filter out unnecessary keys
|
65 |
+
pretrained_dict = state_dict['model']
|
66 |
+
pretrained_dict = {k[5:]: v for k, v in pretrained_dict.items() if k.split('.')[0] == 'base'}
|
67 |
+
|
68 |
+
# load the pretrained state dict
|
69 |
+
self.base.load_state_dict(pretrained_dict)
|
70 |
+
|
71 |
+
|
72 |
+
class DRNSub(nn.Module):
|
73 |
+
def __init__(self, num_classes, pretrained_model=None, fix_base=False):
|
74 |
+
super(DRNSub, self).__init__()
|
75 |
+
|
76 |
+
drnseg = DRNSeg(2)
|
77 |
+
if pretrained_model:
|
78 |
+
print("loading the pretrained drn model from %s" % pretrained_model)
|
79 |
+
state_dict = torch.load(pretrained_model, map_location='cpu')
|
80 |
+
drnseg.load_state_dict(state_dict['model'])
|
81 |
+
|
82 |
+
self.base = drnseg.base
|
83 |
+
if fix_base:
|
84 |
+
for param in self.base.parameters():
|
85 |
+
param.requires_grad = False
|
86 |
+
|
87 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
88 |
+
self.fc = nn.Linear(512, num_classes)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.base(x)
|
92 |
+
x = self.avgpool(x)
|
93 |
+
x = x.view(x.size(0), -1)
|
94 |
+
x = self.fc(x)
|
95 |
+
return x
|
pipeline.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from os import listdir
|
3 |
+
from os.path import isfile, join
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from math import floor
|
7 |
+
from scipy.ndimage.interpolation import zoom, rotate
|
8 |
+
|
9 |
+
import imageio
|
10 |
+
import cv2
|
11 |
+
from os.path import join
|
12 |
+
|
13 |
+
|
14 |
+
## Face extraction
|
15 |
+
|
16 |
+
class Video:
|
17 |
+
def __init__(self, path):
|
18 |
+
self.path = path
|
19 |
+
self.container = imageio.get_reader(path, 'ffmpeg')
|
20 |
+
self.length = self.container.count_frames()
|
21 |
+
self.fps = self.container.get_meta_data()['fps']
|
22 |
+
|
23 |
+
def init_head(self):
|
24 |
+
self.container.set_image_index(0)
|
25 |
+
|
26 |
+
def next_frame(self):
|
27 |
+
self.container.get_next_data()
|
28 |
+
|
29 |
+
def get(self, key):
|
30 |
+
return self.container.get_data(key)
|
31 |
+
|
32 |
+
def __call__(self, key):
|
33 |
+
return self.get(key)
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return self.length
|
37 |
+
|
38 |
+
|
39 |
+
class FaceFinder(Video):
|
40 |
+
def __init__(self, path, load_first_face=True):
|
41 |
+
super().__init__(path)
|
42 |
+
self.faces = {}
|
43 |
+
self.coordinates = {} # stores the face (locations center, rotation, length)
|
44 |
+
self.last_frame = self.get(0)
|
45 |
+
self.frame_shape = self.last_frame.shape[:2]
|
46 |
+
self.last_location = (0, 200, 200, 0)
|
47 |
+
|
48 |
+
# Initialize OpenCV's Haar Cascade for face detection
|
49 |
+
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
50 |
+
|
51 |
+
if load_first_face:
|
52 |
+
face_positions = self.detect_faces(self.last_frame)
|
53 |
+
if len(face_positions) > 0:
|
54 |
+
self.last_location = self.expand_location_zone(face_positions[0])
|
55 |
+
|
56 |
+
def detect_faces(self, frame, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)):
|
57 |
+
"""Detect faces using Haar Cascade."""
|
58 |
+
gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
59 |
+
faces = self.face_cascade.detectMultiScale(gray_frame, scaleFactor=scaleFactor, minNeighbors=minNeighbors, minSize=minSize)
|
60 |
+
return faces
|
61 |
+
|
62 |
+
def expand_location_zone(self, loc, margin=0.2):
|
63 |
+
"""Adds a margin around a frame slice."""
|
64 |
+
x, y, w, h = loc
|
65 |
+
offset_x = round(margin * w)
|
66 |
+
offset_y = round(margin * h)
|
67 |
+
y0 = max(y - offset_y, 0)
|
68 |
+
x1 = min(x + w + offset_x, self.frame_shape[1])
|
69 |
+
y1 = min(y + h + offset_y, self.frame_shape[0])
|
70 |
+
x0 = max(x - offset_x, 0)
|
71 |
+
return (y0, x1, y1, x0)
|
72 |
+
|
73 |
+
def find_faces(self, resize=0.5, stop=0, skipstep=0, cut_left=0, cut_right=-1):
|
74 |
+
"""The core function to extract faces from frames."""
|
75 |
+
# Frame iteration setup
|
76 |
+
if stop != 0:
|
77 |
+
finder_frameset = range(0, min(self.length, stop), skipstep + 1)
|
78 |
+
else:
|
79 |
+
finder_frameset = range(0, self.length, skipstep + 1)
|
80 |
+
|
81 |
+
# Loop through frames
|
82 |
+
for i in finder_frameset:
|
83 |
+
frame = self.get(i)
|
84 |
+
if cut_left != 0 or cut_right != -1:
|
85 |
+
frame[:, :cut_left] = 0
|
86 |
+
frame[:, cut_right:] = 0
|
87 |
+
|
88 |
+
# Detect faces in the current frame
|
89 |
+
face_positions = self.detect_faces(frame)
|
90 |
+
if len(face_positions) > 0:
|
91 |
+
# Use the largest detected face
|
92 |
+
largest_face = max(face_positions, key=lambda f: f[2] * f[3])
|
93 |
+
self.faces[i] = self.expand_location_zone(largest_face)
|
94 |
+
self.last_location = self.faces[i]
|
95 |
+
else:
|
96 |
+
print(f"No face detected in frame {i}")
|
97 |
+
|
98 |
+
print(f"Face extraction completed: {len(self.faces)} faces detected.")
|
99 |
+
|
100 |
+
def get_face(self, i):
|
101 |
+
"""Extract the face region for the given frame index."""
|
102 |
+
frame = self.get(i)
|
103 |
+
if i in self.faces:
|
104 |
+
y0, x1, y1, x0 = self.faces[i]
|
105 |
+
return frame[y0:y1, x0:x1]
|
106 |
+
return frame
|
107 |
+
|
108 |
+
## Face prediction
|
109 |
+
|
110 |
+
class FaceBatchGenerator:
|
111 |
+
'''
|
112 |
+
Made to deal with framesubsets of video.
|
113 |
+
'''
|
114 |
+
def __init__(self, face_finder, target_size = 256):
|
115 |
+
self.finder = face_finder
|
116 |
+
self.target_size = target_size
|
117 |
+
self.head = 0
|
118 |
+
self.length = int(face_finder.length)
|
119 |
+
|
120 |
+
def resize_patch(self, patch):
|
121 |
+
m, n = patch.shape[:2]
|
122 |
+
return zoom(patch, (self.target_size / m, self.target_size / n, 1))
|
123 |
+
|
124 |
+
def next_batch(self, batch_size = 50):
|
125 |
+
batch = np.zeros((1, self.target_size, self.target_size, 3))
|
126 |
+
stop = min(self.head + batch_size, self.length)
|
127 |
+
i = 0
|
128 |
+
while (i < batch_size) and (self.head < self.length):
|
129 |
+
if self.head in self.finder.coordinates:
|
130 |
+
patch = self.finder.get_aligned_face(self.head)
|
131 |
+
batch = np.concatenate((batch, np.expand_dims(self.resize_patch(patch), axis = 0)),
|
132 |
+
axis = 0)
|
133 |
+
i += 1
|
134 |
+
self.head += 1
|
135 |
+
return batch[1:]
|
136 |
+
|
137 |
+
|
138 |
+
def predict_faces(generator, classifier, batch_size = 50, output_size = 1):
|
139 |
+
'''
|
140 |
+
Compute predictions for a face batch generator
|
141 |
+
'''
|
142 |
+
n = len(generator.finder.coordinates.items())
|
143 |
+
profile = np.zeros((1, output_size))
|
144 |
+
for epoch in range(n // batch_size + 1):
|
145 |
+
face_batch = generator.next_batch(batch_size = batch_size)
|
146 |
+
prediction = classifier.predict(face_batch)
|
147 |
+
if (len(prediction) > 0):
|
148 |
+
profile = np.concatenate((profile, prediction))
|
149 |
+
return profile[1:]
|
150 |
+
|
151 |
+
|
152 |
+
def compute_accuracy(classifier, dirname, frame_subsample_count = 30):
|
153 |
+
'''
|
154 |
+
Extraction + Prediction over a video
|
155 |
+
'''
|
156 |
+
filenames = [f for f in listdir(dirname) if isfile(join(dirname, f)) and ((f[-4:] == '.mp4') or (f[-4:] == '.avi') or (f[-4:] == '.mov'))]
|
157 |
+
predictions = {}
|
158 |
+
|
159 |
+
for vid in filenames:
|
160 |
+
print('Dealing with video ', vid)
|
161 |
+
|
162 |
+
# Compute face locations and store them in the face finder
|
163 |
+
face_finder = FaceFinder(join(dirname, vid), load_first_face = False)
|
164 |
+
skipstep = max(floor(face_finder.length / frame_subsample_count), 0)
|
165 |
+
face_finder.find_faces(resize=0.5, skipstep = skipstep)
|
166 |
+
|
167 |
+
print('Predicting ', vid)
|
168 |
+
gen = FaceBatchGenerator(face_finder)
|
169 |
+
p = predict_faces(gen, classifier)
|
170 |
+
|
171 |
+
predictions[vid[:-4]] = (np.mean(p > 0.5), p)
|
172 |
+
return predictions
|
requirements.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py
|
2 |
+
altair
|
3 |
+
beautifulsoup4
|
4 |
+
cachetools
|
5 |
+
certifi
|
6 |
+
charset-normalizer
|
7 |
+
click
|
8 |
+
decorator
|
9 |
+
ffmpeg
|
10 |
+
flatbuffers
|
11 |
+
fsspec
|
12 |
+
gdown
|
13 |
+
gitpython
|
14 |
+
grpcio
|
15 |
+
h5py
|
16 |
+
huggingface-hub
|
17 |
+
idna
|
18 |
+
jinja2
|
19 |
+
jsonschema
|
20 |
+
keras
|
21 |
+
matplotlib
|
22 |
+
numpy
|
23 |
+
opencv-python
|
24 |
+
packaging
|
25 |
+
pandas
|
26 |
+
pillow
|
27 |
+
protobuf
|
28 |
+
pytz
|
29 |
+
PyYAML
|
30 |
+
requests
|
31 |
+
scipy
|
32 |
+
streamlit
|
33 |
+
tensorboard
|
34 |
+
tensorflow
|
35 |
+
torch
|
36 |
+
torchaudio
|
37 |
+
torchvision
|
38 |
+
tqdm
|
39 |
+
typing_extensions
|
40 |
+
urllib3
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (143 Bytes). View file
|
|
utils/__pycache__/preprocessing.cpython-312.pyc
ADDED
Binary file (4.33 kB). View file
|
|
utils/__pycache__/tools.cpython-312.pyc
ADDED
Binary file (6.62 kB). View file
|
|
utils/__pycache__/visualization.cpython-312.pyc
ADDED
Binary file (1.07 kB). View file
|
|
utils/__pycache__/visualize.cpython-312.pyc
ADDED
Binary file (3.81 kB). View file
|
|
utils/download_weights.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gdown
|
3 |
+
|
4 |
+
def download_weights():
|
5 |
+
"""
|
6 |
+
Downloads the required model weights into the 'models/weights/' directory.
|
7 |
+
"""
|
8 |
+
|
9 |
+
# Directory for storing weights
|
10 |
+
output_dir = "models/weights"
|
11 |
+
os.makedirs(output_dir, exist_ok=True)
|
12 |
+
|
13 |
+
# URLs for the weights
|
14 |
+
weights = {
|
15 |
+
"MesoNet": {
|
16 |
+
"Meso4_DF": "https://github.com/DariusAf/MesoNet/raw/master/weights/Meso4_DF.h5",
|
17 |
+
"MesoInception_DF": "https://github.com/DariusAf/MesoNet/raw/master/weights/MesoInception_DF.h5",
|
18 |
+
},
|
19 |
+
"EfficientNet (DFDC)": {
|
20 |
+
"EfficientNet-B0": "https://drive.google.com/uc?id=1LqRbCDNf9Ob7DFexCtE230FW6hhtLw0M",
|
21 |
+
},
|
22 |
+
"FALdetector": {
|
23 |
+
"global": "https://www.dropbox.com/s/rb8zpvrbxbbutxc/global.pth?dl=0",
|
24 |
+
"local": "https://www.dropbox.com/s/pby9dhpr6cqziyl/local.pth?dl=0",
|
25 |
+
},
|
26 |
+
|
27 |
+
"Vision Transformer (CViT)": {
|
28 |
+
"CViT": "https://github.com/erprogs/CViT/blob/main/weight/deepdeepfake_cvit_gpu_ep50.pkl",
|
29 |
+
},
|
30 |
+
}
|
31 |
+
|
32 |
+
# Download each weight file
|
33 |
+
for model_name, files in weights.items():
|
34 |
+
print(f"Downloading weights for {model_name}...")
|
35 |
+
for weight_name, url in files.items():
|
36 |
+
output_path = os.path.join(output_dir, f"{weight_name}.pth")
|
37 |
+
if not os.path.exists(output_path):
|
38 |
+
print(f" - Downloading {weight_name}...")
|
39 |
+
gdown.download(url, output_path, quiet=False)
|
40 |
+
else:
|
41 |
+
print(f" - {weight_name} already exists. Skipping download.")
|
42 |
+
print("All weights downloaded successfully.")
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
download_weights()
|
utils/preprocessing.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def preprocess_image(image):
|
8 |
+
"""
|
9 |
+
Preprocesses an input image for prediction.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
image (Union[str, numpy.ndarray]): File path to an image or a numpy array.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
PIL.Image: Preprocessed image.
|
16 |
+
"""
|
17 |
+
if isinstance(image, str): # Handle file path
|
18 |
+
image = Image.open(image).convert("RGB")
|
19 |
+
elif isinstance(image, np.ndarray): # Handle numpy array
|
20 |
+
image = Image.fromarray(image).convert("RGB")
|
21 |
+
else:
|
22 |
+
raise ValueError("Unsupported image type. Must be a file path or numpy array.")
|
23 |
+
return image
|
24 |
+
|
25 |
+
|
26 |
+
def resize_shorter_side(img, min_length):
|
27 |
+
"""
|
28 |
+
Resize the shorter side of img to min_length while
|
29 |
+
preserving the aspect ratio.
|
30 |
+
"""
|
31 |
+
ow, oh = img.size
|
32 |
+
mult = 8
|
33 |
+
if ow < oh:
|
34 |
+
if ow == min_length and oh % mult == 0:
|
35 |
+
return img, (ow, oh)
|
36 |
+
w = min_length
|
37 |
+
h = int(min_length * oh / ow)
|
38 |
+
else:
|
39 |
+
if oh == min_length and ow % mult == 0:
|
40 |
+
return img, (ow, oh)
|
41 |
+
h = min_length
|
42 |
+
w = int(min_length * ow / oh)
|
43 |
+
return img.resize((w, h), Image.BICUBIC), (w, h)
|
44 |
+
|
45 |
+
|
46 |
+
def generate_local_image(image):
|
47 |
+
"""
|
48 |
+
Detects the face in the input image and extracts it as a 'local image'.
|
49 |
+
If no face is detected, returns the global image as the local image.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image (Union[PIL.Image, numpy.ndarray]): The input image.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
PIL.Image: The cropped face or the original image if no face is detected.
|
56 |
+
"""
|
57 |
+
# Convert numpy array to PIL.Image if necessary
|
58 |
+
if isinstance(image, np.ndarray):
|
59 |
+
image = Image.fromarray(image)
|
60 |
+
|
61 |
+
# Convert PIL image to OpenCV format for face detection
|
62 |
+
image_cv = np.array(image)
|
63 |
+
image_gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY)
|
64 |
+
|
65 |
+
# Load OpenCV's pre-trained Haar Cascade for face detection
|
66 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
67 |
+
faces = face_cascade.detectMultiScale(image_gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
68 |
+
|
69 |
+
if len(faces) == 0:
|
70 |
+
print("No face detected. Using the global image as the local image.")
|
71 |
+
return image # Return the global image as fallback
|
72 |
+
|
73 |
+
# Use the first detected face
|
74 |
+
x, y, w, h = faces[0]
|
75 |
+
|
76 |
+
# Crop the face region
|
77 |
+
face_image = image.crop((x, y, x + w, y + h))
|
78 |
+
return face_image
|
79 |
+
|
80 |
+
|
81 |
+
def preprocess_video(video_path, frame_count=32):
|
82 |
+
cap = cv2.VideoCapture(video_path)
|
83 |
+
if not cap.isOpened():
|
84 |
+
return [] # Return an empty list if video can't be opened
|
85 |
+
|
86 |
+
frames = []
|
87 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
88 |
+
if total_frames == 0:
|
89 |
+
return [] # Handle videos with no frames
|
90 |
+
|
91 |
+
interval = max(1, total_frames // frame_count)
|
92 |
+
for i in range(frame_count):
|
93 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
|
94 |
+
ret, frame = cap.read()
|
95 |
+
if ret:
|
96 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
97 |
+
cap.release()
|
98 |
+
return frames
|
utils/tools.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
# from dlib import cnn_face_detection_model_v1 as face_detect_model
|
7 |
+
from utils.preprocessing import generate_local_image as face_detect_model
|
8 |
+
|
9 |
+
|
10 |
+
def center_crop(im, length):
|
11 |
+
w, h = im.size
|
12 |
+
left = w//2 - length//2
|
13 |
+
right = w//2 + length//2
|
14 |
+
top = h//2 - length//2
|
15 |
+
bottom = h//2 + length//2
|
16 |
+
return im.crop((left, top, right, bottom)), (left, top)
|
17 |
+
|
18 |
+
|
19 |
+
def remove_boundary(img):
|
20 |
+
"""
|
21 |
+
Remove boundary artifacts that FAL causes.
|
22 |
+
"""
|
23 |
+
w, h = img.size
|
24 |
+
left = w//80
|
25 |
+
top = h//50
|
26 |
+
right = w*79//80
|
27 |
+
bottom = h*24//25
|
28 |
+
return img.crop((left, top, right, bottom))
|
29 |
+
|
30 |
+
|
31 |
+
def resize_shorter_side(img, min_length):
|
32 |
+
"""
|
33 |
+
Resize the shorter side of img to min_length while
|
34 |
+
preserving the aspect ratio.
|
35 |
+
"""
|
36 |
+
ow, oh = img.size
|
37 |
+
mult = 8
|
38 |
+
if ow < oh:
|
39 |
+
if ow == min_length and oh % mult == 0:
|
40 |
+
return img, (ow, oh)
|
41 |
+
w = min_length
|
42 |
+
h = int(min_length * oh / ow)
|
43 |
+
else:
|
44 |
+
if oh == min_length and ow % mult == 0:
|
45 |
+
return img, (ow, oh)
|
46 |
+
h = min_length
|
47 |
+
w = int(min_length * ow / oh)
|
48 |
+
return img.resize((w, h), Image.BICUBIC), (w, h)
|
49 |
+
|
50 |
+
|
51 |
+
def flow_resize(flow, sz):
|
52 |
+
oh, ow, _ = flow.shape
|
53 |
+
w, h = sz
|
54 |
+
u_ = cv2.resize(flow[:,:,0], (w, h))
|
55 |
+
v_ = cv2.resize(flow[:,:,1], (w, h))
|
56 |
+
u_ *= w / float(ow)
|
57 |
+
v_ *= h / float(oh)
|
58 |
+
return np.dstack((u_,v_))
|
59 |
+
|
60 |
+
|
61 |
+
def warp(im, flow, alpha=1, interp=cv2.INTER_CUBIC):
|
62 |
+
height, width, _ = flow.shape
|
63 |
+
cart = np.dstack(np.meshgrid(np.arange(width), np.arange(height)))
|
64 |
+
pixel_map = (cart + alpha * flow).astype(np.float32)
|
65 |
+
warped = cv2.remap(
|
66 |
+
im,
|
67 |
+
pixel_map[:, :, 0],
|
68 |
+
pixel_map[:, :, 1],
|
69 |
+
interp,
|
70 |
+
borderMode=cv2.BORDER_REPLICATE)
|
71 |
+
return warped
|
72 |
+
|
73 |
+
|
74 |
+
cnn_face_detector = None
|
75 |
+
def face_detection(
|
76 |
+
img_path,
|
77 |
+
verbose=False,
|
78 |
+
model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'):
|
79 |
+
"""
|
80 |
+
Detects faces using dlib cnn face detection, and extend the bounding box
|
81 |
+
to include the entire face.
|
82 |
+
"""
|
83 |
+
def shrink(img, max_length=2048):
|
84 |
+
ow, oh = img.size
|
85 |
+
if max_length >= max(ow, oh):
|
86 |
+
return img, 1.0
|
87 |
+
|
88 |
+
if ow > oh:
|
89 |
+
mult = max_length / ow
|
90 |
+
else:
|
91 |
+
mult = max_length / oh
|
92 |
+
w = int(ow * mult)
|
93 |
+
h = int(oh * mult)
|
94 |
+
return img.resize((w, h), Image.BILINEAR), mult
|
95 |
+
|
96 |
+
global cnn_face_detector
|
97 |
+
if cnn_face_detector is None:
|
98 |
+
cnn_face_detector = face_detect_model(model_file)
|
99 |
+
|
100 |
+
img = Image.open(img_path).convert('RGB')
|
101 |
+
w, h = img.size
|
102 |
+
img_shrinked, mult = shrink(img)
|
103 |
+
|
104 |
+
im = np.asarray(img_shrinked)
|
105 |
+
if len(im.shape) != 3 or im.shape[2] != 3:
|
106 |
+
return []
|
107 |
+
|
108 |
+
crop_ims = []
|
109 |
+
dets = cnn_face_detector(im, 0)
|
110 |
+
for k, d in enumerate(dets):
|
111 |
+
top = d.rect.top() / mult
|
112 |
+
bottom = d.rect.bottom() / mult
|
113 |
+
left = d.rect.left() / mult
|
114 |
+
right = d.rect.right() / mult
|
115 |
+
|
116 |
+
wid = right - left
|
117 |
+
left = max(0, left - wid // 2.5)
|
118 |
+
top = max(0, top - wid // 1.5)
|
119 |
+
right = min(w - 1, right + wid // 2.5)
|
120 |
+
bottom = min(h - 1, bottom + wid // 2.5)
|
121 |
+
|
122 |
+
if d.confidence > 1:
|
123 |
+
if verbose:
|
124 |
+
print("%d-th face detected: (%d, %d, %d, %d)" %
|
125 |
+
(k, left, top, right, bottom))
|
126 |
+
crop_im = img.crop((left, top, right, bottom))
|
127 |
+
crop_ims.append((crop_im, (left, top, right, bottom)))
|
128 |
+
|
129 |
+
return crop_ims
|
130 |
+
|
131 |
+
|
132 |
+
def mkdirs(paths):
|
133 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
134 |
+
for path in paths:
|
135 |
+
mkdir(path)
|
136 |
+
else:
|
137 |
+
mkdir(paths)
|
138 |
+
|
139 |
+
|
140 |
+
def mkdir(path):
|
141 |
+
if not os.path.exists(path):
|
142 |
+
os.makedirs(path)
|
143 |
+
|
utils/visualization.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
def display_results(results):
|
6 |
+
st.write("### Detection Results")
|
7 |
+
df = pd.DataFrame(results.items(), columns=["Model", "Probability (%)"])
|
8 |
+
st.table(df)
|
9 |
+
|
10 |
+
st.write("### Visualization")
|
11 |
+
fig, ax = plt.subplots()
|
12 |
+
df.plot.bar(x="Model", y="Probability (%)", ax=ax, legend=False)
|
13 |
+
ax.set_ylabel("Probability (%)")
|
14 |
+
st.pyplot(fig)
|
utils/visualize.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
10 |
+
# assume tensor of shape NxCxHxW
|
11 |
+
return tens * torch.Tensor(std)[None, :, None, None] + torch.Tensor(
|
12 |
+
mean)[None, :, None, None]
|
13 |
+
|
14 |
+
|
15 |
+
def get_heatmap_cv(img, magn, max_flow_mag):
|
16 |
+
min_flow_mag = .5
|
17 |
+
cv_magn = np.clip(
|
18 |
+
255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag),
|
19 |
+
a_min=0,
|
20 |
+
a_max=255).astype(np.uint8)
|
21 |
+
if img.dtype != np.uint8:
|
22 |
+
img = (255 * img).astype(np.uint8)
|
23 |
+
|
24 |
+
heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET)
|
25 |
+
heatmap_img = heatmap_img[..., ::-1]
|
26 |
+
|
27 |
+
h, w = magn.shape
|
28 |
+
img_alpha = np.ones((h, w), dtype=np.double)[:, :, None]
|
29 |
+
heatmap_alpha = np.clip(
|
30 |
+
magn / max_flow_mag, a_min=0, a_max=1)[:, :, None]**.7
|
31 |
+
heatmap_alpha[heatmap_alpha < .2]**.5
|
32 |
+
pm_hm = heatmap_img * heatmap_alpha
|
33 |
+
pm_img = img * img_alpha
|
34 |
+
cv_out = pm_hm + pm_img * (1 - heatmap_alpha)
|
35 |
+
cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8)
|
36 |
+
|
37 |
+
return cv_out
|
38 |
+
|
39 |
+
|
40 |
+
def get_heatmap_batch(img_batch, pred_batch):
|
41 |
+
imgrid = torchvision.utils.make_grid(img_batch).cpu()
|
42 |
+
magn_batch = torch.norm(pred_batch, p=2, dim=1, keepdim=True)
|
43 |
+
magngrid = torchvision.utils.make_grid(magn_batch)
|
44 |
+
magngrid = magngrid[0, :, :]
|
45 |
+
imgrid = unnormalize(imgrid).squeeze_()
|
46 |
+
|
47 |
+
cv_magn = magngrid.detach().cpu().numpy()
|
48 |
+
cv_img = imgrid.permute(1, 2, 0).detach().cpu().numpy()
|
49 |
+
cv_out = get_heatmap_cv(cv_img, cv_magn, max_flow_mag=9)
|
50 |
+
out = np.asarray(cv_out).astype(np.double) / 255.0
|
51 |
+
|
52 |
+
out = torch.from_numpy(out).permute(2, 0, 1)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
def save_heatmap_cv(img, magn, path, max_flow_mag=7):
|
57 |
+
cv_out = get_heatmap_cv(img, magn, max_flow_mag)
|
58 |
+
out = Image.fromarray(cv_out)
|
59 |
+
out.save(path, quality=95)
|
60 |
+
|
61 |
+
|