LEGIONM36 commited on
Commit
d31a75f
·
verified ·
1 Parent(s): 84e0164

Upload 4 files

Browse files
Files changed (4) hide show
  1. best_model_fusion.pth +3 -0
  2. model.py +92 -0
  3. readme.md +20 -0
  4. train.py +230 -0
best_model_fusion.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:594c76c254dd74e4ce7bb8e051394c9991c40ceaf637570b9c6de9d4f9482134
3
+ size 139372555
model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models.video as models
4
+
5
+ class TimeSformerBlock(nn.Module):
6
+ def __init__(self, dim, num_heads, num_frames):
7
+ super().__init__()
8
+ self.norm1 = nn.LayerNorm(dim)
9
+ self.attn_time = nn.MultiheadAttention(dim, num_heads, batch_first=True)
10
+ self.norm2 = nn.LayerNorm(dim)
11
+ self.attn_space = nn.MultiheadAttention(dim, num_heads, batch_first=True)
12
+ self.norm3 = nn.LayerNorm(dim)
13
+ self.mlp = nn.Sequential(
14
+ nn.Linear(dim, dim * 4),
15
+ nn.GELU(),
16
+ nn.Linear(dim * 4, dim)
17
+ )
18
+ self.num_frames = num_frames
19
+
20
+ def forward(self, x):
21
+ B, TP, D = x.shape
22
+ T = self.num_frames
23
+ P = TP // T
24
+
25
+ # Temporal Attention
26
+ xt = x.view(B, T, P, D).permute(0, 2, 1, 3).reshape(B * P, T, D)
27
+ xt_res = xt
28
+ xt = self.norm1(xt)
29
+ xt, _ = self.attn_time(xt, xt, xt)
30
+ xt = xt + xt_res
31
+ x = xt.view(B, P, T, D).permute(0, 2, 1, 3).reshape(B, TP, D)
32
+
33
+ # Spatial Attention
34
+ xs = x.view(B, T, P, D).reshape(B * T, P, D)
35
+ xs_res = xs
36
+ xs = self.norm2(xs)
37
+ xs, _ = self.attn_space(xs, xs, xs)
38
+ xs = xs + xs_res
39
+ x = xs.view(B, T, P, D).reshape(B, TP, D)
40
+
41
+ x = x + self.mlp(self.norm3(x))
42
+ return x
43
+
44
+ class FeatureFusionNetwork(nn.Module):
45
+ def __init__(self):
46
+ super(FeatureFusionNetwork, self).__init__()
47
+
48
+ # Branch 1: Backbone CNN (ResNet3D)
49
+ self.cnn = models.r3d_18(weights=None)
50
+ self.cnn.fc = nn.Identity() # Output 512
51
+
52
+ # Branch 2: TimeSformer Backbone
53
+ self.patch_size = 16
54
+ self.embed_dim = 256
55
+ self.img_size = 112
56
+ self.num_patches = (self.img_size // self.patch_size) ** 2
57
+ self.num_frames = 16 # Default SEQ_LEN
58
+
59
+ self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
60
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
61
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_frames * self.num_patches + 1, self.embed_dim))
62
+
63
+ self.transformer_layer = TimeSformerBlock(self.embed_dim, num_heads=4, num_frames=self.num_frames)
64
+
65
+ self.fusion_fc = nn.Sequential(
66
+ nn.Linear(512 + self.embed_dim, 256),
67
+ nn.ReLU(),
68
+ nn.Dropout(0.5),
69
+ nn.Linear(256, 2)
70
+ )
71
+
72
+ def forward(self, x):
73
+ # CNN Pathway
74
+ cnn_feat = self.cnn(x) # (B, 512)
75
+
76
+ # Transformer Pathway
77
+ b, c, t, h, w = x.shape
78
+ x_uv = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
79
+ patches = self.patch_embed(x_uv).flatten(2).transpose(1, 2)
80
+ patches = patches.reshape(b, t * self.num_patches, self.embed_dim)
81
+
82
+ cls_tokens = self.cls_token.expand(b, -1, -1)
83
+ x_trans = torch.cat((cls_tokens, patches), dim=1)
84
+ x_trans = x_trans + self.pos_embed[:, :x_trans.size(1), :]
85
+
86
+ patch_tokens = x_trans[:, 1:, :]
87
+ out_patches = self.transformer_layer(patch_tokens)
88
+ trans_feat = out_patches.mean(dim=1) # (B, D)
89
+
90
+ combined = torch.cat((cnn_feat, trans_feat), dim=1)
91
+ out = self.fusion_fc(combined)
92
+ return out
readme.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Feature Fusion Network
2
+
3
+ ## Model Architecture
4
+ - **Type**: Multi-Modal Hybrid (CNN + Transformer)
5
+ - **Pathway 1 (Spatial)**: ResNet3D (r3d_18) for robust localized feature extraction.
6
+ - **Pathway 2 (Spatiotemporal)**: TimeSformer (Transformer) block dealing with patches and frames to capture long-range dependencies.
7
+ - **Fusion**: Late fusion via concatenation of flattened feature vectors (512 features from CNN + 256 features from Transformer).
8
+ - **Classification Head**: MLP mapping fused features to binary classes.
9
+
10
+ ## Dataset Structure
11
+ Expects `Dataset` folder in parent directory.
12
+ ```
13
+ Dataset/
14
+ ├── violence/
15
+ └── no-violence/
16
+ ```
17
+
18
+ ## How to Run
19
+ 1. Install dependencies: `torch`, `opencv-python`, `scikit-learn`, `numpy`, `torchvision`.
20
+ 2. Run `python train.py`.
train.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
10
+ import torchvision.models.video as models
11
+ import time
12
+ from model import FeatureFusionNetwork
13
+
14
+ # --- Configuration ---
15
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16
+ DATASET_DIR = os.path.join(BASE_DIR, "Dataset")
17
+ MODEL_SAVE_PATH = "best_model_fusion.pth"
18
+
19
+ IMG_SIZE = 112
20
+ SEQ_LEN = 16
21
+ BATCH_SIZE = 16
22
+ EPOCHS = 80
23
+ LEARNING_RATE = 1e-4
24
+ PATIENCE = 5
25
+
26
+ # --- Dataset ---
27
+ class StandardDataset(Dataset):
28
+ def __init__(self, video_paths, labels):
29
+ self.video_paths = video_paths
30
+ self.labels = labels
31
+
32
+ def __len__(self):
33
+ return len(self.video_paths)
34
+
35
+ def __getitem__(self, idx):
36
+ path = self.video_paths[idx]
37
+ label = self.labels[idx]
38
+
39
+ cap = cv2.VideoCapture(path)
40
+ frames = []
41
+ try:
42
+ while True:
43
+ ret, frame = cap.read()
44
+ if not ret: break
45
+ frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE))
46
+ frames.append(frame)
47
+ finally:
48
+ cap.release()
49
+
50
+ if len(frames) == 0:
51
+ frames = np.zeros((SEQ_LEN, IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
52
+ elif len(frames) < SEQ_LEN:
53
+ while len(frames) < SEQ_LEN: frames.append(frames[-1])
54
+ elif len(frames) > SEQ_LEN:
55
+ indices = np.linspace(0, len(frames)-1, SEQ_LEN, dtype=int)
56
+ frames = [frames[i] for i in indices]
57
+
58
+ frames = np.array(frames, dtype=np.float32) / 255.0
59
+ # (T, H, W, C) -> (C, T, H, W)
60
+ frames = torch.tensor(frames).permute(3, 0, 1, 2)
61
+ return frames, label
62
+
63
+ # --- Data Preparation ---
64
+ def prepare_data():
65
+ violence_dir = os.path.join(DATASET_DIR, 'violence')
66
+ no_violence_dir = os.path.join(DATASET_DIR, 'no-violence')
67
+
68
+ if not os.path.exists(violence_dir) or not os.path.exists(no_violence_dir):
69
+ raise FileNotFoundError("Dataset directories not found.")
70
+
71
+ violence_files = [os.path.join(violence_dir, f) for f in os.listdir(violence_dir) if f.endswith('.avi') or f.endswith('.mp4')]
72
+ no_violence_files = [os.path.join(no_violence_dir, f) for f in os.listdir(no_violence_dir) if f.endswith('.avi') or f.endswith('.mp4')]
73
+
74
+ X = violence_files + no_violence_files
75
+ y = [1] * len(violence_files) + [0] * len(no_violence_files)
76
+
77
+ X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=42, stratify=y)
78
+ X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp)
79
+
80
+ return (X_train, y_train), (X_val, y_val), (X_test, y_test)
81
+
82
+ # --- Early Stopping ---
83
+ class EarlyStopping:
84
+ def __init__(self, patience=5, verbose=False, path='checkpoint.pth'):
85
+ self.patience = patience
86
+ self.verbose = verbose
87
+ self.counter = 0
88
+ self.best_score = None
89
+ self.early_stop = False
90
+ self.val_loss_min = np.inf
91
+ self.path = path
92
+
93
+ def __call__(self, val_loss, model):
94
+ score = -val_loss
95
+ if self.best_score is None:
96
+ self.best_score = score
97
+ self.save_checkpoint(val_loss, model)
98
+ elif score < self.best_score:
99
+ self.counter += 1
100
+ if self.verbose:
101
+ print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
102
+ if self.counter >= self.patience:
103
+ self.early_stop = True
104
+ else:
105
+ self.best_score = score
106
+ self.save_checkpoint(val_loss, model)
107
+ self.counter = 0
108
+
109
+ def save_checkpoint(self, val_loss, model):
110
+ if self.verbose:
111
+ print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
112
+ torch.save(model, self.path) # FULL MODEL SAVE
113
+ self.val_loss_min = val_loss
114
+
115
+ if __name__ == "__main__":
116
+ start_time = time.time()
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ print(f"Using device: {device}")
120
+
121
+ try:
122
+ (X_train, y_train), (X_val, y_val), (X_test, y_test) = prepare_data()
123
+ print(f"Dataset Split Stats:")
124
+ print(f"Train: {len(X_train)} samples")
125
+ print(f"Val: {len(X_val)} samples")
126
+ print(f"Test: {len(X_test)} samples")
127
+ except Exception as e:
128
+ print(f"Data preparation failed: {e}")
129
+ exit(1)
130
+
131
+ train_dataset = StandardDataset(X_train, y_train)
132
+ val_dataset = StandardDataset(X_val, y_val)
133
+ test_dataset = StandardDataset(X_test, y_test)
134
+
135
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
136
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
137
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
138
+
139
+ model = FeatureFusionNetwork().to(device)
140
+ criterion = nn.CrossEntropyLoss()
141
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
142
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
143
+
144
+ early_stopping = EarlyStopping(patience=PATIENCE, verbose=True, path=MODEL_SAVE_PATH)
145
+
146
+ print("\nStarting Feature Fusion Training...")
147
+
148
+ for epoch in range(EPOCHS):
149
+ model.train()
150
+ train_loss = 0.0
151
+ correct = 0
152
+ total = 0
153
+
154
+ for batch_idx, (inputs, labels) in enumerate(train_loader):
155
+ inputs, labels = inputs.to(device), labels.to(device)
156
+
157
+ optimizer.zero_grad()
158
+ outputs = model(inputs)
159
+ loss = criterion(outputs, labels)
160
+ loss.backward()
161
+ optimizer.step()
162
+
163
+ train_loss += loss.item()
164
+ _, predicted = torch.max(outputs.data, 1)
165
+ total += labels.size(0)
166
+ correct += (predicted == labels).sum().item()
167
+
168
+ if batch_idx % 10 == 0:
169
+ print(f"Epoch {epoch+1} Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}", end='\r')
170
+
171
+ train_acc = 100 * correct / total
172
+ avg_train_loss = train_loss / len(train_loader)
173
+
174
+ model.eval()
175
+ val_loss = 0.0
176
+ correct_val = 0
177
+ total_val = 0
178
+
179
+ with torch.no_grad():
180
+ for inputs, labels in val_loader:
181
+ inputs, labels = inputs.to(device), labels.to(device)
182
+ outputs = model(inputs)
183
+ loss = criterion(outputs, labels)
184
+ val_loss += loss.item()
185
+ _, predicted = torch.max(outputs.data, 1)
186
+ total_val += labels.size(0)
187
+ correct_val += (predicted == labels).sum().item()
188
+
189
+ val_acc = 100 * correct_val / total_val
190
+ avg_val_loss = val_loss / len(val_loader)
191
+
192
+ print(f'\nEpoch [{epoch+1}/{EPOCHS}] '
193
+ f'Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.2f}% '
194
+ f'Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.2f}%')
195
+
196
+ scheduler.step(avg_val_loss)
197
+
198
+ early_stopping(avg_val_loss, model)
199
+ if early_stopping.early_stop:
200
+ print("Early stopping triggered")
201
+ break
202
+
203
+ print("\nLoading best Fusion model for evaluation...")
204
+ if os.path.exists(MODEL_SAVE_PATH):
205
+ model = torch.load(MODEL_SAVE_PATH)
206
+ else:
207
+ print("Warning: Model file not found.")
208
+
209
+ model.eval()
210
+ all_preds = []
211
+ all_labels = []
212
+
213
+ print("Evaluating on Test set...")
214
+ with torch.no_grad():
215
+ for inputs, labels in test_loader:
216
+ inputs, labels = inputs.to(device), labels.to(device)
217
+ outputs = model(inputs)
218
+ _, predicted = torch.max(outputs.data, 1)
219
+ all_preds.extend(predicted.cpu().numpy())
220
+ all_labels.extend(labels.cpu().numpy())
221
+
222
+ print("\n=== Feature Fusion Model Evaluation Report ===")
223
+ print(classification_report(all_labels, all_preds, target_names=['No Violence', 'Violence']))
224
+ print("Confusion Matrix:")
225
+ print(confusion_matrix(all_labels, all_preds))
226
+ acc = accuracy_score(all_labels, all_preds)
227
+ print(f"\nFinal Test Accuracy: {acc*100:.2f}%")
228
+
229
+ elapsed = time.time() - start_time
230
+ print(f"\nTotal execution time: {elapsed/60:.2f} minutes")