RuiTerrty InPeerReview commited on
Commit
5425ee6
·
0 Parent(s):

Duplicate from InPeerReview/RemoteSensingChangeDetection-RSCD.HA2F

Browse files

Co-authored-by: Agentic Motion Group (AMG) <InPeerReview@users.noreply.huggingface.co>

Files changed (47) hide show
  1. .gitattributes +35 -0
  2. README.md +67 -0
  3. dataset/.gitkeep +0 -0
  4. dataset/Transforms.py +217 -0
  5. dataset/__pycache__/.gitkeep +0 -0
  6. dataset/__pycache__/Transforms.cpython-39.pyc +0 -0
  7. dataset/__pycache__/dataset.cpython-39.pyc +0 -0
  8. dataset/dataset.py +39 -0
  9. eval.py +222 -0
  10. main.py +319 -0
  11. model/.gitkeep +0 -0
  12. model/__pycache__/.gitkeep +0 -0
  13. model/__pycache__/decoder.cpython-39.pyc +0 -0
  14. model/__pycache__/dem.cpython-39.pyc +0 -0
  15. model/__pycache__/encoder.cpython-39.pyc +0 -0
  16. model/__pycache__/freqfusion.cpython-39.pyc +0 -0
  17. model/__pycache__/metric_tool.cpython-39.pyc +0 -0
  18. model/__pycache__/resnet.cpython-39.pyc +0 -0
  19. model/__pycache__/trainer.cpython-39.pyc +0 -0
  20. model/__pycache__/utils.cpython-39.pyc +0 -0
  21. model/decoder.py +301 -0
  22. model/encoder.py +391 -0
  23. model/layers/.gitkeep +0 -0
  24. model/layers/__init__.py +11 -0
  25. model/layers/__pycache__/.gitkeep +0 -0
  26. model/layers/__pycache__/__init__.cpython-39.pyc +0 -0
  27. model/layers/__pycache__/attention.cpython-39.pyc +0 -0
  28. model/layers/__pycache__/block.cpython-39.pyc +0 -0
  29. model/layers/__pycache__/dino_head.cpython-39.pyc +0 -0
  30. model/layers/__pycache__/drop_path.cpython-39.pyc +0 -0
  31. model/layers/__pycache__/layer_scale.cpython-39.pyc +0 -0
  32. model/layers/__pycache__/mlp.cpython-39.pyc +0 -0
  33. model/layers/__pycache__/patch_embed.cpython-39.pyc +0 -0
  34. model/layers/__pycache__/swiglu_ffn.cpython-39.pyc +0 -0
  35. model/layers/attention.py +89 -0
  36. model/layers/block.py +260 -0
  37. model/layers/dino_head.py +58 -0
  38. model/layers/drop_path.py +34 -0
  39. model/layers/layer_scale.py +27 -0
  40. model/layers/mlp.py +40 -0
  41. model/layers/patch_embed.py +88 -0
  42. model/layers/swiglu_ffn.py +72 -0
  43. model/metric_tool.py +131 -0
  44. model/resnet.py +213 -0
  45. model/trainer.py +30 -0
  46. model/utils.py +81 -0
  47. requirements.txt +10 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 🛠️ Requirements
2
+
3
+ ### Environment
4
+ - **Linux system**,
5
+ - **Python** 3.8+, recommended 3.10
6
+ - **PyTorch** 2.0 or higher, recommended 2.1.0
7
+ - **CUDA** 11.7 or higher, recommended 12.1
8
+
9
+ ### Environment Installation
10
+
11
+ It is recommended to use Miniconda for installation. The following commands will create a virtual environment named `stnr` and install PyTorch. In the following installation steps, the default installed CUDA version is 12.1. If your CUDA version is not 12.1, please modify it according to the actual situation.
12
+
13
+ ```bash
14
+ # Create conda environment
15
+ conda create -n stnr python=3.8 -y
16
+ conda activate stnr
17
+
18
+ # Install PyTorch
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ ## 📁 Dataset Preparation
23
+
24
+ We evaluate our method on five remote sensing change detection datasets: **WHU-CD**, **LEVIR-CD**, **SYSU-CD**.
25
+
26
+ | Dataset | Link |
27
+ |---------|------|
28
+ | WHU-CD | [Download](https://aistudio.baidu.com/datasetdetail/251669) |
29
+ | LEVIR-CD | [Download](https://opendatalab.org.cn/OpenDataLab/LEVIR-CD) |
30
+ | SYSU-CD | [Download](https://mail2sysueducn-my.sharepoint.com/personal/liumx23_mail2_sysu_edu_cn/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fliumx23%5Fmail2%5Fsysu%5Fedu%5Fcn%2FDocuments%2FSYSU%2DCD&ga=1) |
31
+
32
+
33
+
34
+ ### Example of Training on LEVIR-CD Dataset
35
+
36
+ ```bash
37
+ python main.py --file_root LEVIR --max_steps 80000 --model_type small --batch_size 16 --lr 2e-4 --gpu_id 0
38
+ ```
39
+
40
+ ### Example of Training on LEVIR-CD Dataset
41
+
42
+ ```bash
43
+ python eval.py --file_root LEVIR --max_steps 80000 --model_type small --batch_size 16 --lr 2e-4 --gpu_id 0
44
+ ```
45
+
46
+ ## 📂 DATA Structure
47
+
48
+ ```
49
+ ├─Train
50
+ ├─A jpg/png
51
+ ├─B jpg/png
52
+ └─label jpg/png
53
+ ├─Val
54
+ ├─A
55
+ ├─B
56
+ └─label
57
+ ├─Test
58
+ ├─A
59
+ ├─B
60
+ └─label
61
+ ```
62
+
63
+ ## 🙏 Acknowledgement
64
+
65
+ We sincerely thank the following works for their contributions:
66
+
67
+ - [ChangeViT](https://arxiv.org/pdf/2406.12847) – A state-of-the-art method for remote sensing change detection that inspired and influenced parts of this work.
dataset/.gitkeep ADDED
File without changes
dataset/Transforms.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ import cv2
6
+
7
+
8
+ class Scale(object):
9
+ """
10
+ Resize the given image to a fixed scale
11
+ """
12
+
13
+ def __init__(self, wi, he):
14
+ '''
15
+ :param wi: width after resizing
16
+ :param he: height after reszing
17
+ '''
18
+ self.w = wi
19
+ self.h = he
20
+
21
+ # modified from torchvision to add support for max size
22
+
23
+ def __call__(self, img, label):
24
+ '''
25
+ :param img: RGB image
26
+ :param label: semantic label image
27
+ :return: resized images
28
+ '''
29
+ # bilinear interpolation for RGB image
30
+ img = cv2.resize(img, (self.w, self.h))
31
+ # nearest neighbour interpolation for label image
32
+ label = cv2.resize(label, (self.w, self.h), interpolation=cv2.INTER_NEAREST)
33
+ return [img, label]
34
+
35
+
36
+ class Resize(object):
37
+ def __init__(self, min_size, max_size, strict=False):
38
+ if not isinstance(min_size, (list, tuple)):
39
+ min_size = (min_size,)
40
+ self.min_size = min_size
41
+ self.max_size = max_size
42
+ self.strict = strict
43
+
44
+ # modified from torchvision to add support for max size
45
+ def get_size(self, image_size):
46
+ w, h = image_size
47
+ if not self.strict:
48
+ size = random.choice(self.min_size)
49
+ max_size = self.max_size
50
+ if max_size is not None:
51
+ min_original_size = float(min((w, h)))
52
+ max_original_size = float(max((w, h)))
53
+ if max_original_size / min_original_size * size > max_size:
54
+ size = int(round(max_size * min_original_size / max_original_size))
55
+
56
+ if (w <= h and w == size) or (h <= w and h == size):
57
+ return (h, w)
58
+
59
+ if w < h:
60
+ ow = size
61
+ oh = int(size * h / w)
62
+ else:
63
+ oh = size
64
+ ow = int(size * w / h)
65
+
66
+ return (oh, ow)
67
+ else:
68
+ if w < h:
69
+ return (self.max_size, self.min_size[0])
70
+ else:
71
+ return (self.min_size[0], self.max_size)
72
+
73
+ def __call__(self, image, label):
74
+ size = self.get_size(image.shape[:2])
75
+ image = cv2.resize(image, size)
76
+ # I confirm that the output size is right, not reversed
77
+ label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
78
+ return (image, label)
79
+
80
+
81
+ class RandomCropResize(object):
82
+ """
83
+ Randomly crop and resize the given image with a probability of 0.5
84
+ """
85
+
86
+ def __init__(self, crop_area):
87
+ '''
88
+ :param crop_area: area to be cropped (this is the max value and we select between 0 and crop area
89
+ '''
90
+ self.cw = crop_area
91
+ self.ch = crop_area
92
+
93
+ def __call__(self, img, label):
94
+ if random.random() < 0.5:
95
+ h, w = img.shape[:2]
96
+ x1 = random.randint(0, self.ch)
97
+ y1 = random.randint(0, self.cw)
98
+
99
+ img_crop = img[y1:h - y1, x1:w - x1]
100
+ label_crop = label[y1:h - y1, x1:w - x1]
101
+
102
+ img_crop = cv2.resize(img_crop, (w, h))
103
+ label_crop = cv2.resize(label_crop, (w, h), interpolation=cv2.INTER_NEAREST)
104
+
105
+ return img_crop, label_crop
106
+ else:
107
+ return [img, label]
108
+
109
+
110
+ class RandomFlip(object):
111
+ """
112
+ Randomly flip the given Image with a probability of 0.5
113
+ """
114
+
115
+ def __call__(self, image, label):
116
+ if random.random() < 0.5:
117
+ image = cv2.flip(image, 0) # horizontal flip
118
+ label = cv2.flip(label, 0) # horizontal flip
119
+ if random.random() < 0.5:
120
+ image = cv2.flip(image, 1) # veritcal flip
121
+ label = cv2.flip(label, 1) # veritcal flip
122
+ return [image, label]
123
+
124
+
125
+ class RandomExchange(object):
126
+ """
127
+ Randomly flip the given Image with a probability of 0.5
128
+ """
129
+
130
+ def __call__(self, image, label):
131
+ if random.random() < 0.5:
132
+ pre_img = image[:, :, 0:3]
133
+ post_img = image[:, :, 3:6]
134
+ image = numpy.concatenate((post_img, pre_img), axis=2)
135
+ return [image, label]
136
+
137
+
138
+ class Normalize(object):
139
+ """
140
+ Given mean: (B, G, R) and std: (B, G, R),
141
+ will normalize each channel of the torch.*Tensor, i.e.
142
+ channel = (channel - mean) / std
143
+ """
144
+
145
+ def __init__(self, mean, std):
146
+ '''
147
+ :param mean: global mean computed from dataset
148
+ :param std: global std computed from dataset
149
+ '''
150
+ self.mean = mean
151
+ self.std = std
152
+ self.depth_mean = [0.5]
153
+ self.depth_std = [0.5]
154
+
155
+ def __call__(self, image, label):
156
+ image = image.astype(np.float32)
157
+ image = image / 255
158
+ label = np.ceil(label / 255)
159
+ for i in range(6):
160
+ image[:, :, i] -= self.mean[i]
161
+ for i in range(6):
162
+ image[:, :, i] /= self.std[i]
163
+
164
+ return [image, label]
165
+
166
+
167
+ class GaussianNoise(object):
168
+ def __init__(self, std=0.05):
169
+ '''
170
+ :param mean: global mean computed from dataset
171
+ :param std: global std computed from dataset
172
+ '''
173
+ self.std = std
174
+
175
+ def __call__(self, image, label):
176
+ noise = np.random.normal(loc=0, scale=self.std, size=image.shape)
177
+ image = image + noise.astype(np.float32)
178
+ return [image, label]
179
+
180
+
181
+ class ToTensor(object):
182
+ '''
183
+ This class converts the data to tensor so that it can be processed by PyTorch
184
+ '''
185
+
186
+ def __init__(self, scale=1):
187
+ '''
188
+ :param scale: set this parameter according to the output scale
189
+ '''
190
+ self.scale = scale
191
+
192
+ def __call__(self, image, label):
193
+ if self.scale != 1:
194
+ h, w = label.shape[:2]
195
+ image = cv2.resize(image, (int(w), int(h)))
196
+ label = cv2.resize(label, (int(w / self.scale), int(h / self.scale)), \
197
+ interpolation=cv2.INTER_NEAREST)
198
+ image = image[:, :, ::-1].copy() # .copy() is to solve "torch does not support negative index"
199
+ image = image.transpose((2, 0, 1))
200
+ image_tensor = torch.from_numpy(image)
201
+ label_tensor = torch.LongTensor(np.array(label, dtype=np.int)).unsqueeze(dim=0)
202
+
203
+ return [image_tensor, label_tensor]
204
+
205
+
206
+ class Compose(object):
207
+ """
208
+ Composes several transforms together.
209
+ """
210
+
211
+ def __init__(self, transforms):
212
+ self.transforms = transforms
213
+
214
+ def __call__(self, *args):
215
+ for t in self.transforms:
216
+ args = t(*args)
217
+ return args
dataset/__pycache__/.gitkeep ADDED
File without changes
dataset/__pycache__/Transforms.cpython-39.pyc ADDED
Binary file (6.8 kB). View file
 
dataset/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (1.9 kB). View file
 
dataset/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ from os.path import join as osp
4
+ import numpy
5
+ import torch.utils.data
6
+
7
+
8
+ class Dataset(torch.utils.data.Dataset):
9
+ def __init__(self, file_root='data/', mode='train', transform=None):
10
+ self.file_list = os.listdir(osp(file_root, mode, 'A'))
11
+
12
+ self.pre_images = [osp(file_root, mode, 'A', x) for x in self.file_list]
13
+ self.post_images = [osp(file_root, mode, 'B', x) for x in self.file_list]
14
+ self.gts = [osp(file_root, mode, 'label', x) for x in self.file_list]
15
+
16
+ self.transform = transform
17
+
18
+ def __len__(self):
19
+ return len(self.pre_images)
20
+
21
+ def __getitem__(self, idx):
22
+ pre_image_name = self.pre_images[idx]
23
+ label_name = self.gts[idx]
24
+ post_image_name = self.post_images[idx]
25
+
26
+ pre_image = cv2.imread(pre_image_name)
27
+ label = cv2.imread(label_name, 0)
28
+ post_image = cv2.imread(post_image_name)
29
+
30
+ img = numpy.concatenate((pre_image, post_image), axis=2)
31
+
32
+ if self.transform:
33
+ [img, label] = self.transform(img, label)
34
+
35
+ return img, label
36
+
37
+ def get_img_info(self, idx):
38
+ img = cv2.imread(self.pre_images[idx])
39
+ return {"height": img.shape[0], "width": img.shape[1]}
eval.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from model.trainer import Trainer
3
+
4
+ sys.path.insert(0, '.')
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.backends.cudnn as cudnn
9
+ from torch.nn.parallel import gather
10
+ import torch.optim.lr_scheduler
11
+
12
+ import dataset.dataset as myDataLoader
13
+ import dataset.Transforms as myTransforms
14
+ from model.metric_tool import ConfuseMatrixMeter
15
+ from model.utils import BCEDiceLoss, init_seed
16
+ from PIL import Image
17
+ import os
18
+ import time
19
+ import numpy as np
20
+ from argparse import ArgumentParser
21
+ from tqdm import tqdm
22
+
23
+
24
+ @torch.no_grad()
25
+ def validate(args, val_loader, model, save_masks=False):
26
+ model.eval()
27
+
28
+ # 确保所有BatchNorm层使用全局统计量
29
+ for m in model.modules():
30
+ if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
31
+ m.track_running_stats = True
32
+ m.eval()
33
+
34
+ salEvalVal = ConfuseMatrixMeter(n_class=2)
35
+ epoch_loss = []
36
+
37
+ if save_masks:
38
+ mask_dir = f"{args.savedir}/pred_masks"
39
+ os.makedirs(mask_dir, exist_ok=True)
40
+ print(f"Saving prediction masks to: {mask_dir}")
41
+
42
+ pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating")
43
+
44
+ for batch_idx, batched_inputs in pbar:
45
+ img, target = batched_inputs
46
+ # 获取当前batch的所有文件名
47
+ batch_file_names = val_loader.sampler.data_source.file_list[
48
+ batch_idx * args.batch_size : (batch_idx + 1) * args.batch_size
49
+ ]
50
+
51
+ pre_img = img[:, 0:3]
52
+ post_img = img[:, 3:6]
53
+
54
+ if args.onGPU:
55
+ pre_img = pre_img.cuda()
56
+ post_img = post_img.cuda()
57
+ target = target.cuda()
58
+
59
+ target = target.float()
60
+ output = model(pre_img, post_img)
61
+ loss = BCEDiceLoss(output, target)
62
+ pred = (output > 0.5).long()
63
+
64
+ if save_masks:
65
+ pred_np = pred.cpu().numpy().astype(np.uint8)
66
+
67
+ print(f"\nDebug - Batch {batch_idx}: {len(batch_file_names)} files, Mask shape: {pred_np.shape}")
68
+
69
+ try:
70
+ for i in range(pred_np.shape[0]):
71
+ if i >= len(batch_file_names): # 防止文件名不足
72
+ print(f"Warning: Missing filename for mask {i}, using default")
73
+ base_name = f"batch_{batch_idx}_mask_{i}"
74
+ else:
75
+ base_name = os.path.splitext(os.path.basename(batch_file_names[i]))[0]
76
+
77
+ single_mask = pred_np[i, 0] # 获取(1, 256, 256)中的(256, 256)
78
+
79
+ if single_mask.ndim != 2:
80
+ raise ValueError(f"Invalid mask shape: {single_mask.shape}")
81
+
82
+ mask_path = f"{mask_dir}/{base_name}_pred.png"
83
+ Image.fromarray(single_mask * 255).save(mask_path)
84
+ print(f"Saved: {mask_path}")
85
+
86
+ except Exception as e:
87
+ print(f"\nError saving batch {batch_idx}: {str(e)}")
88
+ print(f"Current mask shape: {single_mask.shape if 'single_mask' in locals() else 'N/A'}")
89
+ print(f"Current file: {base_name if 'base_name' in locals() else 'N/A'}")
90
+
91
+ if args.onGPU and torch.cuda.device_count() > 1:
92
+ pred = gather(pred, 0, dim=0)
93
+
94
+ f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target.cpu().numpy())
95
+ epoch_loss.append(loss.item())
96
+
97
+ pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'F1': f"{f1:.4f}"})
98
+
99
+ average_loss = sum(epoch_loss) / len(epoch_loss)
100
+ scores = salEvalVal.get_scores()
101
+ return average_loss, scores
102
+
103
+ def ValidateSegmentation(args):
104
+ """完整的验证流程主函数"""
105
+ # 初始化设置
106
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
107
+ torch.backends.cudnn.benchmark = True
108
+ init_seed(args.seed) # 固定随机种子保证可重复性
109
+
110
+ # 模型路径设置
111
+ args.savedir = os.path.join(args.savedir,
112
+ f"{args.file_root}_iter_{args.max_steps}_lr_{args.lr}")
113
+ os.makedirs(args.savedir, exist_ok=True)
114
+
115
+ # 数据集路径配置
116
+ dataset_mapping = {
117
+ 'LEVIR': './levir_cd_256',
118
+ 'WHU': './whu_cd_256',
119
+ 'CLCD': './clcd_256',
120
+ 'SYSU': './sysu_256',
121
+ 'OSCD': './oscd_256'
122
+ }
123
+ args.file_root = dataset_mapping.get(args.file_root, args.file_root)
124
+
125
+ # 初始化模型
126
+ model = Trainer(args.model_type).float()
127
+ if args.onGPU:
128
+ model = model.cuda()
129
+
130
+ # 数据预处理 - 保持与训练时验证集相同的预处理
131
+ mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485]
132
+ std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229]
133
+
134
+ valDataset = myTransforms.Compose([
135
+ myTransforms.Normalize(mean=mean, std=std),
136
+ myTransforms.Scale(args.inWidth, args.inHeight),
137
+ myTransforms.ToTensor()
138
+ ])
139
+
140
+ # 数据加载
141
+ test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset)
142
+ testLoader = torch.utils.data.DataLoader(
143
+ test_data,
144
+ batch_size=args.batch_size,
145
+ shuffle=False,
146
+ num_workers=args.num_workers,
147
+ pin_memory=True
148
+ )
149
+
150
+ # 日志设置
151
+ logFileLoc = os.path.join(args.savedir, args.logFile)
152
+ logger = open(logFileLoc, 'a' if os.path.exists(logFileLoc) else 'w')
153
+ if not os.path.exists(logFileLoc):
154
+ logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s" %
155
+ ('Epoch', 'Kappa', 'IoU', 'F1', 'Recall', 'Precision', 'OA'))
156
+ logger.flush()
157
+
158
+ # 加载最佳模型
159
+ model_file_name = os.path.join(args.savedir, 'best_model.pth')
160
+ if not os.path.exists(model_file_name):
161
+ raise FileNotFoundError(f"Model file not found: {model_file_name}")
162
+
163
+ state_dict = torch.load(model_file_name)
164
+ model.load_state_dict(state_dict)
165
+ print(f"Loaded model from {model_file_name}")
166
+
167
+ # 执行验证
168
+ loss_test, score_test = validate(args, testLoader, model, save_masks=args.save_masks)
169
+
170
+ # 输出结果
171
+ print("\nTest Results:")
172
+ print(f"Loss: {loss_test:.4f}")
173
+ print(f"Kappa: {score_test['Kappa']:.4f}")
174
+ print(f"IoU: {score_test['IoU']:.4f}")
175
+ print(f"F1: {score_test['F1']:.4f}")
176
+ print(f"Recall: {score_test['recall']:.4f}")
177
+ print(f"Precision: {score_test['precision']:.4f}")
178
+ print(f"OA: {score_test['OA']:.4f}")
179
+
180
+ # 记录日志
181
+ logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" %
182
+ ('Test', score_test['Kappa'], score_test['IoU'], score_test['F1'],
183
+ score_test['recall'], score_test['precision'], score_test['OA']))
184
+ logger.close()
185
+
186
+
187
+ if __name__ == '__main__':
188
+ parser = ArgumentParser()
189
+ parser.add_argument('--file_root', default="LEVIR",
190
+ help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD')
191
+ parser.add_argument('--inWidth', type=int, default=256, help='Width of input image')
192
+ parser.add_argument('--inHeight', type=int, default=256, help='Height of input image')
193
+ parser.add_argument('--max_steps', type=int, default=80000,
194
+ help='Max. number of iterations (for path naming)')
195
+ parser.add_argument('--num_workers', type=int, default=4,
196
+ help='Number of data loading workers')
197
+ parser.add_argument('--model_type', type=str, default='small',
198
+ help='Model type | tiny | small')
199
+ parser.add_argument('--batch_size', type=int, default=16,
200
+ help='Batch size for validation')
201
+ parser.add_argument('--lr', type=float, default=2e-4,
202
+ help='Learning rate (for path naming)')
203
+ parser.add_argument('--seed', type=int, default=16,
204
+ help='Random seed for reproducibility')
205
+ parser.add_argument('--savedir', default='./results',
206
+ help='Base directory to save results')
207
+ parser.add_argument('--logFile', default='testLog.txt',
208
+ help='File to save validation logs')
209
+ parser.add_argument('--onGPU', default=True,
210
+ type=lambda x: (str(x).lower() == 'true'),
211
+ help='Run on GPU if True')
212
+ parser.add_argument('--gpu_id', type=int, default=0,
213
+ help='GPU device id')
214
+ parser.add_argument('--save_masks', action='store_true',
215
+ help='Save predicted masks to disk')
216
+
217
+ args = parser.parse_args()
218
+ print('Validation with args:')
219
+ print(args)
220
+
221
+ ValidateSegmentation(args)
222
+
main.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from model.trainer import Trainer
4
+
5
+ sys.path.insert(0, '.')
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.backends.cudnn as cudnn
10
+ from torch.nn.parallel import gather
11
+ import torch.optim.lr_scheduler
12
+
13
+ import dataset.dataset as myDataLoader
14
+ import dataset.Transforms as myTransforms
15
+ from model.metric_tool import ConfuseMatrixMeter
16
+ from model.utils import BCEDiceLoss, init_seed, adjust_learning_rate
17
+
18
+ import os, time
19
+ import numpy as np
20
+ from argparse import ArgumentParser
21
+
22
+
23
+
24
+ @torch.no_grad()
25
+ def val(args, val_loader, model):
26
+ model.eval()
27
+
28
+ salEvalVal = ConfuseMatrixMeter(n_class=2)
29
+
30
+ epoch_loss = []
31
+
32
+ total_batches = len(val_loader)
33
+ print(len(val_loader))
34
+ for iter, batched_inputs in enumerate(val_loader):
35
+
36
+ img, target = batched_inputs
37
+ pre_img = img[:, 0:3]
38
+ post_img = img[:, 3:6]
39
+
40
+ start_time = time.time()
41
+
42
+ if args.onGPU == True:
43
+ pre_img = pre_img.cuda()
44
+ target = target.cuda()
45
+ post_img = post_img.cuda()
46
+
47
+ pre_img_var = torch.autograd.Variable(pre_img).float()
48
+ post_img_var = torch.autograd.Variable(post_img).float()
49
+ target_var = torch.autograd.Variable(target).float()
50
+
51
+ # run the mdoel
52
+ output = model(pre_img_var, post_img_var)
53
+ loss = BCEDiceLoss(output, target_var)
54
+
55
+ pred = torch.where(output > 0.5, torch.ones_like(output), torch.zeros_like(output)).long()
56
+
57
+ # torch.cuda.synchronize()
58
+ time_taken = time.time() - start_time
59
+
60
+ epoch_loss.append(loss.data.item())
61
+
62
+ # compute the confusion matrix
63
+ if args.onGPU and torch.cuda.device_count() > 1:
64
+ output = gather(pred, 0, dim=0)
65
+ # salEvalVal.addBatch(pred, target_var)
66
+ f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target_var.cpu().numpy())
67
+ if iter % 5 == 0:
68
+ print('\r[%d/%d] F1: %3f loss: %.3f time: %.3f' % (iter, total_batches, f1, loss.data.item(), time_taken),
69
+ end='')
70
+
71
+ average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss)
72
+ scores = salEvalVal.get_scores()
73
+
74
+ return average_epoch_loss_val, scores
75
+
76
+
77
+ def train(args, train_loader, model, optimizer, epoch, max_batches, cur_iter=0, lr_factor=1.):
78
+ # switch to train mode
79
+ model.train()
80
+
81
+ salEvalVal = ConfuseMatrixMeter(n_class=2)
82
+ epoch_loss = []
83
+
84
+ for iter, batched_inputs in enumerate(train_loader):
85
+
86
+ img, target = batched_inputs
87
+ pre_img = img[:, 0:3]
88
+ post_img = img[:, 3:6]
89
+
90
+ start_time = time.time()
91
+
92
+ # adjust the learning rate
93
+ lr = adjust_learning_rate(args, optimizer, epoch, iter + cur_iter, max_batches, lr_factor=lr_factor)
94
+
95
+ if args.onGPU == True:
96
+ pre_img = pre_img.cuda()
97
+ target = target.cuda()
98
+ post_img = post_img.cuda()
99
+
100
+ pre_img_var = torch.autograd.Variable(pre_img).float()
101
+ post_img_var = torch.autograd.Variable(post_img).float()
102
+ target_var = torch.autograd.Variable(target).float()
103
+
104
+ # run the model
105
+ output = model(pre_img_var, post_img_var)
106
+ loss = BCEDiceLoss(output, target_var)
107
+
108
+ pred = torch.where(output > 0.5, torch.ones_like(output), torch.zeros_like(output)).long()
109
+
110
+ optimizer.zero_grad()
111
+ loss.backward()
112
+ optimizer.step()
113
+
114
+ epoch_loss.append(loss.data.item())
115
+ time_taken = time.time() - start_time
116
+ res_time = (max_batches * args.max_epochs - iter - cur_iter) * time_taken / 3600
117
+
118
+ if args.onGPU and torch.cuda.device_count() > 1:
119
+ output = gather(pred, 0, dim=0)
120
+
121
+ # Computing F-measure and IoU on GPU
122
+ with torch.no_grad():
123
+ f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target_var.cpu().numpy())
124
+
125
+ if iter % 5 == 0:
126
+ print('\riteration: [%d/%d] f1: %.3f lr: %.7f loss: %.3f time:%.3f h' % (
127
+ iter + cur_iter, max_batches * args.max_epochs, f1, lr, loss.data.item(),
128
+ res_time),
129
+ end='')
130
+
131
+ average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
132
+ scores = salEvalVal.get_scores()
133
+
134
+ return average_epoch_loss_train, scores, lr
135
+
136
+
137
+ def trainValidateSegmentation(args):
138
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
139
+
140
+ torch.backends.cudnn.benchmark = True
141
+
142
+ init_seed(args.seed)
143
+
144
+ args.savedir = args.savedir + '_' + args.file_root + '_iter_' + str(args.max_steps) + '_lr_' + str(args.lr) + '/'
145
+
146
+ if args.file_root == 'LEVIR':
147
+ args.file_root = './levir_cd_256'
148
+ elif args.file_root == 'WHU':
149
+ args.file_root = './whu_cd_256'
150
+ elif args.file_root == 'CLCD':
151
+ args.file_root = './clcd_256'
152
+ elif args.file_root == 'SYSU':
153
+ args.file_root = './sysu_256'
154
+ elif args.file_root == 'OSCD':
155
+ args.file_root = 'oscd_256'
156
+ else:
157
+ raise TypeError('%s has not defined' % args.file_root)
158
+
159
+ if not os.path.exists(args.savedir):
160
+ os.makedirs(args.savedir)
161
+
162
+
163
+ model = Trainer(args.model_type).float()
164
+ if args.onGPU:
165
+ model = model.cuda()
166
+
167
+ # mean = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
168
+ # std = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
169
+
170
+ mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485]
171
+ std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229]
172
+
173
+ # compose the data with transforms
174
+ trainDataset_main = myTransforms.Compose([
175
+ myTransforms.Normalize(mean=mean, std=std),
176
+ myTransforms.Scale(args.inWidth, args.inHeight),
177
+ myTransforms.RandomCropResize(int(7. / 224. * args.inWidth)),
178
+ myTransforms.RandomFlip(),
179
+ myTransforms.RandomExchange(),
180
+ myTransforms.ToTensor()
181
+ ])
182
+
183
+ valDataset = myTransforms.Compose([
184
+ myTransforms.Normalize(mean=mean, std=std),
185
+ myTransforms.Scale(args.inWidth, args.inHeight),
186
+ myTransforms.ToTensor()
187
+ ])
188
+
189
+ train_data = myDataLoader.Dataset(file_root=args.file_root, mode="train", transform=trainDataset_main)
190
+
191
+ trainLoader = torch.utils.data.DataLoader(
192
+ train_data,
193
+ batch_size=args.batch_size, shuffle=True,
194
+ num_workers=args.num_workers, pin_memory=True, drop_last=False
195
+ )
196
+
197
+ test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset)
198
+ testLoader = torch.utils.data.DataLoader(
199
+ test_data, shuffle=False,
200
+ batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
201
+
202
+
203
+ max_batches = len(trainLoader)
204
+ print('For each epoch, we have {} batches'.format(max_batches))
205
+
206
+ if args.onGPU:
207
+ cudnn.benchmark = True
208
+
209
+ args.max_epochs = int(np.ceil(args.max_steps / max_batches))
210
+ start_epoch = 0
211
+ cur_iter = 0
212
+ max_F1_val = 0
213
+
214
+ if args.resume is not None:
215
+ args.resume = args.savedir + 'checkpoint.pth.tar'
216
+ if os.path.isfile(args.resume):
217
+ print("=> loading checkpoint '{}'".format(args.resume))
218
+ checkpoint = torch.load(args.resume)
219
+ start_epoch = checkpoint['epoch']
220
+ cur_iter = start_epoch * len(trainLoader)
221
+ # args.lr = checkpoint['lr']
222
+ model.load_state_dict(checkpoint['state_dict'])
223
+ print("=> loaded checkpoint '{}' (epoch {})"
224
+ .format(args.resume, checkpoint['epoch']))
225
+ else:
226
+ print("=> no checkpoint found at '{}'".format(args.resume))
227
+
228
+ logFileLoc = args.savedir + args.logFile
229
+ if os.path.isfile(logFileLoc):
230
+ logger = open(logFileLoc, 'a')
231
+ else:
232
+ logger = open(logFileLoc, 'w')
233
+ logger.write(
234
+ "\n%s\t%s\t%s\t%s\t%s\t%s\t%s" % ('Epoch', 'Kappa (val)', 'IoU (val)', 'F1 (val)', 'R (val)', 'P (val)', 'OA (val)'))
235
+ logger.flush()
236
+
237
+ optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.99), eps=1e-08, weight_decay=1e-4)
238
+
239
+ for epoch in range(start_epoch, args.max_epochs):
240
+ lossTr, score_tr, lr = \
241
+ train(args, trainLoader, model, optimizer, epoch, max_batches, cur_iter)
242
+ cur_iter += len(trainLoader)
243
+
244
+ torch.cuda.empty_cache()
245
+
246
+ # evaluate on validation set
247
+ if epoch == 0:
248
+ continue
249
+
250
+ lossVal, score_val = val(args, testLoader, model)
251
+ torch.cuda.empty_cache()
252
+ logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % (epoch, score_val['Kappa'], score_val['IoU'],
253
+ score_val['F1'], score_val['recall'],
254
+ score_val['precision'], score_val['OA']))
255
+ logger.flush()
256
+
257
+ torch.save({
258
+ 'epoch': epoch + 1,
259
+ 'arch': str(model),
260
+ 'state_dict': model.state_dict(),
261
+ 'optimizer': optimizer.state_dict(),
262
+ 'lossTr': lossTr,
263
+ 'lossVal': lossVal,
264
+ 'F_Tr': score_tr['F1'],
265
+ 'F_val': score_val['F1'],
266
+ 'lr': lr
267
+ }, args.savedir + 'checkpoint.pth.tar')
268
+
269
+ # save the model also
270
+ model_file_name = args.savedir + 'best_model.pth'
271
+ if epoch % 1 == 0 and max_F1_val <= score_val['F1']:
272
+ max_F1_val = score_val['F1']
273
+ torch.save(model.state_dict(), model_file_name)
274
+
275
+ print("Epoch " + str(epoch) + ': Details')
276
+ print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t F1(tr) = %.4f\t F1(val) = %.4f" \
277
+ % (epoch, lossTr, lossVal, score_tr['F1'], score_val['F1']))
278
+ torch.cuda.empty_cache()
279
+
280
+ state_dict = torch.load(model_file_name)
281
+ model.load_state_dict(state_dict)
282
+
283
+ loss_test, score_test = val(args, testLoader, model)
284
+ print("\nTest :\t Kappa (te) = %.4f\t IoU (te) = %.4f\t F1 (te) = %.4f\t R (te) = %.4f\t P (te) = %.4f" \
285
+ % (score_test['Kappa'], score_test['IoU'], score_test['F1'], score_test['recall'], score_test['precision']))
286
+ logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % ('Test', score_test['Kappa'], score_test['IoU'],
287
+ score_test['F1'], score_test['recall'],
288
+ score_test['precision'], score_test['OA']))
289
+ logger.flush()
290
+ logger.close()
291
+
292
+
293
+ if __name__ == '__main__':
294
+ parser = ArgumentParser()
295
+ parser.add_argument('--file_root', default="LEVIR", help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD ')
296
+ parser.add_argument('--inWidth', type=int, default=256, help='Width of RGB image')
297
+ parser.add_argument('--inHeight', type=int, default=256, help='Height of RGB image')
298
+ parser.add_argument('--max_steps', type=int, default=80000, help='Max. number of iterations')
299
+ parser.add_argument('--num_workers', type=int, default=4, help='No. of parallel threads')
300
+ parser.add_argument('--model_type', type=str, default='small', help='select vit model type | tiny | small')
301
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
302
+ parser.add_argument('--step_loss', type=int, default=100, help='Decrease learning rate after how many epochs')
303
+ parser.add_argument('--lr', type=float, default=2e-4, help='Initial learning rate')
304
+ parser.add_argument('--lr_mode', default='poly', help='Learning rate policy, step or poly')
305
+ parser.add_argument('--seed', default=16, help='initialization seed number')
306
+ parser.add_argument('--savedir', default='./results', help='Directory to save the results')
307
+ parser.add_argument('--resume', default=None, help='Use this checkpoint to continue training | '
308
+ './results_ep100/checkpoint.pth.tar')
309
+ parser.add_argument('--logFile', default='trainValLog.txt',
310
+ help='File that stores the training and validation logs')
311
+ parser.add_argument('--onGPU', default=True, type=lambda x: (str(x).lower() == 'true'),
312
+ help='Run on CPU or GPU. If TRUE, then GPU.')
313
+ parser.add_argument('--gpu_id', default=0, type=int, help='GPU id number')
314
+
315
+ args = parser.parse_args()
316
+ print('Called with args:')
317
+ print(args)
318
+
319
+ trainValidateSegmentation(args)
model/.gitkeep ADDED
File without changes
model/__pycache__/.gitkeep ADDED
File without changes
model/__pycache__/decoder.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
model/__pycache__/dem.cpython-39.pyc ADDED
Binary file (2.23 kB). View file
 
model/__pycache__/encoder.cpython-39.pyc ADDED
Binary file (12.4 kB). View file
 
model/__pycache__/freqfusion.cpython-39.pyc ADDED
Binary file (12.4 kB). View file
 
model/__pycache__/metric_tool.cpython-39.pyc ADDED
Binary file (4.66 kB). View file
 
model/__pycache__/resnet.cpython-39.pyc ADDED
Binary file (6.12 kB). View file
 
model/__pycache__/trainer.cpython-39.pyc ADDED
Binary file (1.08 kB). View file
 
model/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.38 kB). View file
 
model/decoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from model.utils import weight_init
6
+
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
10
+ if drop_prob == 0. or not training:
11
+ return x
12
+ keep_prob = 1 - drop_prob
13
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
14
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
15
+ random_tensor.floor_() # binarize
16
+ output = x.div(keep_prob) * random_tensor
17
+ return output
18
+
19
+
20
+ class DropPath(nn.Module):
21
+ def __init__(self, drop_prob=None):
22
+ super(DropPath, self).__init__()
23
+ self.drop_prob = drop_prob
24
+
25
+ def forward(self, x):
26
+ return drop_path(x, self.drop_prob, self.training)
27
+
28
+
29
+ class Mlp(nn.Module):
30
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
31
+ super().__init__()
32
+ out_features = out_features or in_features
33
+ hidden_features = hidden_features or in_features
34
+ self.fc1 = nn.Linear(in_features, hidden_features)
35
+ self.act = act_layer()
36
+ self.fc2 = nn.Linear(hidden_features, out_features)
37
+ self.drop = nn.Dropout(drop)
38
+
39
+ def forward(self, x):
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ x = self.fc2(x)
44
+ x = self.drop(x)
45
+ return x
46
+
47
+
48
+
49
+ class CrossAttention(nn.Module):
50
+ def __init__(self, dim1, dim2, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
51
+ super().__init__()
52
+ self.num_heads = num_heads
53
+ head_dim = dim1 // num_heads
54
+ self.scale = head_dim ** -0.5
55
+
56
+ self.q = nn.Linear(dim1, dim1, bias=qkv_bias)
57
+ self.kv = nn.Linear(dim2, dim1 * 2, bias=qkv_bias)
58
+
59
+ self.attn_drop = nn.Dropout(attn_drop)
60
+ self.proj = nn.Linear(dim1, dim1)
61
+ self.proj_drop = nn.Dropout(proj_drop)
62
+
63
+ def forward(self, x, y):
64
+ B1, N1, C1 = x.shape
65
+ B2, N2, C2 = y.shape
66
+
67
+ q = self.q(x).reshape(B1, N1, self.num_heads, C1 // self.num_heads).permute(0, 2, 1, 3)
68
+ kv = self.kv(y).reshape(B2, N2, 2, self.num_heads, C1 // self.num_heads).permute(2, 0, 3, 1, 4)
69
+
70
+ k, v = kv[0], kv[1]
71
+
72
+ attn = (q @ k.transpose(-2, -1)) * self.scale
73
+ attn = attn.softmax(dim=-1)
74
+ attn = self.attn_drop(attn)
75
+
76
+ x = (attn @ v).transpose(1, 2).reshape(B1, N1, C1)
77
+
78
+ x = self.proj(x)
79
+ x = self.proj_drop(x)
80
+
81
+ return x
82
+
83
+
84
+
85
+ class Block(nn.Module):
86
+ def __init__(self, dim1, dim2, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
87
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
88
+ super().__init__()
89
+ self.norm1 = norm_layer(dim1)
90
+ self.norm2 = norm_layer(dim2)
91
+ self.attn = CrossAttention(dim1, dim2, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
92
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
93
+ self.norm3 = norm_layer(dim1)
94
+ mlp_hidden_dim = int(dim1 * mlp_ratio)
95
+ self.mlp = Mlp(in_features=dim1, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
96
+
97
+ def forward(self, x, y):
98
+ x = x + self.drop_path(self.attn(self.norm1(x), self.norm2(y)))
99
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
100
+ return x
101
+
102
+
103
+
104
+ class ContentAwareAggregation(nn.Module):
105
+ def __init__(self, low_dim, high_dim):
106
+ super().__init__()
107
+ self.project = nn.Sequential(
108
+ nn.Conv2d(high_dim, low_dim, kernel_size=1),
109
+ nn.BatchNorm2d(low_dim),
110
+ nn.ReLU(inplace=True)
111
+ )
112
+
113
+ self.attn_gen = nn.Sequential(
114
+ nn.Conv2d(low_dim, low_dim, kernel_size=3, padding=1, groups=low_dim),
115
+ nn.BatchNorm2d(low_dim),
116
+ nn.ReLU(inplace=True),
117
+ nn.Conv2d(low_dim, low_dim, kernel_size=1),
118
+ nn.Sigmoid()
119
+ )
120
+
121
+ def forward(self, low_feat, high_feat):
122
+ high_feat = F.interpolate(high_feat, size=low_feat.shape[2:], mode='bilinear', align_corners=False)
123
+ high_feat = self.project(high_feat)
124
+ attn = self.attn_gen(low_feat + high_feat)
125
+ out = attn * low_feat + high_feat
126
+ return out
127
+
128
+
129
+
130
+ class FeatureInjector(nn.Module):
131
+ def __init__(self, dim1=384, dim2=[64, 128, 256], num_heads=8, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
132
+ drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
133
+ super().__init__()
134
+
135
+ self.c2_c5 = Block(dim1, dim2[0], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
136
+ self.c3_c5 = Block(dim1, dim2[1], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
137
+ self.c4_c5 = Block(dim1, dim2[2], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
138
+
139
+ self.fuse = nn.Conv2d(dim1*3, dim1, 1, bias=False)
140
+ self.caa = ContentAwareAggregation(dim1, dim1)
141
+
142
+ weight_init(self)
143
+
144
+ def base_forward(self, c2, c3, c4, c5):
145
+ H, W = c5.shape[2:]
146
+
147
+ c2 = rearrange(c2, 'b c h w -> b (h w) c')
148
+ c3 = rearrange(c3, 'b c h w -> b (h w) c')
149
+ c4 = rearrange(c4, 'b c h w -> b (h w) c')
150
+ c5 = rearrange(c5, 'b c h w -> b (h w) c')
151
+
152
+ _c2 = self.c2_c5(c5, c2)
153
+ _c2 = rearrange(_c2, 'b (h w) c -> b c h w', h=H, w=W)
154
+
155
+ _c3 = self.c3_c5(c5, c3)
156
+ _c3 = rearrange(_c3, 'b (h w) c -> b c h w', h=H, w=W)
157
+
158
+ _c4 = self.c4_c5(c5, c4)
159
+ _c4 = rearrange(_c4, 'b (h w) c -> b c h w', h=H, w=W)
160
+
161
+ _c5 = self.fuse(torch.cat([_c2, _c3, _c4], dim=1))
162
+
163
+ return _c5
164
+
165
+ def forward(self, fx, fy):
166
+ _c5x = self.base_forward(fx[0], fx[1], fx[2], fx[3])
167
+ _c5y = self.base_forward(fy[0], fy[1], fy[2], fy[3])
168
+
169
+
170
+ _c5x = self.caa(_c5x, _c5y)
171
+ _c5y = self.caa(_c5y, _c5x)
172
+
173
+ return _c5x, _c5y
174
+
175
+
176
+ class DualAttentionGate(nn.Module):
177
+ def __init__(self, channels, ratio=8):
178
+ super().__init__()
179
+ self.channel_att = nn.Sequential(
180
+ nn.AdaptiveAvgPool2d(1), # [B,C,1,1]
181
+ nn.Conv2d(channels, channels // ratio, 1, bias=False), # [B,C/8,1,1]
182
+ nn.ReLU(),
183
+ nn.Conv2d(channels // ratio, channels, 1, bias=False), # [B,C,1,1]
184
+ nn.Sigmoid()
185
+ )
186
+
187
+ self.spatial_att = nn.Sequential(
188
+ nn.Conv2d(2, 1, 7, padding=3, bias=False), # 输入2通道(mean+std)
189
+ nn.Sigmoid() # 输出[B,1,H,W]
190
+ )
191
+
192
+ def forward(self, x):
193
+
194
+ c_att = self.channel_att(x)
195
+ mean = torch.mean(x, dim=1, keepdim=True)
196
+ std = torch.std(x, dim=1, keepdim=True)
197
+ s_att = self.spatial_att(torch.cat([mean, std], dim=1))
198
+
199
+
200
+ return x * c_att * s_att
201
+
202
+
203
+ class SimplifiedFGFM(nn.Module):
204
+ def __init__(self, in_channels, out_channels):
205
+ super().__init__()
206
+ self.down = nn.Conv2d(in_channels, out_channels, 1, bias=False)
207
+ self.flow_make = nn.Conv2d(out_channels * 2, 4, 3, padding=1, bias=False)
208
+ self.dual_att = DualAttentionGate(out_channels)
209
+
210
+ def flow_warp(self, input, flow, size):
211
+
212
+ out_h, out_w = size
213
+ n, c, h, w = input.size()
214
+
215
+ norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device)
216
+ grid = torch.meshgrid(
217
+ torch.linspace(-1.0, 1.0, out_h),
218
+ torch.linspace(-1.0, 1.0, out_w),
219
+ indexing='ij'
220
+ )
221
+ grid = torch.stack((grid[1], grid[0]), 2).repeat(n, 1, 1, 1).type_as(input)
222
+ grid = grid + flow.permute(0, 2, 3, 1) / norm
223
+
224
+ return F.grid_sample(input, grid, align_corners=True)
225
+
226
+ def forward(self, lowres_feature, highres_feature):
227
+
228
+ l_feature = self.down(lowres_feature)
229
+ l_feature_up = F.interpolate(l_feature, size=highres_feature.shape[2:], mode='bilinear', align_corners=True)
230
+
231
+ flow = self.flow_make(torch.cat([l_feature_up, highres_feature], dim=1))
232
+ flow_l, flow_h = flow[:, :2, :, :], flow[:, 2:, :, :]
233
+
234
+ l_warp = self.flow_warp(l_feature, flow_l, highres_feature.shape[2:])
235
+ h_warp = self.flow_warp(highres_feature, flow_h, highres_feature.shape[2:])
236
+
237
+
238
+ fused = self.dual_att(l_warp + h_warp)
239
+ return fused
240
+
241
+
242
+
243
+ class Decoder(nn.Module):
244
+ def __init__(self, in_dim=[64, 128, 256, 384], decay=4, num_class=1):
245
+ super().__init__()
246
+ c2_channel, c3_channel, c4_channel, c5_channel = in_dim
247
+
248
+ self.structure_enhance = FeatureInjector(dim1=c5_channel)
249
+
250
+
251
+ self.fgfm_c4 = SimplifiedFGFM(in_channels=c5_channel, out_channels=c4_channel)
252
+ self.fgfm_c3 = SimplifiedFGFM(in_channels=c4_channel, out_channels=c3_channel)
253
+ self.fgfm_c2 = SimplifiedFGFM(in_channels=c3_channel, out_channels=c2_channel)
254
+
255
+
256
+ self.classfier = nn.Sequential(
257
+ nn.ConvTranspose2d(c2_channel, c2_channel, kernel_size=4, stride=2, padding=1),
258
+ nn.Conv2d(c2_channel, num_class, 3, 1, padding=1, bias=False)
259
+ )
260
+
261
+
262
+ self.mlp = nn.ModuleList([
263
+ nn.Sequential(
264
+ nn.Conv2d(dim * 3, dim // decay, 1, bias=False),
265
+ nn.BatchNorm2d(dim // decay),
266
+ nn.ReLU(),
267
+ nn.Conv2d(dim // decay, dim // decay, 3, 1, padding=1, bias=False),
268
+ nn.ReLU(),
269
+ nn.Conv2d(dim // decay, dim // decay, 3, 1, padding=1, bias=False),
270
+ nn.ReLU(),
271
+ nn.Conv2d(dim // decay, dim, 3, 1, padding=1, bias=False)
272
+ ) for dim in in_dim
273
+ ])
274
+
275
+ def difference_modeling(self, x, y, block):
276
+ f = torch.cat([x, y, torch.abs(x - y)], dim=1)
277
+ return block(f)
278
+
279
+ def forward(self, fx, fy):
280
+ c2x, c3x, c4x = fx[:-1]
281
+ c2y, c3y, c4y = fy[:-1]
282
+
283
+
284
+ c5x, c5y = self.structure_enhance(fx, fy)
285
+
286
+
287
+ c2 = self.difference_modeling(c2x, c2y, self.mlp[0])
288
+ c3 = self.difference_modeling(c3x, c3y, self.mlp[1])
289
+ c4 = self.difference_modeling(c4x, c4y, self.mlp[2])
290
+ c5 = self.difference_modeling(c5x, c5y, self.mlp[3])
291
+
292
+
293
+ c4f = self.fgfm_c4(c5, c4)
294
+ c3f = self.fgfm_c3(c4f, c3)
295
+ c2f = self.fgfm_c2(c3f, c2)
296
+
297
+
298
+ pred = self.classfier(c2f)
299
+ pred_mask = torch.sigmoid(pred)
300
+
301
+ return pred_mask
model/encoder.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from einops import rearrange
20
+
21
+ from model.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22
+ from model.resnet import resnet18
23
+
24
+
25
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
26
+ if not depth_first and include_root:
27
+ fn(module=module, name=name)
28
+ for child_name, child_module in module.named_children():
29
+ child_name = ".".join((name, child_name)) if name else child_name
30
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
31
+ if depth_first and include_root:
32
+ fn(module=module, name=name)
33
+ return module
34
+
35
+
36
+ class BlockChunk(nn.ModuleList):
37
+ def forward(self, x):
38
+ for b in self:
39
+ x = b(x)
40
+ return x
41
+
42
+
43
+ class DinoVisionTransformer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size=224,
47
+ patch_size=16,
48
+ in_chans=3,
49
+ embed_dim=768,
50
+ depth=12,
51
+ num_heads=12,
52
+ mlp_ratio=4.0,
53
+ qkv_bias=True,
54
+ ffn_bias=True,
55
+ proj_bias=True,
56
+ drop_path_rate=0.0,
57
+ drop_path_uniform=False,
58
+ init_values=None, # for layerscale: None or 0 => no layerscale
59
+ embed_layer=PatchEmbed,
60
+ act_layer=nn.GELU,
61
+ block_fn=Block,
62
+ ffn_layer="mlp",
63
+ block_chunks=0,
64
+ num_register_tokens=0,
65
+ interpolate_antialias=False,
66
+ interpolate_offset=0.1,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
97
+ self.n_blocks = depth
98
+ self.num_heads = num_heads
99
+ self.patch_size = patch_size
100
+ self.num_register_tokens = num_register_tokens
101
+ self.interpolate_antialias = interpolate_antialias
102
+ self.interpolate_offset = interpolate_offset
103
+
104
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
105
+ num_patches = self.patch_embed.num_patches
106
+
107
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
108
+ assert num_register_tokens >= 0
109
+ self.register_tokens = (
110
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
111
+ )
112
+
113
+ if drop_path_uniform is True:
114
+ dpr = [drop_path_rate] * depth
115
+ else:
116
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
117
+
118
+ if ffn_layer == "mlp":
119
+ print("using MLP layer as FFN")
120
+ ffn_layer = Mlp
121
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
122
+ print("using SwiGLU layer as FFN")
123
+ ffn_layer = SwiGLUFFNFused
124
+ elif ffn_layer == "identity":
125
+ print("using Identity layer as FFN")
126
+
127
+ def f(*args, **kwargs):
128
+ return nn.Identity()
129
+
130
+ ffn_layer = f
131
+ else:
132
+ raise NotImplementedError
133
+
134
+ blocks_list = [
135
+ block_fn(
136
+ dim=embed_dim,
137
+ num_heads=num_heads,
138
+ mlp_ratio=mlp_ratio,
139
+ qkv_bias=qkv_bias,
140
+ proj_bias=proj_bias,
141
+ ffn_bias=ffn_bias,
142
+ drop_path=dpr[i],
143
+ norm_layer=norm_layer,
144
+ act_layer=act_layer,
145
+ ffn_layer=ffn_layer,
146
+ init_values=init_values,
147
+ )
148
+ for i in range(depth)
149
+ ]
150
+ if block_chunks > 0:
151
+ self.chunked_blocks = True
152
+ chunked_blocks = []
153
+ chunksize = depth // block_chunks
154
+ for i in range(0, depth, chunksize):
155
+ # this is to keep the block index consistent if we chunk the block list
156
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i: i + chunksize])
157
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
158
+ else:
159
+ self.chunked_blocks = False
160
+ self.blocks = nn.ModuleList(blocks_list)
161
+
162
+ self.norm = norm_layer(embed_dim)
163
+ self.head = nn.Identity()
164
+
165
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
166
+
167
+ self.init_weights()
168
+
169
+ def init_weights(self):
170
+ trunc_normal_(self.pos_embed, std=0.02)
171
+ if self.register_tokens is not None:
172
+ nn.init.normal_(self.register_tokens, std=1e-6)
173
+ named_apply(init_weights_vit_timm, self)
174
+
175
+ def interpolate_pos_encoding(self, x, w, h):
176
+ previous_dtype = x.dtype
177
+ npatch = x.shape[1] - 1
178
+ N = self.pos_embed.shape[1]
179
+ if npatch == N and w == h:
180
+ return self.pos_embed
181
+ patch_pos_embed = self.pos_embed.float()
182
+ dim = x.shape[-1]
183
+ w0 = w // self.patch_size
184
+ h0 = h // self.patch_size
185
+ # we add a small number to avoid floating point error in the interpolation
186
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
187
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
188
+
189
+ sqrt_N = math.sqrt(N)
190
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
191
+ patch_pos_embed = nn.functional.interpolate(
192
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
193
+ scale_factor=(sx, sy),
194
+ mode="bicubic",
195
+ antialias=self.interpolate_antialias,
196
+ )
197
+
198
+ assert int(w0) == patch_pos_embed.shape[-2]
199
+ assert int(h0) == patch_pos_embed.shape[-1]
200
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
201
+ return patch_pos_embed.to(previous_dtype)
202
+
203
+ def prepare_tokens_with_masks(self, x, masks=None):
204
+ B, nc, w, h = x.shape
205
+ x = self.patch_embed(x)
206
+ if masks is not None:
207
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
208
+
209
+ x = x + self.interpolate_pos_encoding(x, w, h)
210
+
211
+ if self.register_tokens is not None:
212
+ x = torch.cat(
213
+ (
214
+ x[:, :1],
215
+ self.register_tokens.expand(x.shape[0], -1, -1),
216
+ x[:, 1:],
217
+ ),
218
+ dim=1,
219
+ )
220
+
221
+ return x
222
+
223
+ def forward_features_list(self, x_list, masks_list):
224
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
225
+ for blk in self.blocks:
226
+ x = blk(x)
227
+
228
+ all_x = x
229
+ output = []
230
+ for x, masks in zip(all_x, masks_list):
231
+ x_norm = self.norm(x)
232
+ output.append(
233
+ {
234
+ "x_norm_clstoken": x_norm[:, 0],
235
+ "x_norm_regtokens": x_norm[:, 1: self.num_register_tokens + 1],
236
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1:],
237
+ "x_prenorm": x,
238
+ "masks": masks,
239
+ }
240
+ )
241
+ return output
242
+
243
+ def forward(self, x, masks=None):
244
+ if isinstance(x, list):
245
+ return self.forward_features_list(x, masks)
246
+
247
+ x = self.prepare_tokens_with_masks(x, masks)
248
+
249
+ for blk in self.blocks:
250
+ x = blk(x)
251
+
252
+ x_norm = self.norm(x)
253
+ return x_norm
254
+
255
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
256
+ x = self.prepare_tokens_with_masks(x)
257
+ # If n is an int, take the n last blocks. If it's a list, take them
258
+ output, total_block_len = [], len(self.blocks)
259
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
260
+ for i, blk in enumerate(self.blocks):
261
+ x = blk(x)
262
+ if i in blocks_to_take:
263
+ output.append(x)
264
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
265
+ return output
266
+
267
+ def _get_intermediate_layers_chunked(self, x, n=1):
268
+ x = self.prepare_tokens_with_masks(x)
269
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
270
+ # If n is an int, take the n last blocks. If it's a list, take them
271
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
272
+ for block_chunk in self.blocks:
273
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
274
+ x = blk(x)
275
+ if i in blocks_to_take:
276
+ output.append(x)
277
+ i += 1
278
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
279
+ return output
280
+
281
+ def get_intermediate_layers(
282
+ self,
283
+ x: torch.Tensor,
284
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
285
+ reshape: bool = False,
286
+ return_class_token: bool = False,
287
+ norm=True,
288
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
289
+ if self.chunked_blocks:
290
+ outputs = self._get_intermediate_layers_chunked(x, n)
291
+ else:
292
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
293
+ if norm:
294
+ outputs = [self.norm(out) for out in outputs]
295
+ class_tokens = [out[:, 0] for out in outputs]
296
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
297
+ if reshape:
298
+ B, _, w, h = x.shape
299
+ outputs = [
300
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
301
+ for out in outputs
302
+ ]
303
+ if return_class_token:
304
+ return tuple(zip(outputs, class_tokens))
305
+ return tuple(outputs)
306
+
307
+
308
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
309
+ """ViT weight initialization, original timm impl (for reproducibility)"""
310
+ if isinstance(module, nn.Linear):
311
+ trunc_normal_(module.weight, std=0.02)
312
+ if module.bias is not None:
313
+ nn.init.zeros_(module.bias)
314
+
315
+
316
+ class Encoder(nn.Module):
317
+ def __init__(self, model_type='small'):
318
+ super().__init__()
319
+ if model_type == 'tiny':
320
+ self.vit = DinoVisionTransformer(
321
+ img_size=256,
322
+ patch_size=16,
323
+ embed_dim=192,
324
+ depth=12,
325
+ num_heads=6,
326
+ mlp_ratio=4,
327
+ block_fn=partial(Block, attn_class=MemEffAttention),
328
+ num_register_tokens=0
329
+ )
330
+ path = "checkpoint/deit_tiny_patch16_224-a1311bcf.pth"
331
+
332
+ elif model_type == 'small':
333
+ self.vit = DinoVisionTransformer(
334
+ img_size=256,
335
+ patch_size=16,
336
+ embed_dim=384,
337
+ depth=12,
338
+ num_heads=6,
339
+ mlp_ratio=4,
340
+ block_fn=partial(Block, attn_class=MemEffAttention),
341
+ num_register_tokens=0
342
+ )
343
+ path = "checkpoint/dinov2_vits14_pretrain.pth"
344
+
345
+ else:
346
+ assert False, r'Encoder: check the vit model type'
347
+
348
+ state_dict = torch.load(path, map_location='cpu')['model'] \
349
+ if model_type == 'tiny' else torch.load(path, map_location='cpu')
350
+
351
+ for k in ['pos_embed', 'patch_embed.proj.weight']:
352
+ del state_dict[k]
353
+ msg = self.vit.load_state_dict(state_dict, strict=False)
354
+ print(' missing_keys:{},\n unexpected_keys:{}'.format(msg.missing_keys, msg.unexpected_keys))
355
+ print('model_type: {},\n checkpoint_path: {}'.format(model_type, path))
356
+
357
+ self.resnet = resnet18(pretrained=True)
358
+ self.drop = nn.Dropout(p=0.01)
359
+
360
+ # 新增特征融合模块
361
+ self.fusion_conv = nn.Sequential(
362
+ nn.Conv2d(512 + 384, 384, kernel_size=1), # 假设ViT embed_dim=384
363
+ nn.BatchNorm2d(384),
364
+ nn.ReLU(inplace=True)
365
+ )
366
+
367
+ def detail_capture(self, x):
368
+ x = self.resnet.conv1(x)
369
+ x = self.resnet.bn1(x)
370
+ x = self.resnet.relu(x)
371
+
372
+ x2 = self.drop(self.resnet.layer1(x))
373
+ x3 = self.resnet.layer2(x2)
374
+ x4 = self.resnet.layer3(x3)
375
+ x5 = self.resnet.layer4(x4)
376
+ return [x2, x3, x4, x5]
377
+
378
+ def forward(self, x, y):
379
+
380
+ v_x = self.vit(x)
381
+ v_y = self.vit(y)
382
+
383
+ v_x = rearrange(v_x, 'b (h w) c -> b c h w', h=16, w=16)
384
+ v_y = rearrange(v_y, 'b (h w) c -> b c h w', h=16, w=16)
385
+
386
+ c_x = self.detail_capture(x)
387
+ c_y = self.detail_capture(y)
388
+
389
+ fused_v_x = self.fusion_conv(torch.cat([c_x[-1], v_x], dim=1))
390
+ fused_v_y = self.fusion_conv(torch.cat([c_y[-1], v_y], dim=1))
391
+ return c_x[:-1] + [fused_v_x], c_y[:-1] + [fused_v_y]
model/layers/.gitkeep ADDED
File without changes
model/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
model/layers/__pycache__/.gitkeep ADDED
File without changes
model/layers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (444 Bytes). View file
 
model/layers/__pycache__/attention.cpython-39.pyc ADDED
Binary file (2.56 kB). View file
 
model/layers/__pycache__/block.cpython-39.pyc ADDED
Binary file (8.17 kB). View file
 
model/layers/__pycache__/dino_head.cpython-39.pyc ADDED
Binary file (1.96 kB). View file
 
model/layers/__pycache__/drop_path.cpython-39.pyc ADDED
Binary file (1.2 kB). View file
 
model/layers/__pycache__/layer_scale.cpython-39.pyc ADDED
Binary file (995 Bytes). View file
 
model/layers/__pycache__/mlp.cpython-39.pyc ADDED
Binary file (1.18 kB). View file
 
model/layers/__pycache__/patch_embed.cpython-39.pyc ADDED
Binary file (2.61 kB). View file
 
model/layers/__pycache__/swiglu_ffn.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
model/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
model/layers/block.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+
40
+ warnings.warn("xFormers is not available (Block)")
41
+
42
+
43
+ class Block(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ num_heads: int,
48
+ mlp_ratio: float = 4.0,
49
+ qkv_bias: bool = False,
50
+ proj_bias: bool = True,
51
+ ffn_bias: bool = True,
52
+ drop: float = 0.0,
53
+ attn_drop: float = 0.0,
54
+ init_values=None,
55
+ drop_path: float = 0.0,
56
+ act_layer: Callable[..., nn.Module] = nn.GELU,
57
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58
+ attn_class: Callable[..., nn.Module] = Attention,
59
+ ffn_layer: Callable[..., nn.Module] = Mlp,
60
+ ) -> None:
61
+ super().__init__()
62
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63
+ self.norm1 = norm_layer(dim)
64
+ self.attn = attn_class(
65
+ dim,
66
+ num_heads=num_heads,
67
+ qkv_bias=qkv_bias,
68
+ proj_bias=proj_bias,
69
+ attn_drop=attn_drop,
70
+ proj_drop=drop,
71
+ )
72
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.norm2 = norm_layer(dim)
76
+ mlp_hidden_dim = int(dim * mlp_ratio)
77
+ self.mlp = ffn_layer(
78
+ in_features=dim,
79
+ hidden_features=mlp_hidden_dim,
80
+ act_layer=act_layer,
81
+ drop=drop,
82
+ bias=ffn_bias,
83
+ )
84
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
86
+
87
+ self.sample_drop_ratio = drop_path
88
+
89
+ def forward(self, x: Tensor) -> Tensor:
90
+ def attn_residual_func(x: Tensor) -> Tensor:
91
+ return self.ls1(self.attn(self.norm1(x)))
92
+
93
+ def ffn_residual_func(x: Tensor) -> Tensor:
94
+ return self.ls2(self.mlp(self.norm2(x)))
95
+
96
+ if self.training and self.sample_drop_ratio > 0.1:
97
+ # the overhead is compensated only for a drop path rate larger than 0.1
98
+ x = drop_add_residual_stochastic_depth(
99
+ x,
100
+ residual_func=attn_residual_func,
101
+ sample_drop_ratio=self.sample_drop_ratio,
102
+ )
103
+ x = drop_add_residual_stochastic_depth(
104
+ x,
105
+ residual_func=ffn_residual_func,
106
+ sample_drop_ratio=self.sample_drop_ratio,
107
+ )
108
+ elif self.training and self.sample_drop_ratio > 0.0:
109
+ x = x + self.drop_path1(attn_residual_func(x))
110
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
111
+ else:
112
+ x = x + attn_residual_func(x)
113
+ x = x + ffn_residual_func(x)
114
+ return x
115
+
116
+
117
+ def drop_add_residual_stochastic_depth(
118
+ x: Tensor,
119
+ residual_func: Callable[[Tensor], Tensor],
120
+ sample_drop_ratio: float = 0.0,
121
+ ) -> Tensor:
122
+ # 1) extract subset using permutation
123
+ b, n, d = x.shape
124
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
125
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
126
+ x_subset = x[brange]
127
+
128
+ # 2) apply residual_func to get residual
129
+ residual = residual_func(x_subset)
130
+
131
+ x_flat = x.flatten(1)
132
+ residual = residual.flatten(1)
133
+
134
+ residual_scale_factor = b / sample_subset_size
135
+
136
+ # 3) add the residual
137
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
138
+ return x_plus_residual.view_as(x)
139
+
140
+
141
+ def get_branges_scales(x, sample_drop_ratio=0.0):
142
+ b, n, d = x.shape
143
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
144
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
145
+ residual_scale_factor = b / sample_subset_size
146
+ return brange, residual_scale_factor
147
+
148
+
149
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
150
+ if scaling_vector is None:
151
+ x_flat = x.flatten(1)
152
+ residual = residual.flatten(1)
153
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
154
+ else:
155
+ x_plus_residual = scaled_index_add(
156
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
157
+ )
158
+ return x_plus_residual
159
+
160
+
161
+ attn_bias_cache: Dict[Tuple, Any] = {}
162
+
163
+
164
+ def get_attn_bias_and_cat(x_list, branges=None):
165
+ """
166
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
167
+ """
168
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
170
+ if all_shapes not in attn_bias_cache.keys():
171
+ seqlens = []
172
+ for b, x in zip(batch_sizes, x_list):
173
+ for _ in range(b):
174
+ seqlens.append(x.shape[1])
175
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
176
+ attn_bias._batch_sizes = batch_sizes
177
+ attn_bias_cache[all_shapes] = attn_bias
178
+
179
+ if branges is not None:
180
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
181
+ else:
182
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
183
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
184
+
185
+ return attn_bias_cache[all_shapes], cat_tensors
186
+
187
+
188
+ def drop_add_residual_stochastic_depth_list(
189
+ x_list: List[Tensor],
190
+ residual_func: Callable[[Tensor, Any], Tensor],
191
+ sample_drop_ratio: float = 0.0,
192
+ scaling_vector=None,
193
+ ) -> Tensor:
194
+ # 1) generate random set of indices for dropping samples in the batch
195
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
196
+ branges = [s[0] for s in branges_scales]
197
+ residual_scale_factors = [s[1] for s in branges_scales]
198
+
199
+ # 2) get attention bias and index+concat the tensors
200
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
201
+
202
+ # 3) apply residual_func to get residual, and split the result
203
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
204
+
205
+ outputs = []
206
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
207
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
208
+ return outputs
209
+
210
+
211
+ class NestedTensorBlock(Block):
212
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
213
+ """
214
+ x_list contains a list of tensors to nest together and run
215
+ """
216
+ assert isinstance(self.attn, MemEffAttention)
217
+
218
+ if self.training and self.sample_drop_ratio > 0.0:
219
+
220
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
222
+
223
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.mlp(self.norm2(x))
225
+
226
+ x_list = drop_add_residual_stochastic_depth_list(
227
+ x_list,
228
+ residual_func=attn_residual_func,
229
+ sample_drop_ratio=self.sample_drop_ratio,
230
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
231
+ )
232
+ x_list = drop_add_residual_stochastic_depth_list(
233
+ x_list,
234
+ residual_func=ffn_residual_func,
235
+ sample_drop_ratio=self.sample_drop_ratio,
236
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
237
+ )
238
+ return x_list
239
+ else:
240
+
241
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
243
+
244
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls2(self.mlp(self.norm2(x)))
246
+
247
+ attn_bias, x = get_attn_bias_and_cat(x_list)
248
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
249
+ x = x + ffn_residual_func(x)
250
+ return attn_bias.split(x)
251
+
252
+ def forward(self, x_or_x_list):
253
+ if isinstance(x_or_x_list, Tensor):
254
+ return super().forward(x_or_x_list)
255
+ elif isinstance(x_or_x_list, list):
256
+ if not XFORMERS_AVAILABLE:
257
+ raise AssertionError("xFormers is required for using nested tensors")
258
+ return self.forward_nested(x_or_x_list)
259
+ else:
260
+ raise AssertionError
model/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
model/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
model/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
model/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
model/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
model/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
model/metric_tool.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ ################### metrics ###################
5
+ class AverageMeter(object):
6
+ """Computes and stores the average and current value"""
7
+
8
+ def __init__(self):
9
+ self.initialized = False
10
+ self.val = None
11
+ self.avg = None
12
+ self.sum = None
13
+ self.count = None
14
+
15
+ def initialize(self, val, weight):
16
+ self.val = val
17
+ self.avg = val
18
+ self.sum = val * weight
19
+ self.count = weight
20
+ self.initialized = True
21
+
22
+ def update(self, val, weight=1):
23
+ if not self.initialized:
24
+ self.initialize(val, weight)
25
+ else:
26
+ self.add(val, weight)
27
+
28
+ def add(self, val, weight):
29
+ self.val = val
30
+ self.sum += val * weight
31
+ self.count += weight
32
+ self.avg = self.sum / self.count
33
+
34
+ def value(self):
35
+ return self.val
36
+
37
+ def average(self):
38
+ return self.avg
39
+
40
+ def get_scores(self):
41
+ scores_dict = cm2score(self.sum)
42
+ return scores_dict
43
+
44
+ def clear(self):
45
+ self.initialized = False
46
+
47
+
48
+ ################### cm metrics ###################
49
+ class ConfuseMatrixMeter(AverageMeter):
50
+ """Computes and stores the average and current value"""
51
+
52
+ def __init__(self, n_class):
53
+ super(ConfuseMatrixMeter, self).__init__()
54
+ self.n_class = n_class
55
+
56
+ def update_cm(self, pr, gt, weight=1):
57
+ """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵"""
58
+ val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
59
+ self.update(val, weight)
60
+ current_score = cm2F1(val)
61
+ return current_score
62
+
63
+ def get_scores(self):
64
+ scores_dict = cm2score(self.sum)
65
+ return scores_dict
66
+
67
+
68
+ def harmonic_mean(xs):
69
+ harmonic_mean = len(xs) / sum((x + 1e-6) ** -1 for x in xs)
70
+ return harmonic_mean
71
+
72
+
73
+ def cm2F1(confusion_matrix):
74
+ hist = confusion_matrix
75
+ tp = hist[1, 1]
76
+ fn = hist[1, 0]
77
+ fp = hist[0, 1]
78
+ tn = hist[0, 0]
79
+ # recall
80
+ recall = tp / (tp + fn + np.finfo(np.float32).eps)
81
+ # precision
82
+ precision = tp / (tp + fp + np.finfo(np.float32).eps)
83
+ # F1 score
84
+ f1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
85
+ return f1
86
+
87
+
88
+ def cm2score(confusion_matrix):
89
+ hist = confusion_matrix
90
+ tp = hist[1, 1]
91
+ fn = hist[1, 0]
92
+ fp = hist[0, 1]
93
+ tn = hist[0, 0]
94
+ # acc
95
+ oa = (tp + tn) / (tp + fn + fp + tn + np.finfo(np.float32).eps)
96
+ # recall
97
+ recall = tp / (tp + fn + np.finfo(np.float32).eps)
98
+ # precision
99
+ precision = tp / (tp + fp + np.finfo(np.float32).eps)
100
+ # F1 score
101
+ f1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
102
+ # IoU
103
+ iou = tp / (tp + fp + fn + np.finfo(np.float32).eps)
104
+ # pre
105
+ pre = ((tp + fn) * (tp + fp) + (tn + fp) * (tn + fn)) / (tp + fp + tn + fn) ** 2
106
+ # kappa
107
+ kappa = (oa - pre) / (1 - pre)
108
+ score_dict = {'Kappa': kappa, 'IoU': iou, 'F1': f1, 'OA': oa, 'recall': recall, 'precision': precision, 'Pre': pre}
109
+ return score_dict
110
+
111
+
112
+ def get_confuse_matrix(num_classes, label_gts, label_preds):
113
+ """计算一组预测的混淆矩阵"""
114
+
115
+ def __fast_hist(label_gt, label_pred):
116
+ """
117
+ Collect values for Confusion Matrix
118
+ For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
119
+ :param label_gt: <np.array> ground-truth
120
+ :param label_pred: <np.array> prediction
121
+ :return: <np.ndarray> values for confusion matrix
122
+ """
123
+ mask = (label_gt >= 0) & (label_gt < num_classes)
124
+ hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
125
+ minlength=num_classes ** 2).reshape(num_classes, num_classes)
126
+ return hist
127
+
128
+ confusion_matrix = np.zeros((num_classes, num_classes))
129
+ for lt, lp in zip(label_gts, label_preds):
130
+ confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
131
+ return confusion_matrix
model/resnet.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ import torch
4
+ import torch.utils.model_zoo as model_zoo
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10
+ 'resnet152']
11
+
12
+
13
+ model_urls = {
14
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19
+ }
20
+
21
+
22
+ def conv3x3(in_planes, out_planes, stride=1):
23
+ """3x3 convolution with padding"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25
+ padding=1, bias=False)
26
+
27
+
28
+
29
+
30
+ class BasicBlock(nn.Module):
31
+ expansion = 1
32
+
33
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
34
+ super(BasicBlock, self).__init__()
35
+ self.conv1 = conv3x3(inplanes, planes, stride)
36
+ self.bn1 = nn.BatchNorm2d(planes)
37
+ self.relu = nn.ReLU(inplace=True)
38
+ self.conv2 = conv3x3(planes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.downsample = downsample
41
+ self.stride = stride
42
+
43
+ def forward(self, x):
44
+ residual = x
45
+
46
+ out = self.conv1(x)
47
+ out = self.bn1(out)
48
+ out = self.relu(out)
49
+
50
+ out = self.conv2(out)
51
+ out = self.bn2(out)
52
+
53
+ if self.downsample is not None:
54
+ residual = self.downsample(x)
55
+
56
+ out += residual
57
+ out = self.relu(out)
58
+
59
+ return out
60
+
61
+
62
+ class Bottleneck(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68
+ self.bn1 = nn.BatchNorm2d(planes)
69
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70
+ padding=1, bias=False)
71
+ self.bn2 = nn.BatchNorm2d(planes)
72
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73
+ self.bn3 = nn.BatchNorm2d(planes * 4)
74
+ self.relu = nn.ReLU(inplace=True)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ residual = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class ResNet(nn.Module):
102
+
103
+ def __init__(self, block, layers, num_classes=1000):
104
+ self.inplanes = 64
105
+ super(ResNet, self).__init__()
106
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
107
+ bias=False)
108
+ self.bn1 = nn.BatchNorm2d(64)
109
+ self.relu = nn.ReLU(inplace=True)
110
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
111
+ self.layer1 = self._make_layer(block, 64, layers[0])
112
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115
+ self.avgpool = nn.AvgPool2d(7, stride=1)
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ def _make_layer(self, block, planes, blocks, stride=1):
119
+ downsample = None
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ nn.Conv2d(self.inplanes, planes * block.expansion,
123
+ kernel_size=1, stride=stride, bias=False),
124
+ nn.BatchNorm2d(planes * block.expansion),
125
+ )
126
+
127
+ layers = []
128
+ layers.append(block(self.inplanes, planes, stride, downsample))
129
+ self.inplanes = planes * block.expansion
130
+ for i in range(1, blocks):
131
+ layers.append(block(self.inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ x = self.conv1(x)
137
+ x = self.bn1(x)
138
+ x = self.relu(x)
139
+ x = self.maxpool(x)
140
+
141
+ x = self.layer1(x)
142
+ x = self.layer2(x)
143
+ x = self.layer3(x)
144
+ x = self.layer4(x)
145
+
146
+ x = self.avgpool(x)
147
+ x = x.view(x.size(0), -1)
148
+ x = self.fc(x)
149
+
150
+ return x
151
+
152
+
153
+ def resnet18(pretrained=False, **kwargs):
154
+ """Constructs a ResNet-18 model.
155
+ Args:
156
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
157
+ """
158
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
159
+ if pretrained:
160
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
161
+ return model
162
+
163
+
164
+ def resnet34(pretrained=False, **kwargs):
165
+ """Constructs a ResNet-34 model.
166
+ Args:
167
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
168
+ """
169
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170
+ if pretrained:
171
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
172
+ return model
173
+
174
+
175
+ def resnet50(pretrained=False, **kwargs):
176
+ """Constructs a ResNet-50 model.
177
+ Args:
178
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
179
+ """
180
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
181
+ if pretrained:
182
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
183
+ return model
184
+
185
+
186
+ def resnet101(pretrained=False, **kwargs):
187
+ """Constructs a ResNet-101 model.
188
+ Args:
189
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
190
+ """
191
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
192
+ if pretrained:
193
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
194
+ return model
195
+
196
+
197
+ def resnet152(pretrained=False, **kwargs):
198
+ """Constructs a ResNet-152 model.
199
+ Args:
200
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
201
+ """
202
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
203
+ if pretrained:
204
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
205
+ return model
206
+
207
+
208
+ if __name__ == '__main__':
209
+ m = resnet18(pretrained=True, vit_dim=768)
210
+ x = torch.rand(1, 3, 256, 256)
211
+ vit = [torch.rand(1, 256, 768), torch.rand(1, 256, 768), torch.rand(1, 256, 768)]
212
+ x2, x3, x4 = m(x, vit)
213
+ print(x2.shape, x3.shape, x4.shape)
model/trainer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.encoder import Encoder
5
+ from model.decoder import Decoder
6
+
7
+ from model.utils import weight_init
8
+
9
+
10
+ class Trainer(nn.Module):
11
+ def __init__(self, model_type='small'):
12
+ super().__init__()
13
+ if model_type == 'tiny':
14
+ embed_dim = 192
15
+ elif model_type == 'small':
16
+ embed_dim = 384
17
+ else:
18
+ assert False, r'Trainer: check the vit model type'
19
+
20
+ self.encoder = Encoder(model_type)
21
+
22
+ self.decoder = Decoder(in_dim=[64, 128, 256, embed_dim])
23
+ weight_init(self.decoder)
24
+
25
+ def forward(self, x, y):
26
+ fx, fy = self.encoder(x, y)
27
+ pred = self.decoder(fx, fy)
28
+
29
+ return pred
30
+
model/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import random
6
+
7
+
8
+ def weight_init(module):
9
+ for n, m in module.named_children():
10
+ print('initialize: '+n)
11
+ if isinstance(m, nn.Conv2d):
12
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
13
+ if m.bias is not None:
14
+ nn.init.zeros_(m.bias)
15
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
16
+ nn.init.ones_(m.weight)
17
+ if m.bias is not None:
18
+ nn.init.zeros_(m.bias)
19
+ elif isinstance(m, nn.Linear):
20
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
21
+ if m.bias is not None:
22
+ nn.init.zeros_(m.bias)
23
+ elif isinstance(m, nn.Sequential):
24
+ for f, g in m.named_children():
25
+ print('initialize: ' + f)
26
+ if isinstance(g, nn.Conv2d):
27
+ nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu')
28
+ if g.bias is not None:
29
+ nn.init.zeros_(g.bias)
30
+ elif isinstance(g, (nn.BatchNorm2d, nn.GroupNorm)):
31
+ nn.init.ones_(g.weight)
32
+ if g.bias is not None:
33
+ nn.init.zeros_(g.bias)
34
+ elif isinstance(g, nn.Linear):
35
+ nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu')
36
+ if g.bias is not None:
37
+ nn.init.zeros_(g.bias)
38
+ elif isinstance(m, nn.AdaptiveAvgPool2d) or isinstance(m, nn.AdaptiveMaxPool2d) or isinstance(m, nn.ModuleList) or isinstance(m, nn.BCELoss):
39
+ a=1
40
+ else:
41
+ pass
42
+
43
+
44
+ def init_seed(seed):
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed(seed)
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+
50
+
51
+ def BCEDiceLoss(inputs, targets):
52
+ # print(inputs.shape, targets.shape)
53
+ bce = F.binary_cross_entropy(inputs, targets)
54
+ inter = (inputs * targets).sum()
55
+ eps = 1e-5
56
+ dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps)
57
+ # print(bce.item(), inter.item(), inputs.sum().item(), dice.item())
58
+ return bce + 1 - dice
59
+
60
+
61
+ def BCE(inputs, targets):
62
+ # print(inputs.shape, targets.shape)
63
+ bce = F.binary_cross_entropy(inputs, targets)
64
+ return bce
65
+
66
+
67
+ def adjust_learning_rate(args, optimizer, epoch, iter, max_batches, lr_factor=1):
68
+ if args.lr_mode == 'step':
69
+ lr = args.lr * (0.1 ** (epoch // args.step_loss))
70
+ elif args.lr_mode == 'poly':
71
+ cur_iter = iter
72
+ max_iter = max_batches * args.max_epochs
73
+ lr = args.lr * (1 - cur_iter * 1.0 / max_iter) ** 0.9
74
+ else:
75
+ raise ValueError('Unknown lr mode {}'.format(args.lr_mode))
76
+ if epoch == 0 and iter < 200:
77
+ lr = args.lr * 0.9 * (iter + 1) / 200 + 0.1 * args.lr # warm_up
78
+ lr *= lr_factor
79
+ for param_group in optimizer.param_groups:
80
+ param_group['lr'] = lr
81
+ return lr
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchaudio==0.13.1
3
+ torchcam==0.3.2
4
+ torchgeo==0.4.1
5
+ torchmetrics==0.11.4
6
+ torchvision==0.14.1
7
+ numpy==1.21.6
8
+ Pillow==9.2.0
9
+ einops==0.6.0
10
+ opencv-python==4.6.0.66