caliangandrew commited on
Commit
087d6cf
·
verified ·
1 Parent(s): d715fef

Delete test_deps

Browse files
Files changed (42) hide show
  1. test_deps/README.md +0 -14
  2. test_deps/config/__init__.py +0 -7
  3. test_deps/config/__pycache__/__init__.cpython-310.pyc +0 -0
  4. test_deps/config/__pycache__/constants.cpython-310.pyc +0 -0
  5. test_deps/config/constants.py +0 -15
  6. test_deps/config/pretrained_config.yaml +0 -94
  7. test_deps/config/pretrained_face_config.yaml +0 -94
  8. test_deps/config/train_config.yaml +0 -9
  9. test_deps/config/ucf.yaml +0 -73
  10. test_deps/config/xception.yaml +0 -86
  11. test_deps/detectors/__init__.py +0 -11
  12. test_deps/detectors/__pycache__/__init__.cpython-310.pyc +0 -0
  13. test_deps/detectors/__pycache__/base_detector.cpython-310.pyc +0 -0
  14. test_deps/detectors/__pycache__/ucf_detector.cpython-310.pyc +0 -0
  15. test_deps/detectors/base_detector.py +0 -71
  16. test_deps/detectors/ucf_detector.py +0 -472
  17. test_deps/logger.py +0 -36
  18. test_deps/loss/__init__.py +0 -13
  19. test_deps/loss/__pycache__/__init__.cpython-310.pyc +0 -0
  20. test_deps/loss/__pycache__/abstract_loss_func.cpython-310.pyc +0 -0
  21. test_deps/loss/__pycache__/contrastive_regularization.cpython-310.pyc +0 -0
  22. test_deps/loss/__pycache__/cross_entropy_loss.cpython-310.pyc +0 -0
  23. test_deps/loss/__pycache__/l1_loss.cpython-310.pyc +0 -0
  24. test_deps/loss/abstract_loss_func.py +0 -17
  25. test_deps/loss/contrastive_regularization.py +0 -78
  26. test_deps/loss/cross_entropy_loss.py +0 -26
  27. test_deps/loss/l1_loss.py +0 -19
  28. test_deps/metrics/__init__.py +0 -7
  29. test_deps/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
  30. test_deps/metrics/__pycache__/base_metrics_class.cpython-310.pyc +0 -0
  31. test_deps/metrics/__pycache__/registry.cpython-310.pyc +0 -0
  32. test_deps/metrics/base_metrics_class.py +0 -205
  33. test_deps/metrics/registry.py +0 -20
  34. test_deps/metrics/utils.py +0 -88
  35. test_deps/networks/__init__.py +0 -11
  36. test_deps/networks/__pycache__/__init__.cpython-310.pyc +0 -0
  37. test_deps/networks/__pycache__/xception.cpython-310.pyc +0 -0
  38. test_deps/networks/xception.py +0 -285
  39. test_deps/optimizor/LinearLR.py +0 -20
  40. test_deps/optimizor/SAM.py +0 -77
  41. test_deps/train_detector.py +0 -460
  42. test_deps/trainer/trainer.py +0 -441
test_deps/README.md DELETED
@@ -1,14 +0,0 @@
1
- ## UCF
2
-
3
- This model has been adapted from [DeepfakeBench](https://github.com/SCLBD/DeepfakeBench).
4
-
5
- ##
6
-
7
- - **Train UCF model**:
8
- - Use `train_ucf.py`, which will download necessary pretrained `xception` backbone weights from HuggingFace (if not present locally) and start a training job with logging outputs in `.logs/`.
9
- - Customize the training job by editing `config/ucf.yaml`
10
- - `pm2 start train_ucf.py --no-autorestart` to train a generalist detector on datasets from `DATASET_META`
11
- - `pm2 start train_ucf.py --no-autorestart -- --faces_only` to train a face expert detector on preprocessed-face only datasets
12
-
13
- - **Miner Neurons**:
14
- - The `UCF` class in `pretrained_ucf.py` is used by miner neurons to load and perform inference with pretrained UCF model weights.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/config/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- import os
2
- import sys
3
- current_file_path = os.path.abspath(__file__)
4
- parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
- project_root_dir = os.path.dirname(parent_dir)
6
- sys.path.append(parent_dir)
7
- sys.path.append(project_root_dir)
 
 
 
 
 
 
 
 
test_deps/config/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (350 Bytes)
 
test_deps/config/__pycache__/constants.cpython-310.pyc DELETED
Binary file (543 Bytes)
 
test_deps/config/constants.py DELETED
@@ -1,15 +0,0 @@
1
- import os
2
-
3
- # Path to the directory containing the constants.py file
4
- CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__))
5
-
6
- # The base directory for UCF-related files, i.e., UCF directory
7
- UCF_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/UCF/
8
- # Absolute paths for the required files and directories
9
- CONFIG_PATH = os.path.join(CONFIGS_DIR, "ucf.yaml") # Path to the ucf.yaml file
10
- WEIGHTS_DIR = os.path.join(UCF_BASE_PATH, "weights/") # Path to pretrained weights directory
11
-
12
- HF_REPO = "bitmind/ucf"
13
- BACKBONE_CKPT = "xception_best.pth"
14
-
15
- DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(UCF_BASE_PATH, "../../utils/dlib_tools/shape_predictor_81_face_landmarks.dat"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/config/pretrained_config.yaml DELETED
@@ -1,94 +0,0 @@
1
- SWA: false
2
- backbone_config:
3
- dropout: false
4
- inc: 3
5
- mode: adjust_channel
6
- num_classes: 2
7
- backbone_name: xception
8
- compression: c23
9
- cuda: true
10
- cudnn: true
11
- dataset_json_folder: preprocessing/dataset_json_v3
12
- dataset_meta:
13
- fake:
14
- - create_splits: false
15
- path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
16
- - create_splits: false
17
- path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
18
- real:
19
- - create_splits: false
20
- path: bitmind/celeb-a-hq_training_faces
21
- - create_splits: false
22
- path: bitmind/ffhq-256_training_faces
23
- ddp: false
24
- dry_run: false
25
- encoder_feat_dim: 512
26
- faces_only: true
27
- frame_num:
28
- test: 32
29
- train: 32
30
- lmdb: true
31
- lmdb_dir: ./datasets/lmdb
32
- local_rank: 0
33
- log_dir: ./logs/training/ucf_2024-09-17-16-44-50
34
- logdir: ./logs
35
- loss_func:
36
- cls_loss: cross_entropy
37
- con_loss: contrastive_regularization
38
- rec_loss: l1loss
39
- spe_loss: cross_entropy
40
- losstype: null
41
- lr_scheduler: null
42
- manualSeed: 1024
43
- mean:
44
- - 0.5
45
- - 0.5
46
- - 0.5
47
- metric_scoring: auc
48
- mode: train
49
- model_name: ucf
50
- nEpochs: 2
51
- optimizer:
52
- adam:
53
- amsgrad: false
54
- beta1: 0.9
55
- beta2: 0.999
56
- eps: 1.0e-08
57
- lr: 0.0002
58
- weight_decay: 0.0005
59
- sgd:
60
- lr: 0.0002
61
- momentum: 0.9
62
- weight_decay: 0.0005
63
- type: adam
64
- pretrained: ../weights/xception_best.pth
65
- rec_iter: 100
66
- resolution: 256
67
- rgb_dir: ./datasets/rgb
68
- save_avg: true
69
- save_ckpt: true
70
- save_epoch: 1
71
- save_feat: true
72
- specific_task_number: 2
73
- split_transforms:
74
- test:
75
- name: base_transforms
76
- train:
77
- name: random_aug_transforms
78
- validation:
79
- name: base_transforms
80
- start_epoch: 0
81
- std:
82
- - 0.5
83
- - 0.5
84
- - 0.5
85
- test_batchSize: 32
86
- train_batchSize: 32
87
- train_dataset:
88
- - bitmind/celeb-a-hq_training_faces
89
- - bitmind/ffhq-256_training_faces
90
- - bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
91
- - bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
92
- with_landmark: false
93
- with_mask: false
94
- workers: 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/config/pretrained_face_config.yaml DELETED
@@ -1,94 +0,0 @@
1
- SWA: false
2
- backbone_config:
3
- dropout: false
4
- inc: 3
5
- mode: adjust_channel
6
- num_classes: 2
7
- backbone_name: xception
8
- compression: c23
9
- cuda: true
10
- cudnn: true
11
- dataset_json_folder: preprocessing/dataset_json_v3
12
- dataset_meta:
13
- fake:
14
- - create_splits: false
15
- path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
16
- - create_splits: false
17
- path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
18
- real:
19
- - create_splits: false
20
- path: bitmind/celeb-a-hq_training_faces
21
- - create_splits: false
22
- path: bitmind/ffhq-256_training_faces
23
- ddp: false
24
- dry_run: false
25
- encoder_feat_dim: 512
26
- faces_only: true
27
- frame_num:
28
- test: 32
29
- train: 32
30
- lmdb: true
31
- lmdb_dir: ./datasets/lmdb
32
- local_rank: 0
33
- log_dir: ./logs/training/ucf_2024-09-17-16-44-50
34
- logdir: ./logs
35
- loss_func:
36
- cls_loss: cross_entropy
37
- con_loss: contrastive_regularization
38
- rec_loss: l1loss
39
- spe_loss: cross_entropy
40
- losstype: null
41
- lr_scheduler: null
42
- manualSeed: 1024
43
- mean:
44
- - 0.5
45
- - 0.5
46
- - 0.5
47
- metric_scoring: auc
48
- mode: train
49
- model_name: ucf
50
- nEpochs: 2
51
- optimizer:
52
- adam:
53
- amsgrad: false
54
- beta1: 0.9
55
- beta2: 0.999
56
- eps: 1.0e-08
57
- lr: 0.0002
58
- weight_decay: 0.0005
59
- sgd:
60
- lr: 0.0002
61
- momentum: 0.9
62
- weight_decay: 0.0005
63
- type: adam
64
- pretrained: ../weights/xception_best.pth
65
- rec_iter: 100
66
- resolution: 256
67
- rgb_dir: ./datasets/rgb
68
- save_avg: true
69
- save_ckpt: true
70
- save_epoch: 1
71
- save_feat: true
72
- specific_task_number: 2
73
- split_transforms:
74
- test:
75
- name: base_transforms
76
- train:
77
- name: random_aug_transforms
78
- validation:
79
- name: base_transforms
80
- start_epoch: 0
81
- std:
82
- - 0.5
83
- - 0.5
84
- - 0.5
85
- test_batchSize: 32
86
- train_batchSize: 32
87
- train_dataset:
88
- - bitmind/celeb-a-hq_training_faces
89
- - bitmind/ffhq-256_training_faces
90
- - bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
91
- - bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
92
- with_landmark: false
93
- with_mask: false
94
- workers: 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/config/train_config.yaml DELETED
@@ -1,9 +0,0 @@
1
- mode: train
2
- lmdb: True
3
- dry_run: false
4
- rgb_dir: './datasets/rgb'
5
- lmdb_dir: './datasets/lmdb'
6
- dataset_json_folder: './preprocessing/dataset_json'
7
- SWA: False
8
- save_avg: True
9
- log_dir: ./logs/training/
 
 
 
 
 
 
 
 
 
 
test_deps/config/ucf.yaml DELETED
@@ -1,73 +0,0 @@
1
- # log dir
2
- log_dir: ../debug_logs/ucf
3
-
4
- # model setting
5
- pretrained: ../weights/xception_best.pth # path to a pre-trained model, if using one
6
- model_name: ucf # model name
7
- backbone_name: xception # backbone name
8
- encoder_feat_dim: 512 # feature dimension of the backbone
9
-
10
- #backbone setting
11
- backbone_config:
12
- mode: adjust_channel
13
- num_classes: 2
14
- inc: 3
15
- dropout: false
16
-
17
- compression: c23 # compression-level for videos
18
- train_batchSize: 32 # training batch size
19
- test_batchSize: 32 # test batch size
20
- workers: 8 # number of data loading workers
21
- frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
22
- resolution: 256 # resolution of output image to network
23
- with_mask: false # whether to include mask information in the input
24
- with_landmark: false # whether to include facial landmark information in the input
25
- save_ckpt: true # whether to save checkpoint
26
- save_feat: true # whether to save features
27
- specific_task_number: 2 # default num datasets in FF++ used by DFB, overwritten in training
28
-
29
- # mean and std for normalization
30
- mean: [0.5, 0.5, 0.5]
31
- std: [0.5, 0.5, 0.5]
32
-
33
- # optimizer config
34
- optimizer:
35
- # choose between 'adam' and 'sgd'
36
- type: adam
37
- adam:
38
- lr: 0.0002 # learning rate
39
- beta1: 0.9 # beta1 for Adam optimizer
40
- beta2: 0.999 # beta2 for Adam optimizer
41
- eps: 0.00000001 # epsilon for Adam optimizer
42
- weight_decay: 0.0005 # weight decay for regularization
43
- amsgrad: false
44
- sgd:
45
- lr: 0.0002 # learning rate
46
- momentum: 0.9 # momentum for SGD optimizer
47
- weight_decay: 0.0005 # weight decay for regularization
48
-
49
- # training config
50
- lr_scheduler: null # learning rate scheduler
51
- nEpochs: 20 # number of epochs to train for
52
- start_epoch: 0 # manual epoch number (useful for restarts)
53
- save_epoch: 1 # interval epochs for saving models
54
- rec_iter: 100 # interval iterations for recording
55
- logdir: ./logs # folder to output images and logs
56
- manualSeed: 1024 # manual seed for random number generation
57
- save_ckpt: false # whether to save checkpoint
58
-
59
- # loss function
60
- loss_func:
61
- cls_loss: cross_entropy # loss function to use
62
- spe_loss: cross_entropy
63
- con_loss: contrastive_regularization
64
- rec_loss: l1loss
65
- losstype: null
66
-
67
- # metric
68
- metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
69
-
70
- # cuda
71
-
72
- cuda: true # whether to use CUDA acceleration
73
- cudnn: true # whether to use CuDNN for convolution operations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/config/xception.yaml DELETED
@@ -1,86 +0,0 @@
1
- # log dir
2
- log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs/testing_bench
3
-
4
- # model setting
5
- pretrained: /data/home/zhiyuanyan/DeepfakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one
6
- model_name: xception # model name
7
- backbone_name: xception # backbone name
8
-
9
- #backbone setting
10
- backbone_config:
11
- mode: original
12
- num_classes: 2
13
- inc: 3
14
- dropout: false
15
-
16
- # dataset
17
- all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
18
- train_dataset: [FaceForensics++]
19
- test_dataset: [FaceForensics++, DeepFakeDetection]
20
-
21
- compression: c23 # compression-level for videos
22
- train_batchSize: 32 # training batch size
23
- test_batchSize: 32 # test batch size
24
- workers: 8 # number of data loading workers
25
- frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
26
- resolution: 256 # resolution of output image to network
27
- with_mask: false # whether to include mask information in the input
28
- with_landmark: false # whether to include facial landmark information in the input
29
-
30
-
31
- # data augmentation
32
- use_data_augmentation: true # Add this flag to enable/disable data augmentation
33
- data_aug:
34
- flip_prob: 0.5
35
- rotate_prob: 0.0
36
- rotate_limit: [-10, 10]
37
- blur_prob: 0.5
38
- blur_limit: [3, 7]
39
- brightness_prob: 0.5
40
- brightness_limit: [-0.1, 0.1]
41
- contrast_limit: [-0.1, 0.1]
42
- quality_lower: 40
43
- quality_upper: 100
44
-
45
- # mean and std for normalization
46
- mean: [0.5, 0.5, 0.5]
47
- std: [0.5, 0.5, 0.5]
48
-
49
- # optimizer config
50
- optimizer:
51
- # choose between 'adam' and 'sgd'
52
- type: adam
53
- adam:
54
- lr: 0.0002 # learning rate
55
- beta1: 0.9 # beta1 for Adam optimizer
56
- beta2: 0.999 # beta2 for Adam optimizer
57
- eps: 0.00000001 # epsilon for Adam optimizer
58
- weight_decay: 0.0005 # weight decay for regularization
59
- amsgrad: false
60
- sgd:
61
- lr: 0.0002 # learning rate
62
- momentum: 0.9 # momentum for SGD optimizer
63
- weight_decay: 0.0005 # weight decay for regularization
64
-
65
- # training config
66
- lr_scheduler: null # learning rate scheduler
67
- nEpochs: 10 # number of epochs to train for
68
- start_epoch: 0 # manual epoch number (useful for restarts)
69
- save_epoch: 1 # interval epochs for saving models
70
- rec_iter: 100 # interval iterations for recording
71
- logdir: ./logs # folder to output images and logs
72
- manualSeed: 1024 # manual seed for random number generation
73
- save_ckpt: true # whether to save checkpoint
74
- save_feat: true # whether to save features
75
-
76
- # loss function
77
- loss_func: cross_entropy # loss function to use
78
- losstype: null
79
-
80
- # metric
81
- metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
82
-
83
- # cuda
84
-
85
- cuda: true # whether to use CUDA acceleration
86
- cudnn: true # whether to use CuDNN for convolution operations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/detectors/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- import os
2
- import sys
3
- current_file_path = os.path.abspath(__file__)
4
- parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
- project_root_dir = os.path.dirname(parent_dir)
6
- sys.path.append(parent_dir)
7
- sys.path.append(project_root_dir)
8
-
9
- from metrics.registry import DETECTOR
10
-
11
- from .ucf_detector import UCFDetector
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/detectors/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (455 Bytes)
 
test_deps/detectors/__pycache__/base_detector.cpython-310.pyc DELETED
Binary file (2.57 kB)
 
test_deps/detectors/__pycache__/ucf_detector.cpython-310.pyc DELETED
Binary file (12.9 kB)
 
test_deps/detectors/base_detector.py DELETED
@@ -1,71 +0,0 @@
1
- # author: Zhiyuan Yan
2
- # email: zhiyuanyan@link.cuhk.edu.cn
3
- # date: 2023-0706
4
- # description: Abstract Class for the Deepfake Detector
5
-
6
- import abc
7
- import torch
8
- import torch.nn as nn
9
- from typing import Union
10
-
11
- class AbstractDetector(nn.Module, metaclass=abc.ABCMeta):
12
- """
13
- All deepfake detectors should subclass this class.
14
- """
15
- def __init__(self, config=None, load_param: Union[bool, str] = False):
16
- """
17
- config: (dict)
18
- configurations for the model
19
- load_param: (False | True | Path(str))
20
- False Do not read; True Read the default path; Path Read the required path
21
- """
22
- super().__init__()
23
-
24
- @abc.abstractmethod
25
- def features(self, data_dict: dict) -> torch.tensor:
26
- """
27
- Returns the features from the backbone given the input data.
28
- """
29
- pass
30
-
31
- @abc.abstractmethod
32
- def forward(self, data_dict: dict, inference=False) -> dict:
33
- """
34
- Forward pass through the model, returning the prediction dictionary.
35
- """
36
- pass
37
-
38
- @abc.abstractmethod
39
- def classifier(self, features: torch.tensor) -> torch.tensor:
40
- """
41
- Classifies the features into classes.
42
- """
43
- pass
44
-
45
- @abc.abstractmethod
46
- def build_backbone(self, config):
47
- """
48
- Builds the backbone of the model.
49
- """
50
- pass
51
-
52
- @abc.abstractmethod
53
- def build_loss(self, config):
54
- """
55
- Builds the loss function for the model.
56
- """
57
- pass
58
-
59
- @abc.abstractmethod
60
- def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
61
- """
62
- Returns the losses for the model.
63
- """
64
- pass
65
-
66
- @abc.abstractmethod
67
- def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
68
- """
69
- Returns the training metrics for the model.
70
- """
71
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/detectors/ucf_detector.py DELETED
@@ -1,472 +0,0 @@
1
- '''
2
- # Source: https://github.com/SCLBD/DeepfakeBench/blob/main/training/detectors/ucf_detector.py
3
- # author: Zhiyuan Yan
4
- # email: zhiyuanyan@link.cuhk.edu.cn
5
- # date: 2023-0706
6
- # description: Class for the UCFDetector
7
-
8
- Functions in the Class are summarized as:
9
- 1. __init__: Initialization
10
- 2. build_backbone: Backbone-building
11
- 3. build_loss: Loss-function-building
12
- 4. features: Feature-extraction
13
- 5. classifier: Classification
14
- 6. get_losses: Loss-computation
15
- 7. get_train_metrics: Training-metrics-computation
16
- 8. get_test_metrics: Testing-metrics-computation
17
- 9. forward: Forward-propagation
18
-
19
- Reference:
20
- @article{yan2023ucf,
21
- title={UCF: Uncovering Common Features for Generalizable Deepfake Detection},
22
- author={Yan, Zhiyuan and Zhang, Yong and Fan, Yanbo and Wu, Baoyuan},
23
- journal={arXiv preprint arXiv:2304.13949},
24
- year={2023}
25
- }
26
- '''
27
-
28
- import os
29
- import datetime
30
- import logging
31
- import random
32
- import numpy as np
33
- from sklearn import metrics
34
- from typing import Union
35
- from collections import defaultdict
36
-
37
- import torch
38
- import torch.nn as nn
39
- import torch.nn.functional as F
40
- import torch.optim as optim
41
- from torch.nn import DataParallel
42
- from torch.utils.tensorboard import SummaryWriter
43
-
44
- from metrics.base_metrics_class import calculate_metrics_for_train
45
-
46
- from .base_detector import AbstractDetector
47
- from arena.detectors.UCF.detectors import DETECTOR
48
- from networks import BACKBONE
49
- from loss import LOSSFUNC
50
-
51
- logger = logging.getLogger(__name__)
52
-
53
- @DETECTOR.register_module(module_name='ucf')
54
- class UCFDetector(AbstractDetector):
55
- def __init__(self, config):
56
- super().__init__()
57
- self.config = config
58
- self.num_classes = config['backbone_config']['num_classes']
59
- self.encoder_feat_dim = config['encoder_feat_dim']
60
- self.half_fingerprint_dim = self.encoder_feat_dim//2
61
-
62
- self.encoder_f = self.build_backbone(config)
63
- self.encoder_c = self.build_backbone(config)
64
-
65
- self.loss_func = self.build_loss(config)
66
- self.prob, self.label = [], []
67
- self.correct, self.total = 0, 0
68
-
69
- # basic function
70
- self.lr = nn.LeakyReLU(inplace=True)
71
- self.do = nn.Dropout(0.2)
72
- self.pool = nn.AdaptiveAvgPool2d(1)
73
-
74
- # conditional gan
75
- self.con_gan = Conditional_UNet()
76
-
77
- # head
78
- specific_task_number = config['specific_task_number']
79
-
80
- self.head_spe = Head(
81
- in_f=self.half_fingerprint_dim,
82
- hidden_dim=self.encoder_feat_dim,
83
- out_f=specific_task_number
84
- )
85
- self.head_sha = Head(
86
- in_f=self.half_fingerprint_dim,
87
- hidden_dim=self.encoder_feat_dim,
88
- out_f=self.num_classes
89
- )
90
- self.block_spe = Conv2d1x1(
91
- in_f=self.encoder_feat_dim,
92
- hidden_dim=self.half_fingerprint_dim,
93
- out_f=self.half_fingerprint_dim
94
- )
95
- self.block_sha = Conv2d1x1(
96
- in_f=self.encoder_feat_dim,
97
- hidden_dim=self.half_fingerprint_dim,
98
- out_f=self.half_fingerprint_dim
99
- )
100
-
101
- def build_backbone(self, config):
102
- current_dir = os.path.dirname(os.path.abspath(__file__))
103
- pretrained_path = os.path.join(current_dir, config['pretrained'])
104
- # prepare the backbone
105
- backbone_class = BACKBONE[config['backbone_name']]
106
- model_config = config['backbone_config']
107
- backbone = backbone_class(model_config)
108
- # if donot load the pretrained weights, fail to get good results
109
- state_dict = torch.load(pretrained_path)
110
- for name, weights in state_dict.items():
111
- if 'pointwise' in name:
112
- state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
113
- state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k}
114
- backbone.load_state_dict(state_dict, False)
115
- logger.info('Load pretrained model successfully!')
116
- return backbone
117
-
118
- def build_loss(self, config):
119
- cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']]
120
- spe_loss_class = LOSSFUNC[config['loss_func']['spe_loss']]
121
- con_loss_class = LOSSFUNC[config['loss_func']['con_loss']]
122
- rec_loss_class = LOSSFUNC[config['loss_func']['rec_loss']]
123
- cls_loss_func = cls_loss_class()
124
- spe_loss_func = spe_loss_class()
125
- con_loss_func = con_loss_class(margin=3.0)
126
- rec_loss_func = rec_loss_class()
127
- loss_func = {
128
- 'cls': cls_loss_func,
129
- 'spe': spe_loss_func,
130
- 'con': con_loss_func,
131
- 'rec': rec_loss_func,
132
- }
133
- return loss_func
134
-
135
- def features(self, data_dict: dict) -> torch.tensor:
136
- cat_data = data_dict['image']
137
- # encoder
138
- f_all = self.encoder_f.features(cat_data)
139
- c_all = self.encoder_c.features(cat_data)
140
- feat_dict = {'forgery': f_all, 'content': c_all}
141
- return feat_dict
142
-
143
- def classifier(self, features: torch.tensor) -> torch.tensor:
144
- # classification, multi-task
145
- # split the features into the specific and common forgery
146
- f_spe = self.block_spe(features)
147
- f_share = self.block_sha(features)
148
- return f_spe, f_share
149
-
150
- def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
151
- if 'label_spe' in data_dict and 'recontruction_imgs' in pred_dict:
152
- return self.get_train_losses(data_dict, pred_dict)
153
- else: # test mode
154
- return self.get_test_losses(data_dict, pred_dict)
155
-
156
- def get_train_losses(self, data_dict: dict, pred_dict: dict) -> dict:
157
- # get combined, real, fake imgs
158
- cat_data = data_dict['image']
159
- real_img, fake_img = cat_data.chunk(2, dim=0)
160
- # get the reconstruction imgs
161
- reconstruction_image_1, \
162
- reconstruction_image_2, \
163
- self_reconstruction_image_1, \
164
- self_reconstruction_image_2 \
165
- = pred_dict['recontruction_imgs']
166
- # get label
167
- label = data_dict['label']
168
- label_spe = data_dict['label_spe']
169
- # get pred
170
- pred = pred_dict['cls']
171
- pred_spe = pred_dict['cls_spe']
172
-
173
- # 1. classification loss for common features
174
- loss_sha = self.loss_func['cls'](pred, label)
175
-
176
- # 2. classification loss for specific features
177
- loss_spe = self.loss_func['spe'](pred_spe, label_spe)
178
-
179
- # 3. reconstruction loss
180
- self_loss_reconstruction_1 = self.loss_func['rec'](fake_img, self_reconstruction_image_1)
181
- self_loss_reconstruction_2 = self.loss_func['rec'](real_img, self_reconstruction_image_2)
182
- cross_loss_reconstruction_1 = self.loss_func['rec'](fake_img, reconstruction_image_2)
183
- cross_loss_reconstruction_2 = self.loss_func['rec'](real_img, reconstruction_image_1)
184
- loss_reconstruction = \
185
- self_loss_reconstruction_1 + self_loss_reconstruction_2 + \
186
- cross_loss_reconstruction_1 + cross_loss_reconstruction_2
187
-
188
- # 4. constrative loss
189
- common_features = pred_dict['feat']
190
- specific_features = pred_dict['feat_spe']
191
- loss_con = self.loss_func['con'](common_features, specific_features, label_spe)
192
-
193
- # 5. total loss
194
- loss = loss_sha + 0.1*loss_spe + 0.3*loss_reconstruction + 0.05*loss_con
195
- loss_dict = {
196
- 'overall': loss,
197
- 'common': loss_sha,
198
- 'specific': loss_spe,
199
- 'reconstruction': loss_reconstruction,
200
- 'contrastive': loss_con,
201
- }
202
- return loss_dict
203
-
204
- def get_test_losses(self, data_dict: dict, pred_dict: dict) -> dict:
205
- # get label
206
- label = data_dict['label']
207
- # get pred
208
- pred = pred_dict['cls']
209
- # for test mode, only classification loss for common features
210
- loss = self.loss_func['cls'](pred, label)
211
- loss_dict = {'common': loss}
212
- return loss_dict
213
-
214
- def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
215
- def get_accracy(label, output):
216
- _, prediction = torch.max(output, 1) # argmax
217
- correct = (prediction == label).sum().item()
218
- accuracy = correct / prediction.size(0)
219
- return accuracy
220
-
221
- # get pred and label
222
- label = data_dict['label']
223
- pred = pred_dict['cls']
224
- label_spe = data_dict['label_spe']
225
- pred_spe = pred_dict['cls_spe']
226
-
227
- # compute metrics for batch data
228
- auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
229
- acc_spe = get_accracy(label_spe.detach(), pred_spe.detach())
230
- metric_batch_dict = {'acc': acc, 'acc_spe': acc_spe, 'auc': auc, 'eer': eer, 'ap': ap}
231
- # we dont compute the video-level metrics for training
232
- return metric_batch_dict
233
-
234
- def forward(self, data_dict: dict, inference=False) -> dict:
235
- # split the features into the content and forgery
236
- features = self.features(data_dict)
237
- forgery_features, content_features = features['forgery'], features['content']
238
- # get the prediction by classifier (split the common and specific forgery)
239
- f_spe, f_share = self.classifier(forgery_features)
240
-
241
- if inference:
242
- # inference only consider share loss
243
- out_sha, sha_feat = self.head_sha(f_share)
244
- out_spe, spe_feat = self.head_spe(f_spe)
245
- prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
246
- self.prob.append(
247
- prob_sha
248
- .detach()
249
- .squeeze()
250
- .cpu()
251
- .numpy()
252
- )
253
- _, prediction_class = torch.max(out_sha, 1)
254
- if 'label' in data_dict:
255
- self.label.append(
256
- data_dict['label']
257
- .detach()
258
- .squeeze()
259
- .cpu()
260
- .numpy()
261
- )
262
- # deal with acc
263
- common_label = (data_dict['label'] >= 1)
264
- correct = (prediction_class == common_label).sum().item()
265
- self.correct += correct
266
- self.total += data_dict['label'].size(0)
267
-
268
- pred_dict = {'cls': out_sha, 'feat': sha_feat}
269
- return pred_dict
270
-
271
- bs = f_share.size(0)
272
- # using idx aug in the training mode
273
- aug_idx = random.random()
274
- if aug_idx < 0.7:
275
- # real
276
- idx_list = list(range(0, bs//2))
277
- random.shuffle(idx_list)
278
- f_share[0: bs//2] = f_share[idx_list]
279
- # fake
280
- idx_list = list(range(bs//2, bs))
281
- random.shuffle(idx_list)
282
- f_share[bs//2: bs] = f_share[idx_list]
283
-
284
- # concat spe and share to obtain new_f_all
285
- f_all = torch.cat((f_spe, f_share), dim=1)
286
-
287
- # reconstruction loss
288
- f2, f1 = f_all.chunk(2, dim=0)
289
- c2, c1 = content_features.chunk(2, dim=0)
290
-
291
- # ==== self reconstruction ==== #
292
- # f1 + c1 -> f11, f11 + c1 -> near~I1
293
- self_reconstruction_image_1 = self.con_gan(f1, c1)
294
-
295
- # f2 + c2 -> f2, f2 + c2 -> near~I2
296
- self_reconstruction_image_2 = self.con_gan(f2, c2)
297
-
298
- # ==== cross combine ==== #
299
- reconstruction_image_1 = self.con_gan(f1, c2)
300
- reconstruction_image_2 = self.con_gan(f2, c1)
301
-
302
- # head for spe and sha
303
- out_spe, spe_feat = self.head_spe(f_spe)
304
- out_sha, sha_feat = self.head_sha(f_share)
305
-
306
- # get the probability of the pred
307
- prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
308
- prob_spe = torch.softmax(out_spe, dim=1)[:, 1]
309
-
310
- # build the prediction dict for each output
311
- pred_dict = {
312
- 'cls': out_sha,
313
- 'prob': prob_sha,
314
- 'feat': sha_feat,
315
- 'cls_spe': out_spe,
316
- 'prob_spe': prob_spe,
317
- 'feat_spe': spe_feat,
318
- 'feat_content': content_features,
319
- 'recontruction_imgs': (
320
- reconstruction_image_1,
321
- reconstruction_image_2,
322
- self_reconstruction_image_1,
323
- self_reconstruction_image_2
324
- )
325
- }
326
- return pred_dict
327
-
328
- def sn_double_conv(in_channels, out_channels):
329
- return nn.Sequential(
330
- nn.utils.spectral_norm(
331
- nn.Conv2d(in_channels, in_channels, 3, padding=1)),
332
- nn.utils.spectral_norm(
333
- nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2)),
334
- nn.LeakyReLU(0.2, inplace=True)
335
- )
336
-
337
- def r_double_conv(in_channels, out_channels):
338
- return nn.Sequential(
339
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
340
- nn.ReLU(inplace=True),
341
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
342
- nn.ReLU(inplace=True)
343
- )
344
-
345
- class AdaIN(nn.Module):
346
- def __init__(self, eps=1e-5):
347
- super().__init__()
348
- self.eps = eps
349
- # self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :)
350
-
351
- def c_norm(self, x, bs, ch, eps=1e-7):
352
- # assert isinstance(x, torch.cuda.FloatTensor)
353
- x_var = x.var(dim=-1) + eps
354
- x_std = x_var.sqrt().view(bs, ch, 1, 1)
355
- x_mean = x.mean(dim=-1).view(bs, ch, 1, 1)
356
- return x_std, x_mean
357
-
358
- def forward(self, x, y):
359
- assert x.size(0)==y.size(0)
360
- size = x.size()
361
- bs, ch = size[:2]
362
- x_ = x.view(bs, ch, -1)
363
- y_ = y.reshape(bs, ch, -1)
364
- x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps)
365
- y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps)
366
- out = ((x - x_mean.expand(size)) / x_std.expand(size)) \
367
- * y_std.expand(size) + y_mean.expand(size)
368
- return out
369
-
370
- class Conditional_UNet(nn.Module):
371
-
372
- def init_weight(self, std=0.2):
373
- for m in self.modules():
374
- cn = m.__class__.__name__
375
- if cn.find('Conv') != -1:
376
- m.weight.data.normal_(0., std)
377
- elif cn.find('Linear') != -1:
378
- m.weight.data.normal_(1., std)
379
- m.bias.data.fill_(0)
380
-
381
- def __init__(self):
382
- super(Conditional_UNet, self).__init__()
383
-
384
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
385
- self.maxpool = nn.MaxPool2d(2)
386
- self.dropout = nn.Dropout(p=0.3)
387
- #self.dropout_half = HalfDropout(p=0.3)
388
-
389
- self.adain3 = AdaIN()
390
- self.adain2 = AdaIN()
391
- self.adain1 = AdaIN()
392
-
393
- self.dconv_up3 = r_double_conv(512, 256)
394
- self.dconv_up2 = r_double_conv(256, 128)
395
- self.dconv_up1 = r_double_conv(128, 64)
396
-
397
- self.conv_last = nn.Conv2d(64, 3, 1)
398
- self.up_last = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
399
- self.activation = nn.Tanh()
400
- #self.init_weight()
401
-
402
- def forward(self, c, x): # c is the style and x is the content
403
- x = self.adain3(x, c)
404
- x = self.upsample(x)
405
- x = self.dropout(x)
406
- x = self.dconv_up3(x)
407
- c = self.upsample(c)
408
- c = self.dropout(c)
409
- c = self.dconv_up3(c)
410
-
411
- x = self.adain2(x, c)
412
- x = self.upsample(x)
413
- x = self.dropout(x)
414
- x = self.dconv_up2(x)
415
- c = self.upsample(c)
416
- c = self.dropout(c)
417
- c = self.dconv_up2(c)
418
-
419
- x = self.adain1(x, c)
420
- x = self.upsample(x)
421
- x = self.dropout(x)
422
- x = self.dconv_up1(x)
423
-
424
- x = self.conv_last(x)
425
- out = self.up_last(x)
426
-
427
- return self.activation(out)
428
-
429
- class MLP(nn.Module):
430
- def __init__(self, in_f, hidden_dim, out_f):
431
- super(MLP, self).__init__()
432
- self.pool = nn.AdaptiveAvgPool2d(1)
433
- self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
434
- nn.LeakyReLU(inplace=True),
435
- nn.Linear(hidden_dim, hidden_dim),
436
- nn.LeakyReLU(inplace=True),
437
- nn.Linear(hidden_dim, out_f),)
438
-
439
- def forward(self, x):
440
- x = self.pool(x)
441
- x = self.mlp(x)
442
- return x
443
-
444
- class Conv2d1x1(nn.Module):
445
- def __init__(self, in_f, hidden_dim, out_f):
446
- super(Conv2d1x1, self).__init__()
447
- self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1),
448
- nn.LeakyReLU(inplace=True),
449
- nn.Conv2d(hidden_dim, hidden_dim, 1, 1),
450
- nn.LeakyReLU(inplace=True),
451
- nn.Conv2d(hidden_dim, out_f, 1, 1),)
452
-
453
- def forward(self, x):
454
- x = self.conv2d(x)
455
- return x
456
-
457
- class Head(nn.Module):
458
- def __init__(self, in_f, hidden_dim, out_f):
459
- super(Head, self).__init__()
460
- self.do = nn.Dropout(0.2)
461
- self.pool = nn.AdaptiveAvgPool2d(1)
462
- self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
463
- nn.LeakyReLU(inplace=True),
464
- nn.Linear(hidden_dim, out_f),)
465
-
466
- def forward(self, x):
467
- bs = x.size()[0]
468
- x_feat = self.pool(x).view(bs, -1)
469
- x = self.mlp(x_feat)
470
- x = self.do(x)
471
- return x, x_feat
472
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/logger.py DELETED
@@ -1,36 +0,0 @@
1
- import os
2
- import logging
3
-
4
- import torch.distributed as dist
5
-
6
- class RankFilter(logging.Filter):
7
- def __init__(self, rank):
8
- super().__init__()
9
- self.rank = rank
10
-
11
- def filter(self, record):
12
- return dist.get_rank() == self.rank
13
-
14
- def create_logger(log_path):
15
- # Create log path
16
- if os.path.isdir(os.path.dirname(log_path)):
17
- os.makedirs(os.path.dirname(log_path), exist_ok=True)
18
-
19
- # Create logger object
20
- logger = logging.getLogger()
21
- logger.setLevel(logging.INFO)
22
- # Create file handler and set the formatter
23
- fh = logging.FileHandler(log_path)
24
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
25
- fh.setFormatter(formatter)
26
-
27
- # Add the file handler to the logger
28
- logger.addHandler(fh)
29
-
30
- # Add a stream handler to print to console
31
- sh = logging.StreamHandler()
32
- sh.setLevel(logging.INFO) # Set logging level for stream handler
33
- sh.setFormatter(formatter)
34
- logger.addHandler(sh)
35
-
36
- return logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/loss/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- import os
2
- import sys
3
- current_file_path = os.path.abspath(__file__)
4
- parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
- project_root_dir = os.path.dirname(parent_dir)
6
- sys.path.append(parent_dir)
7
- sys.path.append(project_root_dir)
8
-
9
- from metrics.registry import LOSSFUNC
10
-
11
- from .cross_entropy_loss import CrossEntropyLoss
12
- from .contrastive_regularization import ContrastiveLoss
13
- from .l1_loss import L1Loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/loss/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (565 Bytes)
 
test_deps/loss/__pycache__/abstract_loss_func.cpython-310.pyc DELETED
Binary file (977 Bytes)
 
test_deps/loss/__pycache__/contrastive_regularization.cpython-310.pyc DELETED
Binary file (2.38 kB)
 
test_deps/loss/__pycache__/cross_entropy_loss.cpython-310.pyc DELETED
Binary file (1.26 kB)
 
test_deps/loss/__pycache__/l1_loss.cpython-310.pyc DELETED
Binary file (892 Bytes)
 
test_deps/loss/abstract_loss_func.py DELETED
@@ -1,17 +0,0 @@
1
- import torch.nn as nn
2
-
3
- class AbstractLossClass(nn.Module):
4
- """Abstract class for loss functions."""
5
- def __init__(self):
6
- super(AbstractLossClass, self).__init__()
7
-
8
- def forward(self, pred, label):
9
- """
10
- Args:
11
- pred: prediction of the model
12
- label: ground truth label
13
-
14
- Return:
15
- loss: loss value
16
- """
17
- raise NotImplementedError('Each subclass should implement the forward method.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/loss/contrastive_regularization.py DELETED
@@ -1,78 +0,0 @@
1
- import random
2
- from collections import defaultdict
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from .abstract_loss_func import AbstractLossClass
7
- from metrics.registry import LOSSFUNC
8
-
9
-
10
- def swap_spe_features(type_list, value_list):
11
- type_list = type_list.cpu().numpy().tolist()
12
- # get index
13
- index_list = list(range(len(type_list)))
14
-
15
- # init a dict, where its key is the type and value is the index
16
- spe_dict = defaultdict(list)
17
-
18
- # do for-loop to get spe dict
19
- for i, one_type in enumerate(type_list):
20
- spe_dict[one_type].append(index_list[i])
21
-
22
- # shuffle the value list of each key
23
- for keys in spe_dict.keys():
24
- random.shuffle(spe_dict[keys])
25
-
26
- # generate a new index list for the value list
27
- new_index_list = []
28
- for one_type in type_list:
29
- value = spe_dict[one_type].pop()
30
- new_index_list.append(value)
31
-
32
- # swap the value_list by new_index_list
33
- value_list_new = value_list[new_index_list]
34
-
35
- return value_list_new
36
-
37
-
38
- @LOSSFUNC.register_module(module_name="contrastive_regularization")
39
- class ContrastiveLoss(AbstractLossClass):
40
- def __init__(self, margin=1.0):
41
- super().__init__()
42
- self.margin = margin
43
-
44
- def contrastive_loss(self, anchor, positive, negative):
45
- dist_pos = F.pairwise_distance(anchor, positive)
46
- dist_neg = F.pairwise_distance(anchor, negative)
47
- # Compute loss as the distance between anchor and negative minus the distance between anchor and positive
48
- loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0))
49
- return loss
50
-
51
- def forward(self, common, specific, spe_label):
52
- # prepare
53
- bs = common.shape[0]
54
- real_common, fake_common = common.chunk(2)
55
- ### common real
56
- idx_list = list(range(0, bs//2))
57
- random.shuffle(idx_list)
58
- real_common_anchor = common[idx_list]
59
- ### common fake
60
- idx_list = list(range(bs//2, bs))
61
- random.shuffle(idx_list)
62
- fake_common_anchor = common[idx_list]
63
- ### specific
64
- specific_anchor = swap_spe_features(spe_label, specific)
65
- real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2)
66
- real_specific, fake_specific = specific.chunk(2)
67
-
68
- # Compute the contrastive loss of common between real and fake
69
- loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor)
70
- loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor)
71
-
72
- # Comupte the constrastive loss of specific between real and fake
73
- loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor)
74
- loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor)
75
-
76
- # Compute the final loss as the sum of all contrastive losses
77
- loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific
78
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/loss/cross_entropy_loss.py DELETED
@@ -1,26 +0,0 @@
1
- import torch.nn as nn
2
- from .abstract_loss_func import AbstractLossClass
3
- from metrics.registry import LOSSFUNC
4
-
5
-
6
- @LOSSFUNC.register_module(module_name="cross_entropy")
7
- class CrossEntropyLoss(AbstractLossClass):
8
- def __init__(self):
9
- super().__init__()
10
- self.loss_fn = nn.CrossEntropyLoss()
11
-
12
- def forward(self, inputs, targets):
13
- """
14
- Computes the cross-entropy loss.
15
-
16
- Args:
17
- inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
18
- targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
19
-
20
- Returns:
21
- A scalar tensor representing the cross-entropy loss.
22
- """
23
- # Compute the cross-entropy loss
24
- loss = self.loss_fn(inputs, targets)
25
-
26
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/loss/l1_loss.py DELETED
@@ -1,19 +0,0 @@
1
- import torch.nn as nn
2
- from .abstract_loss_func import AbstractLossClass
3
- from metrics.registry import LOSSFUNC
4
-
5
-
6
- @LOSSFUNC.register_module(module_name="l1loss")
7
- class L1Loss(AbstractLossClass):
8
- def __init__(self):
9
- super().__init__()
10
- self.loss_fn = nn.L1Loss()
11
-
12
- def forward(self, inputs, targets):
13
- """
14
- Computes the l1 loss.
15
- """
16
- # Compute the l1 loss
17
- loss = self.loss_fn(inputs, targets)
18
-
19
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/metrics/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- import os
2
- import sys
3
- current_file_path = os.path.abspath(__file__)
4
- parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
- project_root_dir = os.path.dirname(parent_dir)
6
- sys.path.append(parent_dir)
7
- sys.path.append(project_root_dir)
 
 
 
 
 
 
 
 
test_deps/metrics/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (351 Bytes)
 
test_deps/metrics/__pycache__/base_metrics_class.cpython-310.pyc DELETED
Binary file (6.21 kB)
 
test_deps/metrics/__pycache__/registry.cpython-310.pyc DELETED
Binary file (1.01 kB)
 
test_deps/metrics/base_metrics_class.py DELETED
@@ -1,205 +0,0 @@
1
- import numpy as np
2
- from sklearn import metrics
3
- from collections import defaultdict
4
- import torch
5
- import torch.nn as nn
6
-
7
-
8
- def get_accracy(output, label):
9
- _, prediction = torch.max(output, 1) # argmax
10
- correct = (prediction == label).sum().item()
11
- accuracy = correct / prediction.size(0)
12
- return accuracy
13
-
14
-
15
- def get_prediction(output, label):
16
- prob = nn.functional.softmax(output, dim=1)[:, 1]
17
- prob = prob.view(prob.size(0), 1)
18
- label = label.view(label.size(0), 1)
19
- #print(prob.size(), label.size())
20
- datas = torch.cat((prob, label.float()), dim=1)
21
- return datas
22
-
23
-
24
- def calculate_metrics_for_train(label, output):
25
- if output.size(1) == 2:
26
- prob = torch.softmax(output, dim=1)[:, 1]
27
- else:
28
- prob = output
29
-
30
- # Accuracy
31
- _, prediction = torch.max(output, 1)
32
- correct = (prediction == label).sum().item()
33
- accuracy = correct / prediction.size(0)
34
-
35
- # Average Precision
36
- y_true = label.cpu().detach().numpy()
37
- y_pred = prob.cpu().detach().numpy()
38
- ap = metrics.average_precision_score(y_true, y_pred)
39
-
40
- # AUC and EER
41
- try:
42
- fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
43
- prob.squeeze().cpu().numpy(),
44
- pos_label=1)
45
- except:
46
- # for the case when we only have one sample
47
- return None, None, accuracy, ap
48
-
49
- if np.isnan(fpr[0]) or np.isnan(tpr[0]):
50
- # for the case when all the samples within a batch is fake/real
51
- auc, eer = None, None
52
- else:
53
- auc = metrics.auc(fpr, tpr)
54
- fnr = 1 - tpr
55
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
56
-
57
- return auc, eer, accuracy, ap
58
-
59
-
60
- # ------------ compute average metrics of batches---------------------
61
- class Metrics_batch():
62
- def __init__(self):
63
- self.tprs = []
64
- self.mean_fpr = np.linspace(0, 1, 100)
65
- self.aucs = []
66
- self.eers = []
67
- self.aps = []
68
-
69
- self.correct = 0
70
- self.total = 0
71
- self.losses = []
72
-
73
- def update(self, label, output):
74
- acc = self._update_acc(label, output)
75
- if output.size(1) == 2:
76
- prob = torch.softmax(output, dim=1)[:, 1]
77
- else:
78
- prob = output
79
- #label = 1-label
80
- #prob = torch.softmax(output, dim=1)[:, 1]
81
- auc, eer = self._update_auc(label, prob)
82
- ap = self._update_ap(label, prob)
83
-
84
- return acc, auc, eer, ap
85
-
86
- def _update_auc(self, lab, prob):
87
- fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
88
- prob.squeeze().cpu().numpy(),
89
- pos_label=1)
90
- if np.isnan(fpr[0]) or np.isnan(tpr[0]):
91
- return -1, -1
92
-
93
- auc = metrics.auc(fpr, tpr)
94
- interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
95
- interp_tpr[0] = 0.0
96
- self.tprs.append(interp_tpr)
97
- self.aucs.append(auc)
98
-
99
- # return auc
100
-
101
- # EER
102
- fnr = 1 - tpr
103
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
104
- self.eers.append(eer)
105
-
106
- return auc, eer
107
-
108
- def _update_acc(self, lab, output):
109
- _, prediction = torch.max(output, 1) # argmax
110
- correct = (prediction == lab).sum().item()
111
- accuracy = correct / prediction.size(0)
112
- # self.accs.append(accuracy)
113
- self.correct = self.correct+correct
114
- self.total = self.total+lab.size(0)
115
- return accuracy
116
-
117
- def _update_ap(self, label, prob):
118
- y_true = label.cpu().detach().numpy()
119
- y_pred = prob.cpu().detach().numpy()
120
- ap = metrics.average_precision_score(y_true,y_pred)
121
- self.aps.append(ap)
122
-
123
- return np.mean(ap)
124
-
125
- def get_mean_metrics(self):
126
- mean_acc, std_acc = self.correct/self.total, 0
127
- mean_auc, std_auc = self._mean_auc()
128
- mean_err, std_err = np.mean(self.eers), np.std(self.eers)
129
- mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
130
-
131
- return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
132
-
133
- def _mean_auc(self):
134
- mean_tpr = np.mean(self.tprs, axis=0)
135
- mean_tpr[-1] = 1.0
136
- mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
137
- std_auc = np.std(self.aucs)
138
- return mean_auc, std_auc
139
-
140
- def clear(self):
141
- self.tprs.clear()
142
- self.aucs.clear()
143
- # self.accs.clear()
144
- self.correct=0
145
- self.total=0
146
- self.eers.clear()
147
- self.aps.clear()
148
- self.losses.clear()
149
-
150
-
151
- # ------------ compute average metrics of all data ---------------------
152
- class Metrics_all():
153
- def __init__(self):
154
- self.probs = []
155
- self.labels = []
156
- self.correct = 0
157
- self.total = 0
158
-
159
- def store(self, label, output):
160
- prob = torch.softmax(output, dim=1)[:, 1]
161
- _, prediction = torch.max(output, 1) # argmax
162
- correct = (prediction == label).sum().item()
163
- self.correct += correct
164
- self.total += label.size(0)
165
- self.labels.append(label.squeeze().cpu().numpy())
166
- self.probs.append(prob.squeeze().cpu().numpy())
167
-
168
- def get_metrics(self):
169
- y_pred = np.concatenate(self.probs)
170
- y_true = np.concatenate(self.labels)
171
- # auc
172
- fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
173
- auc = metrics.auc(fpr, tpr)
174
- # eer
175
- fnr = 1 - tpr
176
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
177
- # ap
178
- ap = metrics.average_precision_score(y_true,y_pred)
179
- # acc
180
- acc = self.correct / self.total
181
- return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
182
-
183
- def clear(self):
184
- self.probs.clear()
185
- self.labels.clear()
186
- self.correct = 0
187
- self.total = 0
188
-
189
-
190
- # only used to record a series of scalar value
191
- class Recorder:
192
- def __init__(self):
193
- self.sum = 0
194
- self.num = 0
195
- def update(self, item, num=1):
196
- if item is not None:
197
- self.sum += item * num
198
- self.num += num
199
- def average(self):
200
- if self.num == 0:
201
- return None
202
- return self.sum/self.num
203
- def clear(self):
204
- self.sum = 0
205
- self.num = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/metrics/registry.py DELETED
@@ -1,20 +0,0 @@
1
- class Registry(object):
2
- def __init__(self):
3
- self.data = {}
4
-
5
- def register_module(self, module_name=None):
6
- def _register(cls):
7
- name = module_name
8
- if module_name is None:
9
- name = cls.__name__
10
- self.data[name] = cls
11
- return cls
12
- return _register
13
-
14
- def __getitem__(self, key):
15
- return self.data[key]
16
-
17
- BACKBONE = Registry()
18
- DETECTOR = Registry()
19
- TRAINER = Registry()
20
- LOSSFUNC = Registry()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/metrics/utils.py DELETED
@@ -1,88 +0,0 @@
1
- from sklearn import metrics
2
- import numpy as np
3
-
4
-
5
- def parse_metric_for_print(metric_dict):
6
- if metric_dict is None:
7
- return "\n"
8
- str = "\n"
9
- str += "================================ Each dataset best metric ================================ \n"
10
- for key, value in metric_dict.items():
11
- if key != 'avg':
12
- str= str+ f"| {key}: "
13
- for k,v in value.items():
14
- str = str + f" {k}={v} "
15
- str= str+ "| \n"
16
- else:
17
- str += "============================================================================================= \n"
18
- str += "================================== Average best metric ====================================== \n"
19
- avg_dict = value
20
- for avg_key, avg_value in avg_dict.items():
21
- if avg_key == 'dataset_dict':
22
- for key,value in avg_value.items():
23
- str = str + f"| {key}: {value} | \n"
24
- else:
25
- str = str + f"| avg {avg_key}: {avg_value} | \n"
26
- str += "============================================================================================="
27
- return str
28
-
29
-
30
- def get_test_metrics(y_pred, y_true, img_names=None, logger=None):
31
- def get_video_metrics(image, pred, label):
32
- result_dict = {}
33
- new_label = []
34
- new_pred = []
35
- # print(image[0])
36
- # print(pred.shape)
37
- # print(label.shape)
38
- for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
39
-
40
- s = item[0]
41
- if '\\' in s:
42
- parts = s.split('\\')
43
- else:
44
- parts = s.split('/')
45
- a = parts[-2]
46
- b = parts[-1]
47
-
48
- if a not in result_dict:
49
- result_dict[a] = []
50
-
51
- result_dict[a].append(item)
52
- image_arr = list(result_dict.values())
53
-
54
- for video in image_arr:
55
- pred_sum = 0
56
- label_sum = 0
57
- leng = 0
58
- for frame in video:
59
- pred_sum += float(frame[1])
60
- label_sum += int(frame[2])
61
- leng += 1
62
- new_pred.append(pred_sum / leng)
63
- new_label.append(int(label_sum / leng))
64
- fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
65
- v_auc = metrics.auc(fpr, tpr)
66
- fnr = 1 - tpr
67
- v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
68
- return v_auc, v_eer
69
-
70
-
71
- y_pred = y_pred.squeeze()
72
-
73
- # For UCF, where labels for different manipulations are not consistent.
74
- y_true[y_true >= 1] = 1
75
- # auc
76
- fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
77
- auc = metrics.auc(fpr, tpr)
78
- # eer
79
- fnr = 1 - tpr
80
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
81
- # ap
82
- ap = metrics.average_precision_score(y_true, y_pred)
83
- # acc
84
- prediction_class = (y_pred > 0.5).astype(int)
85
- correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
86
- acc = correct / len(prediction_class)
87
-
88
- return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/networks/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- import os
2
- import sys
3
- current_file_path = os.path.abspath(__file__)
4
- parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
- project_root_dir = os.path.dirname(parent_dir)
6
- sys.path.append(parent_dir)
7
- sys.path.append(project_root_dir)
8
-
9
- from metrics.registry import BACKBONE
10
-
11
- from .xception import Xception
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/networks/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (447 Bytes)
 
test_deps/networks/__pycache__/xception.cpython-310.pyc DELETED
Binary file (6.7 kB)
 
test_deps/networks/xception.py DELETED
@@ -1,285 +0,0 @@
1
- '''
2
- # author: Zhiyuan Yan
3
- # email: zhiyuanyan@link.cuhk.edu.cn
4
- # date: 2023-0706
5
-
6
- The code is mainly modified from GitHub link below:
7
- https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py
8
- '''
9
-
10
- import os
11
- import argparse
12
- import logging
13
-
14
- import math
15
- import torch
16
- # import pretrainedmodels
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
-
20
- import torch.utils.model_zoo as model_zoo
21
- from torch.nn import init
22
- from typing import Union
23
- from metrics.registry import BACKBONE
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
-
29
- class SeparableConv2d(nn.Module):
30
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
31
- super(SeparableConv2d, self).__init__()
32
-
33
- self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
34
- stride, padding, dilation, groups=in_channels, bias=bias)
35
- self.pointwise = nn.Conv2d(
36
- in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
37
-
38
- def forward(self, x):
39
- x = self.conv1(x)
40
- x = self.pointwise(x)
41
- return x
42
-
43
-
44
- class Block(nn.Module):
45
- def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
46
- super(Block, self).__init__()
47
-
48
- if out_filters != in_filters or strides != 1:
49
- self.skip = nn.Conv2d(in_filters, out_filters,
50
- 1, stride=strides, bias=False)
51
- self.skipbn = nn.BatchNorm2d(out_filters)
52
- else:
53
- self.skip = None
54
-
55
- self.relu = nn.ReLU(inplace=True)
56
- rep = []
57
-
58
- filters = in_filters
59
- if grow_first: # whether the number of filters grows first
60
- rep.append(self.relu)
61
- rep.append(SeparableConv2d(in_filters, out_filters,
62
- 3, stride=1, padding=1, bias=False))
63
- rep.append(nn.BatchNorm2d(out_filters))
64
- filters = out_filters
65
-
66
- for i in range(reps-1):
67
- rep.append(self.relu)
68
- rep.append(SeparableConv2d(filters, filters,
69
- 3, stride=1, padding=1, bias=False))
70
- rep.append(nn.BatchNorm2d(filters))
71
-
72
- if not grow_first:
73
- rep.append(self.relu)
74
- rep.append(SeparableConv2d(in_filters, out_filters,
75
- 3, stride=1, padding=1, bias=False))
76
- rep.append(nn.BatchNorm2d(out_filters))
77
-
78
- if not start_with_relu:
79
- rep = rep[1:]
80
- else:
81
- rep[0] = nn.ReLU(inplace=False)
82
-
83
- if strides != 1:
84
- rep.append(nn.MaxPool2d(3, strides, 1))
85
- self.rep = nn.Sequential(*rep)
86
-
87
- def forward(self, inp):
88
- x = self.rep(inp)
89
-
90
- if self.skip is not None:
91
- skip = self.skip(inp)
92
- skip = self.skipbn(skip)
93
- else:
94
- skip = inp
95
-
96
- x += skip
97
- return x
98
-
99
- def add_gaussian_noise(ins, mean=0, stddev=0.2):
100
- noise = ins.data.new(ins.size()).normal_(mean, stddev)
101
- return ins + noise
102
-
103
-
104
- @BACKBONE.register_module(module_name="xception")
105
- class Xception(nn.Module):
106
- """
107
- Xception optimized for the ImageNet dataset, as specified in
108
- https://arxiv.org/pdf/1610.02357.pdf
109
- """
110
-
111
- def __init__(self, xception_config):
112
- """ Constructor
113
- Args:
114
- xception_config: configuration file with the dict format
115
- """
116
- super(Xception, self).__init__()
117
- self.num_classes = xception_config["num_classes"]
118
- self.mode = xception_config["mode"]
119
- inc = xception_config["inc"]
120
- dropout = xception_config["dropout"]
121
-
122
- # Entry flow
123
- self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
124
-
125
- self.bn1 = nn.BatchNorm2d(32)
126
- self.relu = nn.ReLU(inplace=True)
127
-
128
- self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
129
- self.bn2 = nn.BatchNorm2d(64)
130
- # do relu here
131
-
132
- self.block1 = Block(
133
- 64, 128, 2, 2, start_with_relu=False, grow_first=True)
134
- self.block2 = Block(
135
- 128, 256, 2, 2, start_with_relu=True, grow_first=True)
136
- self.block3 = Block(
137
- 256, 728, 2, 2, start_with_relu=True, grow_first=True)
138
-
139
- # middle flow
140
- self.block4 = Block(
141
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
142
- self.block5 = Block(
143
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
144
- self.block6 = Block(
145
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
146
- self.block7 = Block(
147
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
148
-
149
- self.block8 = Block(
150
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
151
- self.block9 = Block(
152
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
153
- self.block10 = Block(
154
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
155
- self.block11 = Block(
156
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
157
-
158
- # Exit flow
159
- self.block12 = Block(
160
- 728, 1024, 2, 2, start_with_relu=True, grow_first=False)
161
-
162
- self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
163
- self.bn3 = nn.BatchNorm2d(1536)
164
-
165
- # do relu here
166
- self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
167
- self.bn4 = nn.BatchNorm2d(2048)
168
- # used for iid
169
- final_channel = 2048
170
- if self.mode == 'adjust_channel_iid':
171
- final_channel = 512
172
- self.mode = 'adjust_channel'
173
- self.last_linear = nn.Linear(final_channel, self.num_classes)
174
- if dropout:
175
- self.last_linear = nn.Sequential(
176
- nn.Dropout(p=dropout),
177
- nn.Linear(final_channel, self.num_classes)
178
- )
179
-
180
- self.adjust_channel = nn.Sequential(
181
- nn.Conv2d(2048, 512, 1, 1),
182
- nn.BatchNorm2d(512),
183
- nn.ReLU(inplace=False),
184
- )
185
-
186
- def fea_part1_0(self, x):
187
- x = self.conv1(x)
188
- x = self.bn1(x)
189
- x = self.relu(x)
190
-
191
- return x
192
-
193
- def fea_part1_1(self, x):
194
-
195
- x = self.conv2(x)
196
- x = self.bn2(x)
197
- x = self.relu(x)
198
-
199
- return x
200
-
201
- def fea_part1(self, x):
202
- x = self.conv1(x)
203
- x = self.bn1(x)
204
- x = self.relu(x)
205
-
206
- x = self.conv2(x)
207
- x = self.bn2(x)
208
- x = self.relu(x)
209
-
210
- return x
211
-
212
- def fea_part2(self, x):
213
- x = self.block1(x)
214
- x = self.block2(x)
215
- x = self.block3(x)
216
-
217
- return x
218
-
219
- def fea_part3(self, x):
220
- if self.mode == "shallow_xception":
221
- return x
222
- else:
223
- x = self.block4(x)
224
- x = self.block5(x)
225
- x = self.block6(x)
226
- x = self.block7(x)
227
- return x
228
-
229
- def fea_part4(self, x):
230
- if self.mode == "shallow_xception":
231
- x = self.block12(x)
232
- else:
233
- x = self.block8(x)
234
- x = self.block9(x)
235
- x = self.block10(x)
236
- x = self.block11(x)
237
- x = self.block12(x)
238
- return x
239
-
240
- def fea_part5(self, x):
241
- x = self.conv3(x)
242
- x = self.bn3(x)
243
- x = self.relu(x)
244
-
245
- x = self.conv4(x)
246
- x = self.bn4(x)
247
-
248
- return x
249
-
250
- def features(self, input):
251
- x = self.fea_part1(input)
252
-
253
- x = self.fea_part2(x)
254
- x = self.fea_part3(x)
255
- x = self.fea_part4(x)
256
-
257
- x = self.fea_part5(x)
258
-
259
- if self.mode == 'adjust_channel':
260
- x = self.adjust_channel(x)
261
-
262
- return x
263
-
264
- def classifier(self, features,id_feat=None):
265
- # for iid
266
- if self.mode == 'adjust_channel':
267
- x = features
268
- else:
269
- x = self.relu(features)
270
-
271
- if len(x.shape) == 4:
272
- x = F.adaptive_avg_pool2d(x, (1, 1))
273
- x = x.view(x.size(0), -1)
274
- self.last_emb = x
275
- # for iid
276
- if id_feat!=None:
277
- out = self.last_linear(x-id_feat)
278
- else:
279
- out = self.last_linear(x)
280
- return out
281
-
282
- def forward(self, input):
283
- x = self.features(input)
284
- out = self.classifier(x)
285
- return out, x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/optimizor/LinearLR.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- from torch.optim import SGD
3
- from torch.optim.lr_scheduler import _LRScheduler
4
-
5
- class LinearDecayLR(_LRScheduler):
6
- def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
7
- self.start_decay=start_decay
8
- self.n_epoch=n_epoch
9
- super(LinearDecayLR, self).__init__(optimizer, last_epoch)
10
-
11
- def get_lr(self):
12
- last_epoch = self.last_epoch
13
- n_epoch=self.n_epoch
14
- b_lr=self.base_lrs[0]
15
- start_decay=self.start_decay
16
- if last_epoch>start_decay:
17
- lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
18
- else:
19
- lr=b_lr
20
- return [lr]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/optimizor/SAM.py DELETED
@@ -1,77 +0,0 @@
1
- # borrowed from
2
-
3
- import torch
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- def disable_running_stats(model):
9
- def _disable(module):
10
- if isinstance(module, nn.BatchNorm2d):
11
- module.backup_momentum = module.momentum
12
- module.momentum = 0
13
-
14
- model.apply(_disable)
15
-
16
- def enable_running_stats(model):
17
- def _enable(module):
18
- if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
19
- module.momentum = module.backup_momentum
20
-
21
- model.apply(_enable)
22
-
23
- class SAM(torch.optim.Optimizer):
24
- def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
25
- assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
26
-
27
- defaults = dict(rho=rho, **kwargs)
28
- super(SAM, self).__init__(params, defaults)
29
-
30
- self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
31
- self.param_groups = self.base_optimizer.param_groups
32
-
33
- @torch.no_grad()
34
- def first_step(self, zero_grad=False):
35
- grad_norm = self._grad_norm()
36
- for group in self.param_groups:
37
- scale = group["rho"] / (grad_norm + 1e-12)
38
-
39
- for p in group["params"]:
40
- if p.grad is None: continue
41
- e_w = p.grad * scale.to(p)
42
- p.add_(e_w) # climb to the local maximum "w + e(w)"
43
- self.state[p]["e_w"] = e_w
44
-
45
- if zero_grad: self.zero_grad()
46
-
47
- @torch.no_grad()
48
- def second_step(self, zero_grad=False):
49
- for group in self.param_groups:
50
- for p in group["params"]:
51
- if p.grad is None: continue
52
- p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
53
-
54
- self.base_optimizer.step() # do the actual "sharpness-aware" update
55
-
56
- if zero_grad: self.zero_grad()
57
-
58
- @torch.no_grad()
59
- def step(self, closure=None):
60
- assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
61
- closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
62
-
63
- self.first_step(zero_grad=True)
64
- closure()
65
- self.second_step()
66
-
67
- def _grad_norm(self):
68
- shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
69
- norm = torch.norm(
70
- torch.stack([
71
- p.grad.norm(p=2).to(shared_device)
72
- for group in self.param_groups for p in group["params"]
73
- if p.grad is not None
74
- ]),
75
- p=2
76
- )
77
- return norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/train_detector.py DELETED
@@ -1,460 +0,0 @@
1
- # This script was adapted from the DeepfakeBench training code,
2
- # originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
3
-
4
- # Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
5
-
6
- # BitMind's modifications include adding a testing phase, changing the
7
- # data load/split pipeline to work with subnet 34's image augmentations
8
- # and datasets from BitMind HuggingFace repositories, quality of life CLI args,
9
- # logging changes, etc.
10
-
11
- import os
12
- import sys
13
- import argparse
14
- from os.path import join
15
- import random
16
- import datetime
17
- import time
18
- import yaml
19
- from tqdm import tqdm
20
- import numpy as np
21
- from datetime import timedelta
22
- from copy import deepcopy
23
- from PIL import Image as pil_image
24
- from pathlib import Path
25
- import gc
26
-
27
- import torch
28
- import torch.nn as nn
29
- import torch.nn.parallel
30
- import torch.backends.cudnn as cudnn
31
- import torch.utils.data
32
- import torch.optim as optim
33
- from torch.utils.data.distributed import DistributedSampler
34
- import torch.distributed as dist
35
- from torch.utils.data import DataLoader
36
-
37
- from optimizor.SAM import SAM
38
- from optimizor.LinearLR import LinearDecayLR
39
-
40
- from trainer.trainer import Trainer
41
- from arena.detectors.UCF.detectors import DETECTOR
42
- from metrics.utils import parse_metric_for_print
43
- from logger import create_logger, RankFilter
44
-
45
- from huggingface_hub import hf_hub_download
46
-
47
- # BitMind imports (not from original Deepfake Bench repo)
48
- from bitmind.dataset_processing.load_split_data import load_datasets, create_real_fake_datasets
49
- from bitmind.image_transforms import base_transforms, random_aug_transforms
50
- from bitmind.constants import DATASET_META, FACE_TRAINING_DATASET_META
51
- from config.constants import (
52
- CONFIG_PATH,
53
- WEIGHTS_DIR,
54
- HF_REPO,
55
- BACKBONE_CKPT
56
- )
57
-
58
- parser = argparse.ArgumentParser(description='Process some paths.')
59
- parser.add_argument('--detector_path', type=str, default=CONFIG_PATH, help='path to detector YAML file')
60
- parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False)
61
- parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True)
62
- parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True)
63
- parser.add_argument("--ddp", action='store_true', default=False)
64
- parser.add_argument('--local_rank', type=int, default=0)
65
- parser.add_argument('--workers', type=int, default=os.cpu_count() - 1,
66
- help='number of workers for data loading')
67
- parser.add_argument('--epochs', type=int, default=None, help='number of training epochs')
68
-
69
- args = parser.parse_args()
70
- torch.cuda.set_device(args.local_rank)
71
- print(f"torch.cuda.device(0): {torch.cuda.device(0)}")
72
- print(f"torch.cuda.get_device_name(0): {torch.cuda.get_device_name(0)}")
73
-
74
- def ensure_backbone_is_available(logger,
75
- weights_dir=WEIGHTS_DIR,
76
- model_filename=BACKBONE_CKPT,
77
- hugging_face_repo_name=HF_REPO):
78
-
79
- destination_path = Path(weights_dir) / Path(model_filename)
80
- if not destination_path.parent.exists():
81
- destination_path.parent.mkdir(parents=True, exist_ok=True)
82
- logger.info(f"Created directory {destination_path.parent}.")
83
- if not destination_path.exists():
84
- model_path = hf_hub_download(hugging_face_repo_name, model_filename)
85
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
- model = torch.load(model_path, map_location=device)
87
- torch.save(model, destination_path)
88
- del model
89
- if torch.cuda.is_available():
90
- torch.cuda.empty_cache()
91
- logger.info(f"Downloaded backbone {model_filename} to {destination_path}.")
92
- else:
93
- logger.info(f"{model_filename} backbone already present at {destination_path}.")
94
-
95
- def init_seed(config):
96
- if config['manualSeed'] is None:
97
- config['manualSeed'] = random.randint(1, 10000)
98
- random.seed(config['manualSeed'])
99
- if config['cuda']:
100
- torch.manual_seed(config['manualSeed'])
101
- torch.cuda.manual_seed_all(config['manualSeed'])
102
-
103
- def custom_collate_fn(batch):
104
- images, labels, source_labels = zip(*batch)
105
-
106
- images = torch.stack(images, dim=0) # Stack image tensors into a single tensor
107
- labels = torch.LongTensor(labels)
108
- source_labels = torch.LongTensor(source_labels)
109
-
110
- data_dict = {
111
- 'image': images,
112
- 'label': labels,
113
- 'label_spe': source_labels,
114
- 'landmark': None,
115
- 'mask': None
116
- }
117
- return data_dict
118
-
119
- def prepare_datasets(config, logger):
120
- start_time = log_start_time(logger, "Loading and splitting individual datasets")
121
-
122
- real_datasets, fake_datasets = load_datasets(dataset_meta=config['dataset_meta'],
123
- expert=config['faces_only'],
124
- split_transforms=config['split_transforms'])
125
-
126
- log_finish_time(logger, "Loading and splitting individual datasets", start_time)
127
-
128
- start_time = log_start_time(logger, "Creating real fake dataset splits")
129
- train_dataset, val_dataset, test_dataset = \
130
- create_real_fake_datasets(real_datasets,
131
- fake_datasets,
132
- config['split_transforms']['train']['transform'],
133
- config['split_transforms']['validation']['transform'],
134
- config['split_transforms']['test']['transform'],
135
- source_labels=True)
136
-
137
- log_finish_time(logger, "Creating real fake dataset splits", start_time)
138
-
139
- train_loader = torch.utils.data.DataLoader(train_dataset,
140
- batch_size=config['train_batchSize'],
141
- shuffle=True,
142
- num_workers=config['workers'],
143
- drop_last=True,
144
- collate_fn=custom_collate_fn)
145
- val_loader = torch.utils.data.DataLoader(val_dataset,
146
- batch_size=config['train_batchSize'],
147
- shuffle=True,
148
- num_workers=config['workers'],
149
- drop_last=True,
150
- collate_fn=custom_collate_fn)
151
- test_loader = torch.utils.data.DataLoader(test_dataset,
152
- batch_size=config['train_batchSize'],
153
- shuffle=True,
154
- num_workers=config['workers'],
155
- drop_last=True,
156
- collate_fn=custom_collate_fn)
157
-
158
- print(f"Train size: {len(train_loader.dataset)}")
159
- print(f"Validation size: {len(val_loader.dataset)}")
160
- print(f"Test size: {len(test_loader.dataset)}")
161
-
162
- return train_loader, val_loader, test_loader
163
-
164
- def choose_optimizer(model, config):
165
- opt_name = config['optimizer']['type']
166
- if opt_name == 'sgd':
167
- optimizer = optim.SGD(
168
- params=model.parameters(),
169
- lr=config['optimizer'][opt_name]['lr'],
170
- momentum=config['optimizer'][opt_name]['momentum'],
171
- weight_decay=config['optimizer'][opt_name]['weight_decay']
172
- )
173
- return optimizer
174
- elif opt_name == 'adam':
175
- optimizer = optim.Adam(
176
- params=model.parameters(),
177
- lr=config['optimizer'][opt_name]['lr'],
178
- weight_decay=config['optimizer'][opt_name]['weight_decay'],
179
- betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']),
180
- eps=config['optimizer'][opt_name]['eps'],
181
- amsgrad=config['optimizer'][opt_name]['amsgrad'],
182
- )
183
- return optimizer
184
- elif opt_name == 'sam':
185
- optimizer = SAM(
186
- model.parameters(),
187
- optim.SGD,
188
- lr=config['optimizer'][opt_name]['lr'],
189
- momentum=config['optimizer'][opt_name]['momentum'],
190
- )
191
- else:
192
- raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer']))
193
- return optimizer
194
-
195
-
196
- def choose_scheduler(config, optimizer):
197
- if config['lr_scheduler'] is None:
198
- return None
199
- elif config['lr_scheduler'] == 'step':
200
- scheduler = optim.lr_scheduler.StepLR(
201
- optimizer,
202
- step_size=config['lr_step'],
203
- gamma=config['lr_gamma'],
204
- )
205
- return scheduler
206
- elif config['lr_scheduler'] == 'cosine':
207
- scheduler = optim.lr_scheduler.CosineAnnealingLR(
208
- optimizer,
209
- T_max=config['lr_T_max'],
210
- eta_min=config['lr_eta_min'],
211
- )
212
- return scheduler
213
- elif config['lr_scheduler'] == 'linear':
214
- scheduler = LinearDecayLR(
215
- optimizer,
216
- config['nEpochs'],
217
- int(config['nEpochs']/4),
218
- )
219
- else:
220
- raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler']))
221
-
222
- def choose_metric(config):
223
- metric_scoring = config['metric_scoring']
224
- if metric_scoring not in ['eer', 'auc', 'acc', 'ap']:
225
- raise NotImplementedError('metric {} is not implemented'.format(metric_scoring))
226
- return metric_scoring
227
-
228
- def log_start_time(logger, process_name):
229
- """Log the start time of a process."""
230
- start_time = time.time()
231
- logger.info(f"{process_name} Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
232
- return start_time
233
-
234
- def log_finish_time(logger, process_name, start_time):
235
- """Log the finish time and elapsed time of a process."""
236
- finish_time = time.time()
237
- elapsed_time = finish_time - start_time
238
-
239
- # Convert elapsed time into hours, minutes, and seconds
240
- hours, rem = divmod(elapsed_time, 3600)
241
- minutes, seconds = divmod(rem, 60)
242
-
243
- # Log the finish time and elapsed time
244
- logger.info(f"{process_name} Finish Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(finish_time))}")
245
- logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
246
-
247
- def save_config(config, outputs_dir):
248
- """
249
- Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved.
250
- Also, lists like 'mean' and 'std' are saved in flow style (on a single line).
251
-
252
- Args:
253
- config (dict): The configuration dictionary to save.
254
- outputs_dir (str): The directory path where the files will be saved.
255
- """
256
-
257
- def is_basic_type(value):
258
- """
259
- Check if a value is a basic data type that can be saved in YAML.
260
- Basic types include int, float, str, bool, list, and dict.
261
- """
262
- return isinstance(value, (int, float, str, bool, list, dict, type(None)))
263
-
264
- def filter_dict(data_dict):
265
- """
266
- Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects).
267
- """
268
- if not isinstance(data_dict, dict):
269
- return data_dict
270
-
271
- filtered_dict = {}
272
- for key, value in data_dict.items():
273
- if isinstance(value, dict):
274
- # Recursively filter nested dictionaries
275
- nested_dict = filter_dict(value)
276
- if nested_dict: # Only add non-empty dictionaries
277
- filtered_dict[key] = nested_dict
278
- elif is_basic_type(value):
279
- # Add if the value is a basic type
280
- filtered_dict[key] = value
281
- else:
282
- # Skip the key if the value is not a basic type (e.g., an object)
283
- print(f"Skipping key '{key}' because its value is of type {type(value)}")
284
-
285
- return filtered_dict
286
-
287
- def save_dict_to_yaml(data_dict, file_path):
288
- """
289
- Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object.
290
- Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style.
291
-
292
- Args:
293
- data_dict (dict): The dictionary to save.
294
- file_path (str): The local file path where the YAML file will be saved.
295
- """
296
-
297
- # Custom representer for lists to force flow style (compact lists)
298
- class FlowStyleList(list):
299
- pass
300
-
301
- def flow_style_list_representer(dumper, data):
302
- return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)
303
-
304
- yaml.add_representer(FlowStyleList, flow_style_list_representer)
305
-
306
- # Preprocess specific lists to be in flow style
307
- if 'mean' in data_dict:
308
- data_dict['mean'] = FlowStyleList(data_dict['mean'])
309
- if 'std' in data_dict:
310
- data_dict['std'] = FlowStyleList(data_dict['std'])
311
-
312
- try:
313
- # Filter the dictionary
314
- filtered_dict = filter_dict(data_dict)
315
-
316
- # Save the filtered dictionary as YAML
317
- with open(file_path, 'w') as f:
318
- yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList
319
- print(f"Filtered dictionary successfully saved to {file_path}")
320
- except Exception as e:
321
- print(f"Error saving dictionary to YAML: {e}")
322
-
323
- # Save as YAML
324
- save_dict_to_yaml(config, outputs_dir + '/config.yaml')
325
-
326
- def main():
327
- torch.cuda.empty_cache()
328
- gc.collect()
329
- # parse options and load config
330
- with open(args.detector_path, 'r') as f:
331
- config = yaml.safe_load(f)
332
- with open(os.getcwd() + '/config/train_config.yaml', 'r') as f:
333
- config2 = yaml.safe_load(f)
334
- if 'label_dict' in config:
335
- config2['label_dict']=config['label_dict']
336
- config.update(config2)
337
-
338
- config['workers'] = args.workers
339
-
340
- config['local_rank']=args.local_rank
341
- if config['dry_run']:
342
- config['nEpochs'] = 0
343
- config['save_feat']=False
344
-
345
- if args.epochs: config['nEpochs'] = args.epochs
346
-
347
- config['split_transforms'] = {'train': {'name': 'base_transforms',
348
- 'transform': base_transforms},
349
- 'validation': {'name': 'base_transforms',
350
- 'transform': base_transforms},
351
- 'test': {'name': 'base_transforms',
352
- 'transform': base_transforms}}
353
- config['faces_only'] = args.faces_only
354
- config['dataset_meta'] = FACE_TRAINING_DATASET_META if config['faces_only'] else DATASET_META
355
- dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets]
356
- config['train_dataset'] = dataset_names
357
- config['save_ckpt'] = args.save_ckpt
358
- config['save_feat'] = args.save_feat
359
-
360
- config['specific_task_number'] = len(config['dataset_meta']["fake"]) + 1
361
-
362
- if config['lmdb']:
363
- config['dataset_json_folder'] = 'preprocessing/dataset_json_v3'
364
-
365
- # create logger
366
- timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
367
-
368
- outputs_dir = os.path.join(
369
- config['log_dir'],
370
- config['model_name'] + '_' + timenow
371
- )
372
-
373
- os.makedirs(outputs_dir, exist_ok=True)
374
- logger = create_logger(os.path.join(outputs_dir, 'training.log'))
375
- config['log_dir'] = outputs_dir
376
- logger.info('Save log to {}'.format(outputs_dir))
377
-
378
- config['ddp']= args.ddp
379
-
380
- # init seed
381
- init_seed(config)
382
-
383
- # set cudnn benchmark if needed
384
- if config['cudnn']:
385
- cudnn.benchmark = True
386
- if config['ddp']:
387
- # dist.init_process_group(backend='gloo')
388
- dist.init_process_group(
389
- backend='nccl',
390
- timeout=timedelta(minutes=30)
391
- )
392
- logger.addFilter(RankFilter(0))
393
-
394
- ensure_backbone_is_available(logger=logger,
395
- model_filename=config['pretrained'].split('/')[-1],
396
- hugging_face_repo_name='bitmind/' + config['model_name'])
397
-
398
- # prepare the model (detector)
399
- model_class = DETECTOR[config['model_name']]
400
- model = model_class(config)
401
-
402
- # prepare the optimizer
403
- optimizer = choose_optimizer(model, config)
404
-
405
- # prepare the scheduler
406
- scheduler = choose_scheduler(config, optimizer)
407
-
408
- # prepare the metric
409
- metric_scoring = choose_metric(config)
410
-
411
- # prepare the trainer
412
- trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring)
413
-
414
- # prepare the data loaders
415
- train_loader, val_loader, test_loader = prepare_datasets(config, logger)
416
-
417
- # print configuration
418
- logger.info("--------------- Configuration ---------------")
419
- params_string = "Parameters: \n"
420
- for key, value in config.items():
421
- params_string += "{}: {}".format(key, value) + "\n"
422
- logger.info(params_string)
423
-
424
- # save training configs
425
- save_config(config, outputs_dir)
426
-
427
- # start training
428
- start_time = log_start_time(logger, "Training")
429
- for epoch in range(config['start_epoch'], config['nEpochs'] + 1):
430
- trainer.model.epoch = epoch
431
- best_metric = trainer.train_epoch(
432
- epoch,
433
- train_data_loader=train_loader,
434
- validation_data_loaders={'val':val_loader}
435
- )
436
- if best_metric is not None:
437
- logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!")
438
- logger.info("Stop Training on best Validation metric {}".format(parse_metric_for_print(best_metric)))
439
- log_finish_time(logger, "Training", start_time)
440
-
441
- # test
442
- start_time = log_start_time(logger, "Test")
443
- trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test")
444
- log_finish_time(logger, "Test", start_time)
445
-
446
- # update
447
- if 'svdd' in config['model_name']:
448
- model.update_R(epoch)
449
- if scheduler is not None:
450
- scheduler.step()
451
-
452
- # close the tensorboard writers
453
- for writer in trainer.writers.values():
454
- writer.close()
455
-
456
- torch.cuda.empty_cache()
457
- gc.collect()
458
-
459
- if __name__ == '__main__':
460
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps/trainer/trainer.py DELETED
@@ -1,441 +0,0 @@
1
- # This script was adapted from the DeepfakeBench training code,
2
- # originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
3
-
4
- # Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
5
-
6
- import os
7
- import sys
8
- current_file_path = os.path.abspath(__file__)
9
- parent_dir = os.path.dirname(os.path.dirname(current_file_path))
10
- project_root_dir = os.path.dirname(parent_dir)
11
- sys.path.append(parent_dir)
12
- sys.path.append(project_root_dir)
13
-
14
- import pickle
15
- import datetime
16
- import logging
17
- import numpy as np
18
- from copy import deepcopy
19
- from collections import defaultdict
20
- from tqdm import tqdm
21
- import time
22
- import torch
23
- import torch.nn as nn
24
- import torch.nn.functional as F
25
- import torch.optim as optim
26
- from torch.nn import DataParallel
27
- from torch.utils.tensorboard import SummaryWriter
28
- from metrics.base_metrics_class import Recorder
29
- from torch.optim.swa_utils import AveragedModel, SWALR
30
- from torch import distributed as dist
31
- from torch.nn.parallel import DistributedDataParallel as DDP
32
- from sklearn import metrics
33
- from metrics.utils import get_test_metrics
34
-
35
- FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT']#
36
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
-
38
-
39
- class Trainer(object):
40
- def __init__(
41
- self,
42
- config,
43
- model,
44
- optimizer,
45
- scheduler,
46
- logger,
47
- metric_scoring='auc',
48
- swa_model=None
49
- ):
50
- # check if all the necessary components are implemented
51
- if config is None or model is None or optimizer is None or logger is None:
52
- raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented")
53
-
54
- self.config = config
55
- self.model = model
56
- self.optimizer = optimizer
57
- self.scheduler = scheduler
58
- self.swa_model = swa_model
59
- self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric
60
- self.logger = logger
61
- self.metric_scoring = metric_scoring
62
- # maintain the best metric of all epochs
63
- self.best_metrics_all_time = defaultdict(
64
- lambda: defaultdict(lambda: float('-inf')
65
- if self.metric_scoring != 'eer' else float('inf'))
66
- )
67
- self.speed_up() # move model to GPU
68
-
69
- # create directory path
70
- self.log_dir = self.config['log_dir']
71
- print("Making dir ", self.log_dir)
72
- os.makedirs(self.log_dir, exist_ok=True)
73
-
74
- def get_writer(self, phase, dataset_key, metric_key):
75
- phase = phase.split('/')[-1]
76
- dataset_key = dataset_key.split('/')[-1]
77
- metric_key = metric_key.split('/')[-1]
78
- writer_key = f"{phase}-{dataset_key}-{metric_key}"
79
- if writer_key not in self.writers:
80
- # update directory path
81
- writer_path = os.path.join(
82
- self.log_dir,
83
- phase,
84
- dataset_key,
85
- metric_key,
86
- "metric_board"
87
- )
88
- os.makedirs(writer_path, exist_ok=True)
89
- # update writers dictionary
90
- self.writers[writer_key] = SummaryWriter(writer_path)
91
- return self.writers[writer_key]
92
-
93
- def speed_up(self):
94
- self.model.to(device)
95
- self.model.device = device
96
- if self.config['ddp'] == True:
97
- num_gpus = torch.cuda.device_count()
98
- print(f'avai gpus: {num_gpus}')
99
- # local_rank=[i for i in range(0,num_gpus)]
100
- self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank'])
101
- #self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])])
102
-
103
- def setTrain(self):
104
- self.model.train()
105
- self.train = True
106
-
107
- def setEval(self):
108
- self.model.eval()
109
- self.train = False
110
-
111
- def load_ckpt(self, model_path):
112
- if os.path.isfile(model_path):
113
- saved = torch.load(model_path, map_location='cpu')
114
- suffix = model_path.split('.')[-1]
115
- if suffix == 'p':
116
- self.model.load_state_dict(saved.state_dict())
117
- else:
118
- self.model.load_state_dict(saved)
119
- self.logger.info('Model found in {}'.format(model_path))
120
- else:
121
- raise NotImplementedError(
122
- "=> no model found at '{}'".format(model_path))
123
-
124
- def save_ckpt(self, phase, dataset_key,ckpt_info=None):
125
- save_dir = self.log_dir
126
- os.makedirs(save_dir, exist_ok=True)
127
- ckpt_name = f"ckpt_best.pth"
128
- save_path = os.path.join(save_dir, ckpt_name)
129
- if self.config['ddp'] == True:
130
- torch.save(self.model.state_dict(), save_path)
131
- else:
132
- if 'svdd' in self.config['model_name']:
133
- torch.save({'R': self.model.R,
134
- 'c': self.model.c,
135
- 'state_dict': self.model.state_dict(),}, save_path)
136
- else:
137
- torch.save(self.model.state_dict(), save_path)
138
- self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}")
139
-
140
- def save_swa_ckpt(self):
141
- save_dir = self.log_dir
142
- os.makedirs(save_dir, exist_ok=True)
143
- ckpt_name = f"swa.pth"
144
- save_path = os.path.join(save_dir, ckpt_name)
145
- torch.save(self.swa_model.state_dict(), save_path)
146
- self.logger.info(f"SWA Checkpoint saved to {save_path}")
147
-
148
- def save_feat(self, phase, fea, dataset_key):
149
- save_dir = os.path.join(self.log_dir, phase, dataset_key)
150
- os.makedirs(save_dir, exist_ok=True)
151
- features = fea
152
- feat_name = f"feat_best.npy"
153
- save_path = os.path.join(save_dir, feat_name)
154
- np.save(save_path, features)
155
- self.logger.info(f"Feature saved to {save_path}")
156
-
157
- def save_data_dict(self, phase, data_dict, dataset_key):
158
- save_dir = os.path.join(self.log_dir, phase, dataset_key)
159
- os.makedirs(save_dir, exist_ok=True)
160
- file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle')
161
- with open(file_path, 'wb') as file:
162
- pickle.dump(data_dict, file)
163
- self.logger.info(f"data_dict saved to {file_path}")
164
-
165
- def save_metrics(self, phase, metric_one_dataset, dataset_key):
166
- save_dir = os.path.join(self.log_dir, phase, dataset_key)
167
- os.makedirs(save_dir, exist_ok=True)
168
- file_path = os.path.join(save_dir, 'metric_dict_best.pickle')
169
- with open(file_path, 'wb') as file:
170
- pickle.dump(metric_one_dataset, file)
171
- self.logger.info(f"Metrics saved to {file_path}")
172
-
173
- def train_step(self,data_dict):
174
- if self.config['optimizer']['type']=='sam':
175
- for i in range(2):
176
- predictions = self.model(data_dict)
177
- losses = self.model.get_losses(data_dict, predictions)
178
- if i == 0:
179
- pred_first = predictions
180
- losses_first = losses
181
- self.optimizer.zero_grad()
182
- losses['overall'].backward()
183
- if i == 0:
184
- self.optimizer.first_step(zero_grad=True)
185
- else:
186
- self.optimizer.second_step(zero_grad=True)
187
- return losses_first, pred_first
188
- else:
189
- predictions = self.model(data_dict)
190
- if type(self.model) is DDP:
191
- losses = self.model.module.get_losses(data_dict, predictions)
192
- else:
193
- losses = self.model.get_losses(data_dict, predictions)
194
- self.optimizer.zero_grad()
195
- losses['overall'].backward()
196
- self.optimizer.step()
197
-
198
- return losses,predictions
199
-
200
- def train_epoch(
201
- self,
202
- epoch,
203
- train_data_loader,
204
- validation_data_loaders=None
205
- ):
206
-
207
- self.logger.info("===> Epoch[{}] start!".format(epoch))
208
- if epoch>=1:
209
- times_per_epoch = 2
210
- else:
211
- times_per_epoch = 1
212
-
213
-
214
- #times_per_epoch=4
215
- validation_step = len(train_data_loader) // times_per_epoch # validate 10 times per epoch
216
- step_cnt = epoch * len(train_data_loader)
217
-
218
- # define training recorder
219
- train_recorder_loss = defaultdict(Recorder)
220
- train_recorder_metric = defaultdict(Recorder)
221
-
222
- for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)):
223
- self.setTrain()
224
- # more elegant and more scalable way of moving data to GPU
225
- for key in data_dict.keys():
226
- if data_dict[key]!=None and key!='name':
227
- data_dict[key]=data_dict[key].cuda()
228
-
229
- losses, predictions=self.train_step(data_dict)
230
- # update learning rate
231
-
232
- if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']:
233
- self.swa_model.update_parameters(self.model)
234
-
235
- # compute training metric for each batch data
236
- if type(self.model) is DDP:
237
- batch_metrics = self.model.module.get_train_metrics(data_dict, predictions)
238
- else:
239
- batch_metrics = self.model.get_train_metrics(data_dict, predictions)
240
-
241
- # store data by recorder
242
- ## store metric
243
- for name, value in batch_metrics.items():
244
- train_recorder_metric[name].update(value)
245
- ## store loss
246
- for name, value in losses.items():
247
- train_recorder_loss[name].update(value)
248
-
249
- # run tensorboard to visualize the training process
250
- if iteration % 300 == 0 and self.config['local_rank']==0:
251
- if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']):
252
- self.scheduler.step()
253
- # info for loss
254
- loss_str = f"Iter: {step_cnt} "
255
- for k, v in train_recorder_loss.items():
256
- v_avg = v.average()
257
- if v_avg == None:
258
- loss_str += f"training-loss, {k}: not calculated"
259
- continue
260
- loss_str += f"training-loss, {k}: {v_avg} "
261
- # tensorboard-1. loss
262
- processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
263
- processed_train_dataset = ','.join(processed_train_dataset)
264
- writer = self.get_writer('train', processed_train_dataset, k)
265
- writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt)
266
- self.logger.info(loss_str)
267
- # info for metric
268
- metric_str = f"Iter: {step_cnt} "
269
- for k, v in train_recorder_metric.items():
270
- v_avg = v.average()
271
- if v_avg == None:
272
- metric_str += f"training-metric, {k}: not calculated "
273
- continue
274
- metric_str += f"training-metric, {k}: {v_avg} "
275
- # tensorboard-2. metric
276
- processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
277
- processed_train_dataset = ','.join(processed_train_dataset)
278
- writer = self.get_writer('train', processed_train_dataset, k)
279
- writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt)
280
- self.logger.info(metric_str)
281
-
282
- # clear recorder.
283
- # Note we only consider the current 300 samples for computing batch-level loss/metric
284
- for name, recorder in train_recorder_loss.items(): # clear loss recorder
285
- recorder.clear()
286
- for name, recorder in train_recorder_metric.items(): # clear metric recorder
287
- recorder.clear()
288
-
289
- # run validation
290
- if (step_cnt+1) % validation_step == 0:
291
- if validation_data_loaders is not None and ((not self.config['ddp']) or (self.config['ddp'] and dist.get_rank() == 0)):
292
- self.logger.info("===> Validation start!")
293
- validation_best_metric = self.eval(
294
- eval_data_loaders=validation_data_loaders,
295
- eval_stage="validation",
296
- step=step_cnt,
297
- epoch=epoch,
298
- iteration=iteration
299
- )
300
- else:
301
- validation_best_metric = None
302
-
303
- step_cnt += 1
304
-
305
- for key in data_dict.keys():
306
- if data_dict[key]!=None and key!='name':
307
- data_dict[key]=data_dict[key].cpu()
308
- return validation_best_metric
309
-
310
- def get_respect_acc(self,prob,label):
311
- pred = np.where(prob > 0.5, 1, 0)
312
- judge = (pred == label)
313
- zero_num = len(label) - np.count_nonzero(label)
314
- acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:])
315
- acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num])
316
- return acc_real,acc_fake
317
-
318
- def eval_one_dataset(self, data_loader):
319
- # define eval recorder
320
- eval_recorder_loss = defaultdict(Recorder)
321
- prediction_lists = []
322
- feature_lists=[]
323
- label_lists = []
324
- for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)):
325
- # get data
326
- if 'label_spe' in data_dict:
327
- data_dict.pop('label_spe') # remove the specific label
328
- data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only
329
- # move data to GPU elegantly
330
- for key in data_dict.keys():
331
- if data_dict[key]!=None:
332
- data_dict[key]=data_dict[key].cuda()
333
- # model forward without considering gradient computation
334
- predictions = self.inference(data_dict) #dict with keys cls, feat
335
-
336
- label_lists += list(data_dict['label'].cpu().detach().numpy())
337
- # Get the predicted class for each sample in the batch
338
- _, predicted_classes = torch.max(predictions['cls'], dim=1)
339
- # Convert the predicted class indices to a list and add to prediction_lists
340
- prediction_lists += predicted_classes.cpu().detach().numpy().tolist()
341
- feature_lists += list(predictions['feat'].cpu().detach().numpy())
342
- if type(self.model) is not AveragedModel:
343
- # compute all losses for each batch data
344
- if type(self.model) is DDP:
345
- losses = self.model.module.get_losses(data_dict, predictions)
346
- else:
347
- losses = self.model.get_losses(data_dict, predictions)
348
-
349
- # store data by recorder
350
- for name, value in losses.items():
351
- eval_recorder_loss[name].update(value)
352
- return eval_recorder_loss, np.array(prediction_lists), np.array(label_lists),np.array(feature_lists)
353
-
354
- def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage):
355
- best_metric = self.best_metrics_all_time[key].get(self.metric_scoring,
356
- float('-inf') if self.metric_scoring != 'eer' else float(
357
- 'inf'))
358
- # Check if the current score is an improvement
359
- improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else (
360
- metric_one_dataset[self.metric_scoring] < best_metric)
361
- if improved:
362
- # Update the best metric
363
- self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring]
364
- if key == 'avg':
365
- self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict']
366
- # Save checkpoint, feature, and metrics if specified in config
367
- if eval_stage=='validation' and self.config['save_ckpt'] and key not in FFpp_pool:
368
- self.save_ckpt(eval_stage, key, f"{epoch}+{iteration}")
369
- self.save_metrics(eval_stage, metric_one_dataset, key)
370
- if losses_one_dataset_recorder is not None:
371
- # info for each dataset
372
- loss_str = f"dataset: {key} step: {step} "
373
- for k, v in losses_one_dataset_recorder.items():
374
- writer = self.get_writer(eval_stage, key, k)
375
- v_avg = v.average()
376
- if v_avg == None:
377
- print(f'{k} is not calculated')
378
- continue
379
- # tensorboard-1. loss
380
- writer.add_scalar(f'{eval_stage}_losses/{k}', v_avg, global_step=step)
381
- loss_str += f"{eval_stage}-loss, {k}: {v_avg} "
382
- self.logger.info(loss_str)
383
- # tqdm.write(loss_str)
384
- metric_str = f"dataset: {key} step: {step} "
385
- for k, v in metric_one_dataset.items():
386
- if k == 'pred' or k == 'label' or k=='dataset_dict':
387
- continue
388
- metric_str += f"{eval_stage}-metric, {k}: {v} "
389
- # tensorboard-2. metric
390
- writer = self.get_writer(eval_stage, key, k)
391
- writer.add_scalar(f'{eval_stage}_metrics/{k}', v, global_step=step)
392
- if 'pred' in metric_one_dataset:
393
- acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label'])
394
- metric_str += f'{eval_stage}-metric, acc_real:{acc_real}; acc_fake:{acc_fake}'
395
- writer.add_scalar(f'{eval_stage}_metrics/acc_real', acc_real, global_step=step)
396
- writer.add_scalar(f'{eval_stage}_metrics/acc_fake', acc_fake, global_step=step)
397
- self.logger.info(metric_str)
398
-
399
- def eval(self, eval_data_loaders, eval_stage, step=None, epoch=None, iteration=None):
400
- # set model to eval mode
401
- self.setEval()
402
-
403
- # define eval recorder
404
- losses_all_datasets = {}
405
- metrics_all_datasets = {}
406
- best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric
407
- avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0,'dataset_dict':{}} #'video_auc': 0
408
- keys = eval_data_loaders.keys()
409
- for key in keys:
410
- # compute loss for each dataset
411
- losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.eval_one_dataset(eval_data_loaders[key])
412
- losses_all_datasets[key] = losses_one_dataset_recorder
413
- metric_one_dataset=get_test_metrics(y_pred=predictions_nps,y_true=label_nps, logger=self.logger)
414
-
415
- for metric_name, value in metric_one_dataset.items():
416
- if metric_name in avg_metric:
417
- avg_metric[metric_name]+=value
418
- avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring]
419
- if type(self.model) is AveragedModel:
420
- metric_str = f"Iter Final for SWA: "
421
- for k, v in metric_one_dataset.items():
422
- metric_str += f"{eval_stage}-metric, {k}: {v} "
423
- self.logger.info(metric_str)
424
- continue
425
- self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage)
426
-
427
- if len(keys)>0 and self.config.get('save_avg',False):
428
- # calculate avg value
429
- for key in avg_metric:
430
- if key != 'dataset_dict':
431
- avg_metric[key] /= len(keys)
432
- self.save_best(epoch, iteration, step, None, 'avg', avg_metric, eval_stage)
433
-
434
- self.logger.info(f'===> {eval_stage} Done!')
435
- return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt
436
-
437
-
438
- @torch.no_grad()
439
- def inference(self, data_dict):
440
- predictions = self.model(data_dict, inference=True)
441
- return predictions