bawolf commited on
Commit
7aa93af
·
1 Parent(s): 744ba87

add ignored datset files

Browse files
script/inference.py CHANGED
@@ -6,7 +6,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
6
 
7
  from src.utils.utils import get_latest_run_dir, get_latest_model_path, get_config
8
  from src.models.model import load_model
9
- from src.data.video_utils import create_transform, extract_frames
10
 
11
  def setup_model(run_dir=None):
12
  """Setup model and configuration"""
 
6
 
7
  from src.utils.utils import get_latest_run_dir, get_latest_model_path, get_config
8
  from src.models.model import load_model
9
+ from src.dataset.video_utils import create_transform, extract_frames
10
 
11
  def setup_model(run_dir=None):
12
  """Setup model and configuration"""
script/train.py CHANGED
@@ -12,9 +12,9 @@ import sys
12
  sys.path.append(os.path.dirname(os.path.dirname(__file__)))
13
 
14
  from src.utils.utils import create_run_directory
15
- from src.data.dataset import VideoDataset
16
  from src.models.model import create_model
17
- from src.data.video_utils import create_transform
18
 
19
  def train_and_evaluate(config):
20
  # Create a run directory if it doesn't exist
@@ -228,11 +228,11 @@ def main():
228
  config = {
229
  "class_labels": class_labels,
230
  "num_classes": len(class_labels),
231
- "data_path": '../finetune/3moves_otherpeopletrain',
232
  "batch_size": 32,
233
  "learning_rate": 2e-6,
234
  "weight_decay": 0.007,
235
- "num_epochs": 1,
236
  "patience": 10, # for early stopping
237
  "max_frames": 10,
238
  "sigma": 0.3,
 
12
  sys.path.append(os.path.dirname(os.path.dirname(__file__)))
13
 
14
  from src.utils.utils import create_run_directory
15
+ from src.dataset.dataset import VideoDataset
16
  from src.models.model import create_model
17
+ from src.dataset.video_utils import create_transform
18
 
19
  def train_and_evaluate(config):
20
  # Create a run directory if it doesn't exist
 
228
  config = {
229
  "class_labels": class_labels,
230
  "num_classes": len(class_labels),
231
+ "data_path": '../finetune/3moves_otherpeopleval',
232
  "batch_size": 32,
233
  "learning_rate": 2e-6,
234
  "weight_decay": 0.007,
235
+ "num_epochs": 50,
236
  "patience": 10, # for early stopping
237
  "max_frames": 10,
238
  "sigma": 0.3,
script/visualization/visualize.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  import sys
10
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
11
 
12
- from src.data.dataset import VideoDataset
13
  from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config
14
  from src.models.model import load_model
15
 
 
9
  import sys
10
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
11
 
12
+ from src.dataset.dataset import VideoDataset
13
  from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config
14
  from src.models.model import load_model
15
 
src/dataset/dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import csv
4
+ from .video_utils import create_transform, extract_frames
5
+
6
+ class VideoDataset(Dataset):
7
+ def __init__(self, file_path, config, transform=None):
8
+ self.data = []
9
+ self.label_map = {}
10
+ # Use create_transform if no custom transform is provided
11
+ self.transform = transform or create_transform(config)
12
+
13
+ # Validate required config keys
14
+ required_keys = {"max_frames", "sigma", "class_labels"}
15
+ missing_keys = required_keys - set(config.keys())
16
+ if missing_keys:
17
+ raise ValueError(f"Missing required config keys: {missing_keys}")
18
+
19
+ self.max_frames = config['max_frames']
20
+ self.sigma = config['sigma']
21
+
22
+ # Create label map from class_labels list
23
+ self.label_map = {i: label for i, label in enumerate(config['class_labels'])}
24
+
25
+ # Read the CSV file and parse the data
26
+ with open(file_path, 'r') as file:
27
+ csv_reader = csv.reader(file)
28
+ for row in csv_reader:
29
+ if len(row) != 2:
30
+ print(f"Skipping invalid row: {row}")
31
+ continue
32
+ video_path, label = row
33
+ try:
34
+ label = int(label)
35
+ except ValueError:
36
+ print(f"Skipping row with invalid label: {row}")
37
+ continue
38
+ self.data.append((video_path, label))
39
+
40
+ if not self.data:
41
+ raise ValueError(f"No valid data found in the CSV file: {file_path}")
42
+
43
+ def __len__(self):
44
+ return len(self.data)
45
+
46
+ def __getitem__(self, idx):
47
+ video_path, label = self.data[idx]
48
+
49
+ frames, success = extract_frames(video_path,
50
+ {"max_frames": self.max_frames, "sigma": self.sigma},
51
+ self.transform)
52
+
53
+ if not success:
54
+ frames = self._get_error_tensor()
55
+
56
+ return frames, label, video_path
57
+
58
+ def _get_error_tensor(self):
59
+ return torch.zeros((self.max_frames, 3, 224, 224))
src/dataset/video_utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms
5
+ from scipy.stats import norm
6
+ import os
7
+
8
+ def create_transform(config, training=False):
9
+ """Create transform pipeline based on config"""
10
+ # Validate base required keys
11
+ required_keys = {
12
+ "image_size",
13
+ "normalization_mean",
14
+ "normalization_std"
15
+ }
16
+
17
+ # Add training-specific required keys
18
+ if training:
19
+ required_keys.update({
20
+ "flip_probability",
21
+ "rotation_degrees",
22
+ "brightness_jitter",
23
+ "contrast_jitter",
24
+ "saturation_jitter",
25
+ "hue_jitter",
26
+ "crop_scale_min",
27
+ "crop_scale_max"
28
+ })
29
+
30
+ missing_keys = required_keys - set(config.keys())
31
+ if missing_keys:
32
+ raise ValueError(f"Missing required config keys: {missing_keys}")
33
+
34
+ # Build transform list
35
+ transform_list = [
36
+ transforms.ToPILImage(),
37
+ transforms.Resize((config["image_size"], config["image_size"]))
38
+ ]
39
+
40
+ # Add training augmentations if needed
41
+ if training:
42
+ transform_list.extend([
43
+ transforms.RandomHorizontalFlip(p=config["flip_probability"]),
44
+ transforms.RandomRotation(config["rotation_degrees"]),
45
+ transforms.ColorJitter(
46
+ brightness=config["brightness_jitter"],
47
+ contrast=config["contrast_jitter"],
48
+ saturation=config["saturation_jitter"],
49
+ hue=config["hue_jitter"]
50
+ ),
51
+ transforms.RandomResizedCrop(
52
+ config["image_size"],
53
+ scale=(config["crop_scale_min"], config["crop_scale_max"])
54
+ )
55
+ ])
56
+
57
+ # Add final transforms
58
+ transform_list.extend([
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(
61
+ mean=config["normalization_mean"],
62
+ std=config["normalization_std"]
63
+ )
64
+ ])
65
+
66
+ return transforms.Compose(transform_list)
67
+
68
+ def extract_frames(video_path: str, config: dict, transform) -> tuple[torch.Tensor, bool]:
69
+ """Extract and process frames from video using Gaussian sampling
70
+ Returns:
71
+ tuple: (frames tensor, success boolean)
72
+ """
73
+ # Validate required config keys
74
+ required_keys = {"max_frames", "sigma"}
75
+ missing_keys = required_keys - set(config.keys())
76
+ if missing_keys:
77
+ raise ValueError(f"Missing required config keys for frame extraction: {missing_keys}")
78
+
79
+ frames = []
80
+ success = True
81
+
82
+ if not os.path.exists(video_path):
83
+ print(f"File not found: {video_path}")
84
+ return None, False
85
+
86
+ cap = cv2.VideoCapture(video_path)
87
+ if not cap.isOpened():
88
+ print(f"Failed to open video: {video_path}")
89
+ return None, False
90
+
91
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
92
+ if total_frames == 0:
93
+ print(f"Video has no frames: {video_path}")
94
+ cap.release()
95
+ return None, False
96
+
97
+ # Create a normal distribution centered at the middle of the video
98
+ x = np.linspace(0, 1, total_frames)
99
+ probabilities = norm.pdf(x, loc=0.5, scale=config["sigma"])
100
+ probabilities /= probabilities.sum()
101
+
102
+ # Sample frame indices based on this distribution
103
+ frame_indices = np.sort(np.random.choice(
104
+ total_frames,
105
+ size=min(config["max_frames"], total_frames),
106
+ replace=False,
107
+ p=probabilities
108
+ ))
109
+
110
+ for frame_idx in frame_indices:
111
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
112
+ ret, frame = cap.read()
113
+ if not ret:
114
+ print(f"Failed to read frame {frame_idx} from video: {video_path}")
115
+ success = False
116
+ break
117
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
118
+ if transform:
119
+ frame = transform(frame)
120
+ frames.append(frame)
121
+
122
+ cap.release()
123
+
124
+ if not frames:
125
+ print(f"No frames extracted from video: {video_path}")
126
+ return None, False
127
+
128
+ # Pad with zeros if we don't have enough frames
129
+ while len(frames) < config["max_frames"]:
130
+ frames.append(torch.zeros_like(frames[0]))
131
+
132
+ return torch.stack(frames), success