jamino30 commited on
Commit
ecf0440
·
verified ·
1 Parent(s): 4cb4b28

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. u2net/data_loader.py +26 -11
  2. u2net/evaluate.py +6 -3
  3. u2net/model.py +8 -5
  4. u2net/train.py +44 -23
u2net/data_loader.py CHANGED
@@ -3,12 +3,13 @@ import random
3
  from PIL import Image
4
 
5
  import torch
 
6
  from torchvision import transforms
7
  from sklearn.model_selection import train_test_split
8
-
9
 
10
  class SaliencyDataset(torch.utils.data.Dataset):
11
- def __init__(self, split, img_size=512, val_split_ratio=0.05):
12
  self.img_size = img_size
13
  self.split = split
14
  self.image_dir, self.mask_dir = self.set_directories(split)
@@ -19,8 +20,14 @@ class SaliencyDataset(torch.utils.data.Dataset):
19
  self.images = train_imgs if split == 'train' else val_imgs
20
  else:
21
  self.images = all_images
 
 
 
 
 
22
 
23
- self.resize = transforms.Resize((img_size, img_size))
 
24
  self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
25
 
26
  def __len__(self):
@@ -33,30 +40,38 @@ class SaliencyDataset(torch.utils.data.Dataset):
33
  mask_path = os.path.join(self.mask_dir, mask_filename)
34
 
35
  img = Image.open(img_path).convert('RGB')
36
- mask = Image.open(mask_path).convert('L')
37
-
38
- img, mask = self.resize(img), self.resize(mask)
39
 
 
40
  if self.split == 'train':
41
  img, mask = self.apply_augmentations(img, mask)
42
 
43
  img = transforms.ToTensor()(img)
44
  img = self.normalize(img)
45
  mask = transforms.ToTensor()(mask).squeeze(0)
 
46
  return img, mask
47
 
48
  def apply_augmentations(self, img, mask):
49
- if random.random() > 0.5: # horizontal flip
50
  img = transforms.functional.hflip(img)
51
  mask = transforms.functional.hflip(mask)
52
 
53
- if random.random() > 0.5: # random resized crop
54
  resized_crop = transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.0))
55
  i, j, h, w = resized_crop.get_params(img, scale=(0.8, 1.0), ratio=(3/4, 4/3))
56
- img = transforms.functional.resized_crop(img, i, j, h, w, (self.img_size, self.img_size))
57
- mask = transforms.functional.resized_crop(mask, i, j, h, w, (self.img_size, self.img_size))
 
 
 
 
 
 
58
 
59
- if random.random() > 0.5: # color jitter
60
  color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
61
  img = color_jitter(img)
62
 
 
3
  from PIL import Image
4
 
5
  import torch
6
+ import numpy as np
7
  from torchvision import transforms
8
  from sklearn.model_selection import train_test_split
9
+
10
 
11
  class SaliencyDataset(torch.utils.data.Dataset):
12
+ def __init__(self, split, img_size=512, val_split_ratio=0.05, subset_ratio=None):
13
  self.img_size = img_size
14
  self.split = split
15
  self.image_dir, self.mask_dir = self.set_directories(split)
 
20
  self.images = train_imgs if split == 'train' else val_imgs
21
  else:
22
  self.images = all_images
23
+
24
+ if subset_ratio: # subsampling
25
+ total_samples = len(self.images)
26
+ indices = np.random.choice(total_samples, int(total_samples * subset_ratio), replace=False)
27
+ self.images = [self.images[i] for i in indices]
28
 
29
+ self.img_resize = transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.BILINEAR)
30
+ self.mask_resize = transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST)
31
  self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
32
 
33
  def __len__(self):
 
40
  mask_path = os.path.join(self.mask_dir, mask_filename)
41
 
42
  img = Image.open(img_path).convert('RGB')
43
+ mask = Image.open(mask_path)
44
+ if mask.mode != 'L': mask = mask.convert('L')
45
+ mask = mask.point(lambda p: 255 if p > 128 else 0)
46
 
47
+ img, mask = self.img_resize(img), self.mask_resize(mask)
48
  if self.split == 'train':
49
  img, mask = self.apply_augmentations(img, mask)
50
 
51
  img = transforms.ToTensor()(img)
52
  img = self.normalize(img)
53
  mask = transforms.ToTensor()(mask).squeeze(0)
54
+
55
  return img, mask
56
 
57
  def apply_augmentations(self, img, mask):
58
+ if random.random() > 0.5: # horizontal flip
59
  img = transforms.functional.hflip(img)
60
  mask = transforms.functional.hflip(mask)
61
 
62
+ if random.random() > 0.5: # random resized crop
63
  resized_crop = transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.0))
64
  i, j, h, w = resized_crop.get_params(img, scale=(0.8, 1.0), ratio=(3/4, 4/3))
65
+ img = transforms.functional.resized_crop(
66
+ img, i, j, h, w, (self.img_size, self.img_size),
67
+ interpolation=transforms.InterpolationMode.BILINEAR
68
+ )
69
+ mask = transforms.functional.resized_crop(
70
+ mask, i, j, h, w, (self.img_size, self.img_size),
71
+ interpolation=transforms.InterpolationMode.NEAREST
72
+ )
73
 
74
+ if random.random() > 0.5: # color jitter
75
  color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
76
  img = color_jitter(img)
77
 
u2net/evaluate.py CHANGED
@@ -3,6 +3,7 @@ from tqdm import tqdm
3
  import torch
4
  import torch.nn as nn
5
  from torch.utils.data import DataLoader
 
6
 
7
  from data_loader import PASCALSDataset
8
  from model import U2Net
@@ -11,7 +12,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  print('Device:', device)
12
 
13
  def load_model(model, model_path):
14
- state_dict = torch.load(model_path, map_location=device, weights_only=False)
15
  model.load_state_dict(state_dict)
16
  model.eval()
17
 
@@ -28,12 +29,14 @@ def eval(model, loader, criterion):
28
 
29
 
30
  if __name__ == '__main__':
31
- batch_size = 40
32
 
 
 
33
  loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
34
  model = U2Net().to(device)
35
  model = nn.DataParallel(model)
36
- load_model(model, 'results/inter-u2net-duts.pt')
37
 
38
  loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
39
 
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.utils.data import DataLoader
6
+ from safetensors.torch import load_file
7
 
8
  from data_loader import PASCALSDataset
9
  from model import U2Net
 
12
  print('Device:', device)
13
 
14
  def load_model(model, model_path):
15
+ state_dict = load_file(model_path, device=device.type)
16
  model.load_state_dict(state_dict)
17
  model.eval()
18
 
 
29
 
30
 
31
  if __name__ == '__main__':
32
+ batch_size = 1
33
 
34
+ model_type = input('Model type [b,f]: ')
35
+ model_name = 'best-u2net-duts-msra.safetensors' if model_type == 'b' else 'u2net-duts-msra.safetensors'
36
  loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
37
  model = U2Net().to(device)
38
  model = nn.DataParallel(model)
39
+ load_model(model, f'results/{model_name}')
40
 
41
  loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
42
 
u2net/model.py CHANGED
@@ -9,17 +9,19 @@ def init_weight(layer):
9
 
10
 
11
  class ConvBlock(nn.Module):
12
- def __init__(self, in_channel, out_channel, dilation=1):
13
  super(ConvBlock, self).__init__()
14
  self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
15
  self.bn = nn.BatchNorm2d(out_channel)
16
  self.relu = nn.ReLU(inplace=True)
 
17
  init_weight(self.conv)
18
 
19
  def forward(self, x):
20
  x = self.conv(x)
21
  x = self.bn(x)
22
  x = self.relu(x)
 
23
  return x
24
 
25
 
@@ -93,7 +95,7 @@ class RSU4F(nn.Module):
93
 
94
 
95
  class U2Net(nn.Module):
96
- def __init__(self):
97
  super(U2Net, self).__init__()
98
  self.enc = nn.ModuleList([
99
  RSU(L=7, C_in=3, C_out=64, M=32),
@@ -123,6 +125,7 @@ class U2Net(nn.Module):
123
 
124
  self.lastconv = nn.Conv2d(6, 1, 1)
125
  self.downsample = nn.MaxPool2d(2, stride=2)
 
126
 
127
  init_weight(self.lastconv)
128
  for conv in self.convs:
@@ -143,10 +146,10 @@ class U2Net(nn.Module):
143
 
144
  side_out = []
145
  for i, conv in enumerate(self.convs):
146
- if i == 0: side_out.append(conv(dec_out[5]))
147
- else: side_out.append(self.upsample(conv(dec_out[5-i]), side_out[0]))
148
 
149
  side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
150
 
151
- # logits
152
  return [s.squeeze(1) for s in side_out]
 
9
 
10
 
11
  class ConvBlock(nn.Module):
12
+ def __init__(self, in_channel, out_channel, dilation=1, dropout_rate=0.3):
13
  super(ConvBlock, self).__init__()
14
  self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
15
  self.bn = nn.BatchNorm2d(out_channel)
16
  self.relu = nn.ReLU(inplace=True)
17
+ self.dropout = nn.Dropout2d(p=dropout_rate) # custom - add dropout layer
18
  init_weight(self.conv)
19
 
20
  def forward(self, x):
21
  x = self.conv(x)
22
  x = self.bn(x)
23
  x = self.relu(x)
24
+ self.dropout(x)
25
  return x
26
 
27
 
 
95
 
96
 
97
  class U2Net(nn.Module):
98
+ def __init__(self, dropout_rate=0.3):
99
  super(U2Net, self).__init__()
100
  self.enc = nn.ModuleList([
101
  RSU(L=7, C_in=3, C_out=64, M=32),
 
125
 
126
  self.lastconv = nn.Conv2d(6, 1, 1)
127
  self.downsample = nn.MaxPool2d(2, stride=2)
128
+ self.dropout = nn.Dropout(p=dropout_rate) # custom - add dropout layer
129
 
130
  init_weight(self.lastconv)
131
  for conv in self.convs:
 
146
 
147
  side_out = []
148
  for i, conv in enumerate(self.convs):
149
+ if i == 0: side_out.append(self.dropout(conv(dec_out[5])))
150
+ else: side_out.append(self.upsample(self.dropout(conv(dec_out[5-i])), side_out[0]))
151
 
152
  side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
153
 
154
+ # logits (no sigmoid)
155
  return [s.squeeze(1) for s in side_out]
u2net/train.py CHANGED
@@ -14,6 +14,20 @@ from model import U2Net
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  scaler = GradScaler()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def train_one_epoch(model, loader, criterion, optimizer):
18
  model.train()
19
  running_loss = 0.
@@ -43,47 +57,54 @@ def validate(model, loader, criterion):
43
  avg_loss = running_loss / len(loader)
44
  return avg_loss
45
 
 
 
 
 
 
46
 
47
  if __name__ == '__main__':
48
  batch_size = 40
49
  valid_batch_size = 80
50
- epochs = 100
51
 
52
  lr = 1e-3
53
- loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
 
 
 
54
 
55
- model_name = 'u2net-duts'
56
  model = U2Net()
57
- model = torch.nn.DataParallel(model.to(device))
58
- optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
59
 
60
  train_loader = DataLoader(
61
  ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
62
  batch_size=batch_size, shuffle=True, pin_memory=True,
63
- num_workers=16, persistent_workers=True
64
  )
65
  valid_loader = DataLoader(
66
  ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
67
  batch_size=valid_batch_size, shuffle=False, pin_memory=True,
68
- num_workers=16, persistent_workers=True
69
  )
70
 
71
  best_val_loss = float('inf')
72
  losses = {'train': [], 'val': []}
73
- for epoch in tqdm(range(epochs), desc='Epochs'):
74
- torch.cuda.empty_cache()
75
- train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
76
- val_loss = validate(model, valid_loader, loss_fn)
77
- losses['train'].append(train_loss)
78
- losses['val'].append(val_loss)
79
-
80
- if val_loss < best_val_loss:
81
- best_val_loss = val_loss
82
- save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
83
- print('Best model saved.')
84
-
85
- print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
86
 
87
- save_file(model.state_dict(), f'results/{model_name}.safetensors')
88
- with open('results/loss.txt', 'wb') as f:
89
- pickle.dump(losses, f)
 
 
 
 
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  scaler = GradScaler()
16
 
17
+
18
+ class DiceLoss(nn.Module):
19
+ def __init__(self):
20
+ super(DiceLoss, self).__init__()
21
+
22
+ def forward(self, inputs, targets, smooth=1):
23
+ inputs = torch.sigmoid(inputs)
24
+ inputs = inputs.view(-1)
25
+ targets = targets.view(-1)
26
+ intersection = (inputs * targets).sum()
27
+ dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
28
+ return 1 - dice
29
+
30
+
31
  def train_one_epoch(model, loader, criterion, optimizer):
32
  model.train()
33
  running_loss = 0.
 
57
  avg_loss = running_loss / len(loader)
58
  return avg_loss
59
 
60
+ def save(model, model_name, losses):
61
+ save_file(model.state_dict(), f'results/{model_name}.safetensors')
62
+ with open('results/loss.txt', 'wb') as f:
63
+ pickle.dump(losses, f)
64
+
65
 
66
  if __name__ == '__main__':
67
  batch_size = 40
68
  valid_batch_size = 80
69
+ epochs = 200
70
 
71
  lr = 1e-3
72
+ loss_fn_bce = nn.BCEWithLogitsLoss(reduction='mean')
73
+ loss_fn_dice = DiceLoss()
74
+ alpha = 0.6
75
+ loss_fn = lambda o, m: alpha * loss_fn_bce(o, m) + (1 - alpha) * loss_fn_dice(o, m)
76
 
77
+ model_name = 'u2net-duts-msra'
78
  model = U2Net()
79
+ model = torch.nn.parallel.DataParallel(model.to(device))
80
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
81
 
82
  train_loader = DataLoader(
83
  ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
84
  batch_size=batch_size, shuffle=True, pin_memory=True,
85
+ num_workers=8, persistent_workers=True
86
  )
87
  valid_loader = DataLoader(
88
  ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
89
  batch_size=valid_batch_size, shuffle=False, pin_memory=True,
90
+ num_workers=8, persistent_workers=True
91
  )
92
 
93
  best_val_loss = float('inf')
94
  losses = {'train': [], 'val': []}
95
+
96
+ # training loop
97
+ try:
98
+ for epoch in range(epochs):
99
+ train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
100
+ val_loss = validate(model, valid_loader, loss_fn)
101
+ losses['train'].append(train_loss)
102
+ losses['val'].append(val_loss)
 
 
 
 
 
103
 
104
+ if val_loss < best_val_loss:
105
+ best_val_loss = val_loss
106
+ save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
107
+
108
+ print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
109
+ finally:
110
+ save(model, model_name, losses)