caliangandrew commited on
Commit
2a84f8d
·
verified ·
1 Parent(s): 087d6cf

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +14 -0
  2. logger.py +36 -0
  3. train_detector.py +460 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## UCF
2
+
3
+ This model has been adapted from [DeepfakeBench](https://github.com/SCLBD/DeepfakeBench).
4
+
5
+ ##
6
+
7
+ - **Train UCF model**:
8
+ - Use `train_ucf.py`, which will download necessary pretrained `xception` backbone weights from HuggingFace (if not present locally) and start a training job with logging outputs in `.logs/`.
9
+ - Customize the training job by editing `config/ucf.yaml`
10
+ - `pm2 start train_ucf.py --no-autorestart` to train a generalist detector on datasets from `DATASET_META`
11
+ - `pm2 start train_ucf.py --no-autorestart -- --faces_only` to train a face expert detector on preprocessed-face only datasets
12
+
13
+ - **Miner Neurons**:
14
+ - The `UCF` class in `pretrained_ucf.py` is used by miner neurons to load and perform inference with pretrained UCF model weights.
logger.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import torch.distributed as dist
5
+
6
+ class RankFilter(logging.Filter):
7
+ def __init__(self, rank):
8
+ super().__init__()
9
+ self.rank = rank
10
+
11
+ def filter(self, record):
12
+ return dist.get_rank() == self.rank
13
+
14
+ def create_logger(log_path):
15
+ # Create log path
16
+ if os.path.isdir(os.path.dirname(log_path)):
17
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
18
+
19
+ # Create logger object
20
+ logger = logging.getLogger()
21
+ logger.setLevel(logging.INFO)
22
+ # Create file handler and set the formatter
23
+ fh = logging.FileHandler(log_path)
24
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
25
+ fh.setFormatter(formatter)
26
+
27
+ # Add the file handler to the logger
28
+ logger.addHandler(fh)
29
+
30
+ # Add a stream handler to print to console
31
+ sh = logging.StreamHandler()
32
+ sh.setLevel(logging.INFO) # Set logging level for stream handler
33
+ sh.setFormatter(formatter)
34
+ logger.addHandler(sh)
35
+
36
+ return logger
train_detector.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script was adapted from the DeepfakeBench training code,
2
+ # originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
3
+
4
+ # Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
5
+
6
+ # BitMind's modifications include adding a testing phase, changing the
7
+ # data load/split pipeline to work with subnet 34's image augmentations
8
+ # and datasets from BitMind HuggingFace repositories, quality of life CLI args,
9
+ # logging changes, etc.
10
+
11
+ import os
12
+ import sys
13
+ import argparse
14
+ from os.path import join
15
+ import random
16
+ import datetime
17
+ import time
18
+ import yaml
19
+ from tqdm import tqdm
20
+ import numpy as np
21
+ from datetime import timedelta
22
+ from copy import deepcopy
23
+ from PIL import Image as pil_image
24
+ from pathlib import Path
25
+ import gc
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.parallel
30
+ import torch.backends.cudnn as cudnn
31
+ import torch.utils.data
32
+ import torch.optim as optim
33
+ from torch.utils.data.distributed import DistributedSampler
34
+ import torch.distributed as dist
35
+ from torch.utils.data import DataLoader
36
+
37
+ from optimizor.SAM import SAM
38
+ from optimizor.LinearLR import LinearDecayLR
39
+
40
+ from trainer.trainer import Trainer
41
+ from arena.detectors.UCF.detectors import DETECTOR
42
+ from metrics.utils import parse_metric_for_print
43
+ from logger import create_logger, RankFilter
44
+
45
+ from huggingface_hub import hf_hub_download
46
+
47
+ # BitMind imports (not from original Deepfake Bench repo)
48
+ from bitmind.dataset_processing.load_split_data import load_datasets, create_real_fake_datasets
49
+ from bitmind.image_transforms import base_transforms, random_aug_transforms
50
+ from bitmind.constants import DATASET_META, FACE_TRAINING_DATASET_META
51
+ from config.constants import (
52
+ CONFIG_PATH,
53
+ WEIGHTS_DIR,
54
+ HF_REPO,
55
+ BACKBONE_CKPT
56
+ )
57
+
58
+ parser = argparse.ArgumentParser(description='Process some paths.')
59
+ parser.add_argument('--detector_path', type=str, default=CONFIG_PATH, help='path to detector YAML file')
60
+ parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False)
61
+ parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True)
62
+ parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True)
63
+ parser.add_argument("--ddp", action='store_true', default=False)
64
+ parser.add_argument('--local_rank', type=int, default=0)
65
+ parser.add_argument('--workers', type=int, default=os.cpu_count() - 1,
66
+ help='number of workers for data loading')
67
+ parser.add_argument('--epochs', type=int, default=None, help='number of training epochs')
68
+
69
+ args = parser.parse_args()
70
+ torch.cuda.set_device(args.local_rank)
71
+ print(f"torch.cuda.device(0): {torch.cuda.device(0)}")
72
+ print(f"torch.cuda.get_device_name(0): {torch.cuda.get_device_name(0)}")
73
+
74
+ def ensure_backbone_is_available(logger,
75
+ weights_dir=WEIGHTS_DIR,
76
+ model_filename=BACKBONE_CKPT,
77
+ hugging_face_repo_name=HF_REPO):
78
+
79
+ destination_path = Path(weights_dir) / Path(model_filename)
80
+ if not destination_path.parent.exists():
81
+ destination_path.parent.mkdir(parents=True, exist_ok=True)
82
+ logger.info(f"Created directory {destination_path.parent}.")
83
+ if not destination_path.exists():
84
+ model_path = hf_hub_download(hugging_face_repo_name, model_filename)
85
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ model = torch.load(model_path, map_location=device)
87
+ torch.save(model, destination_path)
88
+ del model
89
+ if torch.cuda.is_available():
90
+ torch.cuda.empty_cache()
91
+ logger.info(f"Downloaded backbone {model_filename} to {destination_path}.")
92
+ else:
93
+ logger.info(f"{model_filename} backbone already present at {destination_path}.")
94
+
95
+ def init_seed(config):
96
+ if config['manualSeed'] is None:
97
+ config['manualSeed'] = random.randint(1, 10000)
98
+ random.seed(config['manualSeed'])
99
+ if config['cuda']:
100
+ torch.manual_seed(config['manualSeed'])
101
+ torch.cuda.manual_seed_all(config['manualSeed'])
102
+
103
+ def custom_collate_fn(batch):
104
+ images, labels, source_labels = zip(*batch)
105
+
106
+ images = torch.stack(images, dim=0) # Stack image tensors into a single tensor
107
+ labels = torch.LongTensor(labels)
108
+ source_labels = torch.LongTensor(source_labels)
109
+
110
+ data_dict = {
111
+ 'image': images,
112
+ 'label': labels,
113
+ 'label_spe': source_labels,
114
+ 'landmark': None,
115
+ 'mask': None
116
+ }
117
+ return data_dict
118
+
119
+ def prepare_datasets(config, logger):
120
+ start_time = log_start_time(logger, "Loading and splitting individual datasets")
121
+
122
+ real_datasets, fake_datasets = load_datasets(dataset_meta=config['dataset_meta'],
123
+ expert=config['faces_only'],
124
+ split_transforms=config['split_transforms'])
125
+
126
+ log_finish_time(logger, "Loading and splitting individual datasets", start_time)
127
+
128
+ start_time = log_start_time(logger, "Creating real fake dataset splits")
129
+ train_dataset, val_dataset, test_dataset = \
130
+ create_real_fake_datasets(real_datasets,
131
+ fake_datasets,
132
+ config['split_transforms']['train']['transform'],
133
+ config['split_transforms']['validation']['transform'],
134
+ config['split_transforms']['test']['transform'],
135
+ source_labels=True)
136
+
137
+ log_finish_time(logger, "Creating real fake dataset splits", start_time)
138
+
139
+ train_loader = torch.utils.data.DataLoader(train_dataset,
140
+ batch_size=config['train_batchSize'],
141
+ shuffle=True,
142
+ num_workers=config['workers'],
143
+ drop_last=True,
144
+ collate_fn=custom_collate_fn)
145
+ val_loader = torch.utils.data.DataLoader(val_dataset,
146
+ batch_size=config['train_batchSize'],
147
+ shuffle=True,
148
+ num_workers=config['workers'],
149
+ drop_last=True,
150
+ collate_fn=custom_collate_fn)
151
+ test_loader = torch.utils.data.DataLoader(test_dataset,
152
+ batch_size=config['train_batchSize'],
153
+ shuffle=True,
154
+ num_workers=config['workers'],
155
+ drop_last=True,
156
+ collate_fn=custom_collate_fn)
157
+
158
+ print(f"Train size: {len(train_loader.dataset)}")
159
+ print(f"Validation size: {len(val_loader.dataset)}")
160
+ print(f"Test size: {len(test_loader.dataset)}")
161
+
162
+ return train_loader, val_loader, test_loader
163
+
164
+ def choose_optimizer(model, config):
165
+ opt_name = config['optimizer']['type']
166
+ if opt_name == 'sgd':
167
+ optimizer = optim.SGD(
168
+ params=model.parameters(),
169
+ lr=config['optimizer'][opt_name]['lr'],
170
+ momentum=config['optimizer'][opt_name]['momentum'],
171
+ weight_decay=config['optimizer'][opt_name]['weight_decay']
172
+ )
173
+ return optimizer
174
+ elif opt_name == 'adam':
175
+ optimizer = optim.Adam(
176
+ params=model.parameters(),
177
+ lr=config['optimizer'][opt_name]['lr'],
178
+ weight_decay=config['optimizer'][opt_name]['weight_decay'],
179
+ betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']),
180
+ eps=config['optimizer'][opt_name]['eps'],
181
+ amsgrad=config['optimizer'][opt_name]['amsgrad'],
182
+ )
183
+ return optimizer
184
+ elif opt_name == 'sam':
185
+ optimizer = SAM(
186
+ model.parameters(),
187
+ optim.SGD,
188
+ lr=config['optimizer'][opt_name]['lr'],
189
+ momentum=config['optimizer'][opt_name]['momentum'],
190
+ )
191
+ else:
192
+ raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer']))
193
+ return optimizer
194
+
195
+
196
+ def choose_scheduler(config, optimizer):
197
+ if config['lr_scheduler'] is None:
198
+ return None
199
+ elif config['lr_scheduler'] == 'step':
200
+ scheduler = optim.lr_scheduler.StepLR(
201
+ optimizer,
202
+ step_size=config['lr_step'],
203
+ gamma=config['lr_gamma'],
204
+ )
205
+ return scheduler
206
+ elif config['lr_scheduler'] == 'cosine':
207
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
208
+ optimizer,
209
+ T_max=config['lr_T_max'],
210
+ eta_min=config['lr_eta_min'],
211
+ )
212
+ return scheduler
213
+ elif config['lr_scheduler'] == 'linear':
214
+ scheduler = LinearDecayLR(
215
+ optimizer,
216
+ config['nEpochs'],
217
+ int(config['nEpochs']/4),
218
+ )
219
+ else:
220
+ raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler']))
221
+
222
+ def choose_metric(config):
223
+ metric_scoring = config['metric_scoring']
224
+ if metric_scoring not in ['eer', 'auc', 'acc', 'ap']:
225
+ raise NotImplementedError('metric {} is not implemented'.format(metric_scoring))
226
+ return metric_scoring
227
+
228
+ def log_start_time(logger, process_name):
229
+ """Log the start time of a process."""
230
+ start_time = time.time()
231
+ logger.info(f"{process_name} Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
232
+ return start_time
233
+
234
+ def log_finish_time(logger, process_name, start_time):
235
+ """Log the finish time and elapsed time of a process."""
236
+ finish_time = time.time()
237
+ elapsed_time = finish_time - start_time
238
+
239
+ # Convert elapsed time into hours, minutes, and seconds
240
+ hours, rem = divmod(elapsed_time, 3600)
241
+ minutes, seconds = divmod(rem, 60)
242
+
243
+ # Log the finish time and elapsed time
244
+ logger.info(f"{process_name} Finish Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(finish_time))}")
245
+ logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
246
+
247
+ def save_config(config, outputs_dir):
248
+ """
249
+ Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved.
250
+ Also, lists like 'mean' and 'std' are saved in flow style (on a single line).
251
+
252
+ Args:
253
+ config (dict): The configuration dictionary to save.
254
+ outputs_dir (str): The directory path where the files will be saved.
255
+ """
256
+
257
+ def is_basic_type(value):
258
+ """
259
+ Check if a value is a basic data type that can be saved in YAML.
260
+ Basic types include int, float, str, bool, list, and dict.
261
+ """
262
+ return isinstance(value, (int, float, str, bool, list, dict, type(None)))
263
+
264
+ def filter_dict(data_dict):
265
+ """
266
+ Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects).
267
+ """
268
+ if not isinstance(data_dict, dict):
269
+ return data_dict
270
+
271
+ filtered_dict = {}
272
+ for key, value in data_dict.items():
273
+ if isinstance(value, dict):
274
+ # Recursively filter nested dictionaries
275
+ nested_dict = filter_dict(value)
276
+ if nested_dict: # Only add non-empty dictionaries
277
+ filtered_dict[key] = nested_dict
278
+ elif is_basic_type(value):
279
+ # Add if the value is a basic type
280
+ filtered_dict[key] = value
281
+ else:
282
+ # Skip the key if the value is not a basic type (e.g., an object)
283
+ print(f"Skipping key '{key}' because its value is of type {type(value)}")
284
+
285
+ return filtered_dict
286
+
287
+ def save_dict_to_yaml(data_dict, file_path):
288
+ """
289
+ Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object.
290
+ Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style.
291
+
292
+ Args:
293
+ data_dict (dict): The dictionary to save.
294
+ file_path (str): The local file path where the YAML file will be saved.
295
+ """
296
+
297
+ # Custom representer for lists to force flow style (compact lists)
298
+ class FlowStyleList(list):
299
+ pass
300
+
301
+ def flow_style_list_representer(dumper, data):
302
+ return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)
303
+
304
+ yaml.add_representer(FlowStyleList, flow_style_list_representer)
305
+
306
+ # Preprocess specific lists to be in flow style
307
+ if 'mean' in data_dict:
308
+ data_dict['mean'] = FlowStyleList(data_dict['mean'])
309
+ if 'std' in data_dict:
310
+ data_dict['std'] = FlowStyleList(data_dict['std'])
311
+
312
+ try:
313
+ # Filter the dictionary
314
+ filtered_dict = filter_dict(data_dict)
315
+
316
+ # Save the filtered dictionary as YAML
317
+ with open(file_path, 'w') as f:
318
+ yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList
319
+ print(f"Filtered dictionary successfully saved to {file_path}")
320
+ except Exception as e:
321
+ print(f"Error saving dictionary to YAML: {e}")
322
+
323
+ # Save as YAML
324
+ save_dict_to_yaml(config, outputs_dir + '/config.yaml')
325
+
326
+ def main():
327
+ torch.cuda.empty_cache()
328
+ gc.collect()
329
+ # parse options and load config
330
+ with open(args.detector_path, 'r') as f:
331
+ config = yaml.safe_load(f)
332
+ with open(os.getcwd() + '/config/train_config.yaml', 'r') as f:
333
+ config2 = yaml.safe_load(f)
334
+ if 'label_dict' in config:
335
+ config2['label_dict']=config['label_dict']
336
+ config.update(config2)
337
+
338
+ config['workers'] = args.workers
339
+
340
+ config['local_rank']=args.local_rank
341
+ if config['dry_run']:
342
+ config['nEpochs'] = 0
343
+ config['save_feat']=False
344
+
345
+ if args.epochs: config['nEpochs'] = args.epochs
346
+
347
+ config['split_transforms'] = {'train': {'name': 'base_transforms',
348
+ 'transform': base_transforms},
349
+ 'validation': {'name': 'base_transforms',
350
+ 'transform': base_transforms},
351
+ 'test': {'name': 'base_transforms',
352
+ 'transform': base_transforms}}
353
+ config['faces_only'] = args.faces_only
354
+ config['dataset_meta'] = FACE_TRAINING_DATASET_META if config['faces_only'] else DATASET_META
355
+ dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets]
356
+ config['train_dataset'] = dataset_names
357
+ config['save_ckpt'] = args.save_ckpt
358
+ config['save_feat'] = args.save_feat
359
+
360
+ config['specific_task_number'] = len(config['dataset_meta']["fake"]) + 1
361
+
362
+ if config['lmdb']:
363
+ config['dataset_json_folder'] = 'preprocessing/dataset_json_v3'
364
+
365
+ # create logger
366
+ timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
367
+
368
+ outputs_dir = os.path.join(
369
+ config['log_dir'],
370
+ config['model_name'] + '_' + timenow
371
+ )
372
+
373
+ os.makedirs(outputs_dir, exist_ok=True)
374
+ logger = create_logger(os.path.join(outputs_dir, 'training.log'))
375
+ config['log_dir'] = outputs_dir
376
+ logger.info('Save log to {}'.format(outputs_dir))
377
+
378
+ config['ddp']= args.ddp
379
+
380
+ # init seed
381
+ init_seed(config)
382
+
383
+ # set cudnn benchmark if needed
384
+ if config['cudnn']:
385
+ cudnn.benchmark = True
386
+ if config['ddp']:
387
+ # dist.init_process_group(backend='gloo')
388
+ dist.init_process_group(
389
+ backend='nccl',
390
+ timeout=timedelta(minutes=30)
391
+ )
392
+ logger.addFilter(RankFilter(0))
393
+
394
+ ensure_backbone_is_available(logger=logger,
395
+ model_filename=config['pretrained'].split('/')[-1],
396
+ hugging_face_repo_name='bitmind/' + config['model_name'])
397
+
398
+ # prepare the model (detector)
399
+ model_class = DETECTOR[config['model_name']]
400
+ model = model_class(config)
401
+
402
+ # prepare the optimizer
403
+ optimizer = choose_optimizer(model, config)
404
+
405
+ # prepare the scheduler
406
+ scheduler = choose_scheduler(config, optimizer)
407
+
408
+ # prepare the metric
409
+ metric_scoring = choose_metric(config)
410
+
411
+ # prepare the trainer
412
+ trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring)
413
+
414
+ # prepare the data loaders
415
+ train_loader, val_loader, test_loader = prepare_datasets(config, logger)
416
+
417
+ # print configuration
418
+ logger.info("--------------- Configuration ---------------")
419
+ params_string = "Parameters: \n"
420
+ for key, value in config.items():
421
+ params_string += "{}: {}".format(key, value) + "\n"
422
+ logger.info(params_string)
423
+
424
+ # save training configs
425
+ save_config(config, outputs_dir)
426
+
427
+ # start training
428
+ start_time = log_start_time(logger, "Training")
429
+ for epoch in range(config['start_epoch'], config['nEpochs'] + 1):
430
+ trainer.model.epoch = epoch
431
+ best_metric = trainer.train_epoch(
432
+ epoch,
433
+ train_data_loader=train_loader,
434
+ validation_data_loaders={'val':val_loader}
435
+ )
436
+ if best_metric is not None:
437
+ logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!")
438
+ logger.info("Stop Training on best Validation metric {}".format(parse_metric_for_print(best_metric)))
439
+ log_finish_time(logger, "Training", start_time)
440
+
441
+ # test
442
+ start_time = log_start_time(logger, "Test")
443
+ trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test")
444
+ log_finish_time(logger, "Test", start_time)
445
+
446
+ # update
447
+ if 'svdd' in config['model_name']:
448
+ model.update_R(epoch)
449
+ if scheduler is not None:
450
+ scheduler.step()
451
+
452
+ # close the tensorboard writers
453
+ for writer in trainer.writers.values():
454
+ writer.close()
455
+
456
+ torch.cuda.empty_cache()
457
+ gc.collect()
458
+
459
+ if __name__ == '__main__':
460
+ main()