Andres Felipe Ruiz-Hurtado commited on
Commit
bc97962
·
1 Parent(s): c16482d
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.6.0
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Root analysis using deep learning
 
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.6.0
8
+ app_file: main.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Root analysis using deep learning
__pycache__/processsors.cpython-312.pyc ADDED
Binary file (8.33 kB). View file
 
dependecies/__init__.py ADDED
File without changes
dependecies/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (160 Bytes). View file
 
dependecies/segroot/__init__.py ADDED
File without changes
dependecies/segroot/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (168 Bytes). View file
 
dependecies/segroot/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (9.13 kB). View file
 
dependecies/segroot/__pycache__/model.cpython-312.pyc ADDED
Binary file (7.64 kB). View file
 
dependecies/segroot/__pycache__/paired_transforms_pt04.cpython-312.pyc ADDED
Binary file (56.8 kB). View file
 
dependecies/segroot/binarize_crop.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import skimage.io as io
5
+ from skimage.morphology import dilation
6
+ import pickle
7
+ import numpy as np
8
+ import argparse
9
+ from dataloader import pad_pair_256
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--dilate",
14
+ default=0,
15
+ type=int,
16
+ help="dilation degree of masks")
17
+
18
+ args = parser.parse_args()
19
+ data_dir = Path('../data/data_raw')
20
+ mask_dir = Path('../data/masks')
21
+ mask_dir.mkdir(exist_ok=True, parents=True)
22
+
23
+ imgs = sorted(list(data_dir.glob('*Untitled.jpg')))
24
+ print('original images count : ', len(imgs))
25
+ masks = sorted(list(data_dir.glob('*Untitled-mask.jpg')))
26
+ print('original masks count : ', len(masks))
27
+
28
+ # generate binary masks for every annotated images
29
+ for m in masks:
30
+ mask = io.imread(m.as_posix(), as_gray=True)
31
+ # binarize
32
+ mask[mask > 0.5 ] = 1.0
33
+ mask[mask <= 0.5] = 0.0
34
+ for i in range(args.dilate):
35
+ mask = dilation(mask)
36
+ print('binary masks dilated !!!')
37
+ plt.imsave((mask_dir / m.parts[-1]).as_posix(), mask, cmap='gray')
38
+ print('binary masks generated !!!')
39
+
40
+ # save idx info in a dictionary
41
+ info_dict = {k: v.parts[-1] for k, v in enumerate(imgs)}
42
+ with open('../data/info.pkl', 'wb') as handle:
43
+ pickle.dump(info_dict, handle)
44
+ print('index info saved!!!')
45
+
46
+ # crop the padded image to generate 256*256 subimages
47
+ new_masks = sorted(list(mask_dir.glob('*Untitled-mask.jpg')))
48
+ print('new_mask length : ',len(new_masks))
49
+
50
+ subimg_path = Path('../data/subimg')
51
+ subimg_path.mkdir(exist_ok=True, parents=True)
52
+ submask_path = Path('../data/submask')
53
+ submask_path.mkdir(exist_ok=True, parents=True)
54
+
55
+ for idx, (mask_path, img_path) in enumerate(zip(new_masks, imgs)):
56
+ mask = Image.open(mask_path)
57
+ img = Image.open(img_path)
58
+ new_img, new_mask = pad_pair_256(img, mask)
59
+ new_img, new_mask = np.array(new_img), np.array(new_mask)
60
+ # padded shape (2560, 2304)
61
+ w, h, _ = new_img.shape
62
+ for i in range(int(w/256)):
63
+ for j in range(int(h/256)):
64
+ subimg = new_img[i*256:(i+1)*256, j*256:(j+1)*256, :]
65
+ subimg_fn = '{}/{}-{}-{}.png'.format(
66
+ Path('../data/subimg').as_posix(), idx, i, j)
67
+ plt.imsave(subimg_fn, subimg)
68
+ submask_fn = '{}/{}-{}-{}.png'.format(
69
+ Path('../data/submask').as_posix(), idx, i, j)
70
+ submask = new_mask[i*256:(i+1)*256, j*256:(j+1)*256]
71
+ plt.imsave(submask_fn, submask, cmap='gray')
72
+ print('No.{} image & mask cropped!!!'.format(idx))
dependecies/segroot/dataloader.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import itertools
3
+ import pickle
4
+ import torch
5
+ from torchvision import models
6
+ from pathlib import Path
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset, DataLoader, Sampler
9
+
10
+ import dependecies.segroot.paired_transforms_pt04 as p_tr
11
+
12
+ train_transform = p_tr.Compose([
13
+ p_tr.RandomCrop(256),
14
+ p_tr.RandomRotation((90, 90)),
15
+ p_tr.RandomRotation((180, 180)),
16
+ p_tr.RandomRotation((270, 270)),
17
+ p_tr.RandomHorizontalFlip(),
18
+ p_tr.RandomVerticalFlip(),
19
+ p_tr.ToTensor()
20
+ ])
21
+
22
+ # normalize = p_tr.Normalize([0.35042979, 0.44016893, 0.2340332],
23
+ # [0.20999724, 0.25972678, 0.13885915])
24
+ normalize = p_tr.Normalize([0.5, 0.5, 0.5],
25
+ [0.5, 0.5, 0.5])
26
+
27
+
28
+ def pad_pair_256(image, gt):
29
+ w, h = image.size
30
+ new_w = ((w - 1) // 256 + 1) * 256
31
+ new_h = ((h - 1) // 256 + 1) * 256
32
+ new_image = Image.new("RGB", (new_w, new_h))
33
+ new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2))
34
+ new_gt = Image.new("L", (new_w, new_h))
35
+ new_gt.paste(gt, ((new_w - w) // 2, (new_h - h) // 2))
36
+ return new_image, new_gt
37
+
38
+
39
+ def convert_png(image, gt):
40
+ new_image = Image.new('RGB', (256, 256))
41
+ new_image.paste(image)
42
+ new_gt = Image.new('L', (256, 256))
43
+ new_gt.paste(gt)
44
+ return new_image, new_gt
45
+
46
+
47
+ def get_paths(root_dir, im_ids):
48
+ imgs = []
49
+ for i in im_ids:
50
+ tmp = Path(root_dir).glob('*{}-*.png'.format(i))
51
+ tmp = [p for p in tmp if p.parts[-1].startswith(str(i)+'-')]
52
+ imgs = imgs + list(tmp)
53
+ return imgs
54
+
55
+
56
+ class LoopSampler(Sampler):
57
+ def __init__(self, data_source):
58
+ self.data_source = data_source
59
+
60
+ def __iter__(self):
61
+ return itertools.cycle(range(len(self.data_source)))
62
+
63
+ def __len__(self):
64
+ return len(self.data_source)
65
+
66
+
67
+ class TrainDataset(Dataset):
68
+ def __init__(self, im_ids):
69
+ self.root_dir = '../data/data_raw'
70
+ self.mask_dir = '../data/mask'
71
+ self.im_ids = im_ids
72
+ with open('../data/info.pkl', 'rb') as handle:
73
+ self.info = pickle.load(handle)
74
+ self.fns = [self.info[im_id] for im_id in im_ids]
75
+
76
+ def __getitem__(self, index):
77
+ im_fn = self.fns[index]
78
+ im_name = os.path.join(self.root_dir, im_fn)
79
+ gt_name = os.path.join(
80
+ self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
81
+ image = Image.open(im_name)
82
+ gt = Image.open(gt_name)
83
+ image, gt = pad_pair_256(image, gt)
84
+
85
+ image, gt = train_transform(image, gt)
86
+ image = normalize(image)
87
+
88
+ return image, gt
89
+
90
+ def __len__(self):
91
+ return len(self.im_ids)
92
+
93
+
94
+ class StaticTrainDataset(Dataset):
95
+ def __init__(self, im_ids):
96
+ self.subimgs = sorted(get_paths('../data/subimg', im_ids))
97
+ self.submasks = sorted(get_paths('../data/submask', im_ids))
98
+ self.im_ids = im_ids
99
+
100
+ def __getitem__(self, index):
101
+ im_name = self.subimgs[index]
102
+ gt_name = self.submasks[index]
103
+ image = Image.open(im_name)
104
+ gt = Image.open(gt_name)
105
+ image, gt = convert_png(image, gt)
106
+
107
+ image, gt = train_transform(image, gt)
108
+ image = normalize(image)
109
+
110
+ return image, gt
111
+
112
+ def __len__(self):
113
+ return len(self.im_ids * 90)
114
+
115
+
116
+ class TrainDataLoader():
117
+ def __init__(self, dataset, batch_size, num_workers=0):
118
+ self.dataset = dataset
119
+ self.dataloader = DataLoader(self.dataset, batch_size=batch_size,
120
+ num_workers=num_workers, sampler=LoopSampler(self.dataset))
121
+ self.dl = iter(self.dataloader)
122
+
123
+ def next_batch(self):
124
+ image, gt = next(self.dl)
125
+ return image, gt
126
+
127
+
128
+ class TestDataset(Dataset):
129
+ def __init__(self, im_ids):
130
+ self.root_dir = '../data/data_raw'
131
+ self.mask_dir = '../data/masks'
132
+ with open('../data/info.pkl', 'rb') as handle:
133
+ self.info = pickle.load(handle)
134
+ self.im_ids = im_ids
135
+ self.fns = [self.info[im_id] for im_id in im_ids]
136
+
137
+ def __getitem__(self, index):
138
+ im_fn = self.fns[index]
139
+ im_name = os.path.join(self.root_dir, im_fn)
140
+ gt_name = os.path.join(
141
+ self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
142
+ image = Image.open(im_name)
143
+ gt = Image.open(gt_name)
144
+ image, gt = pad_pair_256(image, gt)
145
+
146
+ image, gt = p_tr.ToTensor()(image, gt)
147
+ image = normalize(image)
148
+ return image, gt
149
+
150
+ def __len__(self):
151
+ return len(self.fns)
dependecies/segroot/main_segroot.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ import numpy as np
4
+ import random
5
+ from tqdm import tqdm
6
+ import argparse
7
+
8
+ from model import SegRoot
9
+ from dataloader import StaticTrainDataset, TestDataset, TrainDataset, LoopSampler
10
+ from utils import (
11
+ dice_score,
12
+ init_weights,
13
+ evaluate,
14
+ get_ids,
15
+ load_vgg16,
16
+ set_random_seed,
17
+ )
18
+
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--seed", default=42, type=int, help="set random seed")
22
+ parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
23
+ parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
24
+ parser.add_argument("--bs", default=64, type=int, help="batch size of dataloaders")
25
+ parser.add_argument("--lr", default=1e-2, type=float, help="learning rate")
26
+ parser.add_argument("--epochs", default=200, type=int, help="max epochs of training")
27
+ parser.add_argument(
28
+ "--verbose", default=5, type=int, help="intervals to save and validate model"
29
+ )
30
+ parser.add_argument(
31
+ "--dynamic", action="store_true", help="use dynamic sub-images during training"
32
+ )
33
+
34
+
35
+ def train_one_epoch(model, train_iter, optimizer, device):
36
+ model.train()
37
+ for p in model.parameters():
38
+ p.requires_grad = True
39
+ for x, y in train_iter:
40
+ x, y = x.to(device), y.to(device)
41
+ bs = x.shape[0]
42
+ optimizer.zero_grad()
43
+ y_pred = model(x)
44
+ loss = 1 - dice_score(y, y_pred)
45
+ loss = torch.sum(loss) / bs
46
+ loss.backward()
47
+ optimizer.step()
48
+
49
+
50
+ if __name__ == "__main__":
51
+ args = parser.parse_args()
52
+ seed = args.seed
53
+ bs = args.bs
54
+ lr = args.lr
55
+ width = args.width
56
+ depth = args.depth
57
+ epochs = args.epochs
58
+ verbose = args.verbose
59
+
60
+ # set random seed
61
+ set_random_seed(seed)
62
+ # define the device for training
63
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
64
+ # get training ids
65
+ train_ids, valid_ids, test_ids = get_ids(65)
66
+ # define dataloaders
67
+ if args.dynamic:
68
+ train_data = TrainDataset(train_ids)
69
+ train_iter = DataLoader(
70
+ train_data, batch_size=bs, num_workers=6, sampler=LoopSampler
71
+ )
72
+ else:
73
+ train_data = StaticTrainDataset(train_ids)
74
+ train_iter = DataLoader(train_data, batch_size=bs, num_workers=6, shuffle=True)
75
+
76
+ train_tdata = TestDataset(train_ids)
77
+ valid_tdata = TestDataset(valid_ids)
78
+ test_tdata = TestDataset(test_ids)
79
+
80
+ # define model
81
+ model = SegRoot(width, depth).to(device)
82
+ model = model.apply(init_weights)
83
+
84
+ # define optimizer and lr_scheduler
85
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
86
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
87
+ optimizer, mode="max", factor=0.5, verbose=True, patience=5
88
+ )
89
+
90
+ print(f"Start training SegRoot-({width},{depth}))......")
91
+ print(f"Random seed is {seed}, batch size is {bs}......")
92
+ print(f"learning rate is {lr}, max epochs is {epochs}......")
93
+ best_valid = float("-inf")
94
+ for epoch in tqdm(range(epochs)):
95
+ train_one_epoch(model, train_iter, optimizer, device)
96
+ if epoch % verbose == 0:
97
+ train_dice = evaluate(model, train_tdata, device)
98
+ valid_dice = evaluate(model, valid_tdata, device)
99
+ scheduler.step(valid_dice)
100
+ print(
101
+ "Epoch {:05d}, train dice: {:.4f}, valid dice: {:.4f}".format(
102
+ epoch, train_dice, valid_dice
103
+ )
104
+ )
105
+ if valid_dice > best_valid:
106
+ best_valid = valid_dice
107
+ test_dice = evaluate(model, test_tdata, device)
108
+ print("New best validation, test dice: {:.4f}".format(test_dice))
109
+ torch.save(
110
+ model.state_dict(),
111
+ f"../weights/best_segroot-({args.width},{args.depth}).pt",
112
+ )
dependecies/segroot/model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class ConvBNRelu(nn.Module):
8
+ def __init__(self, in_ch, out_ch):
9
+ super(ConvBNRelu, self).__init__()
10
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True)
11
+ self.bn = nn.BatchNorm2d(out_ch)
12
+ self.activation = nn.ReLU()
13
+
14
+ def forward(self, x):
15
+ x = self.conv(x)
16
+ x = self.bn(x)
17
+ x = self.activation(x)
18
+ # print(x.shape)
19
+ return x
20
+
21
+
22
+ class FirstBlock(nn.Module):
23
+ def __init__(self, in_ch, out_ch):
24
+ super(FirstBlock, self).__init__()
25
+ self.conv1 = ConvBNRelu(in_ch, out_ch)
26
+ self.conv2 = ConvBNRelu(out_ch, out_ch)
27
+
28
+ def forward(self, x):
29
+ x = self.conv1(x)
30
+ x = self.conv2(x)
31
+ return x
32
+
33
+
34
+ class DownBlock(nn.Module):
35
+ def __init__(self, in_ch, out_ch):
36
+ super(DownBlock, self).__init__()
37
+ self.conv1 = ConvBNRelu(in_ch, out_ch)
38
+ self.conv2 = ConvBNRelu(out_ch, out_ch)
39
+
40
+ def forward(self, x):
41
+ x = F.max_pool2d(x,kernel_size=2,stride=2)
42
+ x = self.conv1(x)
43
+ x = self.conv2(x)
44
+ return x
45
+
46
+ class Encoder(nn.Module):
47
+ def __init__(self, in_ch, out_ch, block_num=2):
48
+ super(Encoder, self).__init__()
49
+ layers = []
50
+ layers += [ConvBNRelu(in_ch, out_ch)]
51
+ for i in range(block_num-1):
52
+ layers += [ConvBNRelu(out_ch, out_ch)]
53
+ # layers += [nn.Dropout2d(0.5)]
54
+ self.features = nn.Sequential(*layers)
55
+
56
+ def forward(self, x):
57
+ x = self.features(x)
58
+ x, indices = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
59
+ return x, indices
60
+
61
+ class Decoder(nn.Module):
62
+ def __init__(self, in_ch, out_ch, block_num=2):
63
+ super(Decoder, self).__init__()
64
+ layers = []
65
+ layers += [ConvBNRelu(in_ch, out_ch)]
66
+ for i in range(block_num-1):
67
+ layers += [ConvBNRelu(out_ch, out_ch)]
68
+ # layers += [nn.Dropout2d(0.5)]
69
+ self.features = nn.Sequential(*layers)
70
+
71
+ def forward(self, x, indices):
72
+ x = F.max_unpool2d(x, indices=indices, kernel_size=2, stride=2)
73
+ x = self.features(x)
74
+ return x
75
+
76
+ class SegRoot(nn.Module):
77
+ def __init__(self, width=8, depth=5, num_classes=2):
78
+ super(SegRoot, self).__init__()
79
+ chs = []
80
+ for i in range(depth-1):
81
+ chs.append(width * (2**i))
82
+ chs.append(chs[-1])
83
+ self.e_ch_info = [3,] + chs
84
+ self.e_bl_info = [2,2,3,3]
85
+ for _ in range(depth - 4):
86
+ self.e_bl_info += [3,]
87
+ self.d_ch_info = chs[::-1] + [4,]
88
+ self.d_bl_info = self.e_bl_info[::-1]
89
+ # using same setup with Unet
90
+ if width == 4:
91
+ self.e_ch_info = [3,4,8,16,32,64]
92
+ self.d_ch_info = [64,32,16,8,4,4]
93
+ self.num_classes = num_classes
94
+ self.encoders = nn.ModuleList()
95
+ self.decoders = nn.ModuleList()
96
+
97
+ for i in range(1,len(self.e_ch_info)):
98
+ self.encoders.append(Encoder(self.e_ch_info[i-1], self.e_ch_info[i], self.e_bl_info[i-1]))
99
+ self.decoders.append(Decoder(self.d_ch_info[i-1], self.d_ch_info[i], self.d_bl_info[i-1]))
100
+
101
+ # self.classifier = nn.Conv2d(self.d_ch_info[-1], num_classes, kernel_size=3, padding=1)
102
+ self.classifier = nn.Conv2d(self.d_ch_info[-1], 1, 1)
103
+
104
+ def forward(self, x):
105
+ indices = []
106
+ bs = x.shape[0]
107
+ for i in range(len(self.e_bl_info)):
108
+ x, ind = self.encoders[i](x)
109
+ indices.append(ind)
110
+
111
+ indices = indices[::-1]
112
+ for i in range(len(self.e_bl_info)):
113
+ x = self.decoders[i](x, indices[i])
114
+
115
+ x = self.classifier(x)
116
+ # x = F.softmax(x,dim=1)
117
+ x = torch.sigmoid(x)
118
+ return x
119
+
120
+
121
+ if __name__ == '__main__':
122
+ x = torch.zeros((1, 3, 256, 256))
123
+ net = SegRoot(8,5)
124
+ print(net(x).shape)
dependecies/segroot/paired_transforms_pt04.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import math
4
+ import random
5
+ from PIL import Image, ImageOps, ImageEnhance
6
+ try:
7
+ import accimage
8
+ except ImportError:
9
+ accimage = None
10
+ import numpy as np
11
+ import numbers
12
+ import types
13
+ import collections
14
+ import warnings
15
+
16
+ from torchvision.transforms import functional as F
17
+
18
+ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
19
+ "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
20
+ "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
21
+ "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
22
+
23
+ _pil_interpolation_to_str = {
24
+ Image.NEAREST: 'PIL.Image.NEAREST',
25
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
26
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
27
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
28
+ }
29
+
30
+
31
+ class Compose(object):
32
+ """Composes several transforms together.
33
+ Args:
34
+ transforms (list of ``Transform`` objects): list of transforms to compose.
35
+ Example:
36
+ >>> transforms.Compose([
37
+ >>> transforms.CenterCrop(10),
38
+ >>> transforms.ToTensor(),
39
+ >>> ])
40
+ """
41
+
42
+ def __init__(self, transforms):
43
+ self.transforms = transforms
44
+
45
+ def __call__(self, img, target = None):
46
+ if target is not None:
47
+ for t in self.transforms:
48
+ img, target = t(img, target)
49
+ return img, target
50
+
51
+ for t in self.transforms:
52
+ img = t(img)
53
+ return img
54
+
55
+ def __repr__(self):
56
+ format_string = self.__class__.__name__ + '('
57
+ for t in self.transforms:
58
+ format_string += '\n'
59
+ format_string += ' {0}'.format(t)
60
+ format_string += '\n)'
61
+ return format_string
62
+
63
+
64
+ class ToTensor(object):
65
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
66
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
67
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
68
+ """
69
+
70
+ def __call__(self, pic, pic2=None):
71
+ """
72
+ Args:
73
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
74
+ pic2 (PIL Image): (optional) Second image to be converted also.
75
+ Returns:
76
+ Tensor(s): Converted image(s).
77
+ """
78
+ if pic2 is not None:
79
+ return F.to_tensor(pic), F.to_tensor(pic2)
80
+ return F.to_tensor(pic)
81
+
82
+ def __repr__(self):
83
+ return self.__class__.__name__ + '()'
84
+
85
+
86
+ class ToPILImage(object):
87
+ """Convert a tensor or an ndarray to PIL Image.
88
+ Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
89
+ H x W x C to a PIL Image while preserving the value range.
90
+ Args:
91
+ mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
92
+ If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
93
+ 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
94
+ 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
95
+ 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e,
96
+ ``int``, ``float``, ``short``).
97
+ .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
98
+ """
99
+ def __init__(self, mode=None):
100
+ self.mode = mode
101
+
102
+ def __call__(self, pic, pic2=None):
103
+ """
104
+ Args:
105
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
106
+ Returns:
107
+ PIL Image: Image converted to PIL Image.
108
+ """
109
+ if pic2 is not None:
110
+ return F.to_pil_image(pic), F.to_pil_image(pic2)
111
+ return F.to_pil_image(pic, self.mode)
112
+
113
+ def __repr__(self):
114
+ format_string = self.__class__.__name__ + '('
115
+ if self.mode is not None:
116
+ format_string += 'mode={0}'.format(self.mode)
117
+ format_string += ')'
118
+ return format_string
119
+
120
+
121
+ class Normalize(object):
122
+ """Normalize a tensor image with mean and standard deviation.
123
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
124
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
125
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
126
+ Args:
127
+ mean (sequence): Sequence of means for each channel.
128
+ std (sequence): Sequence of standard deviations for each channel.
129
+ """
130
+
131
+ def __init__(self, mean, std):
132
+ self.mean = mean
133
+ self.std = std
134
+
135
+ def __call__(self, tensor):
136
+ """
137
+ Args:
138
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
139
+ Returns:
140
+ Tensor: Normalized Tensor image.
141
+ """
142
+ return F.normalize(tensor, self.mean, self.std)
143
+
144
+ def __repr__(self):
145
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
146
+
147
+
148
+ class Resize(object):
149
+ """Resize the input PIL Image to the given size.
150
+ Args:
151
+ size (sequence or int): Desired output size. If size is a sequence like
152
+ (h, w), output size will be matched to this. If size is an int,
153
+ smaller edge of the image will be matched to this number.
154
+ i.e, if height > width, then image will be rescaled to
155
+ (size * height / width, size)
156
+ interpolation (int, optional): Desired interpolation. Default is
157
+ ``PIL.Image.BILINEAR``
158
+ interpolation_tg (int, optional): Desired interpolation for target. Default is
159
+ ``PIL.Image.NEAREST``
160
+ """
161
+
162
+ def __init__(self, size, interpolation=Image.BILINEAR, interpolation_tg = Image.NEAREST):
163
+ assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
164
+ self.size = size
165
+ self.interpolation = interpolation
166
+ self.interpolation_tg = interpolation_tg
167
+
168
+ def __call__(self, img, target = None):
169
+ """
170
+ Args:
171
+ img (PIL Image): Image to be scaled.
172
+ target (PIL Image): (optional) Target to be scaled
173
+ Returns:
174
+ PIL Image: Rescaled image(s).
175
+ """
176
+ if target is not None:
177
+ return F.resize(img, self.size, self.interpolation), F.resize(target, self.size, self.interpolation_tg)
178
+ return F.resize(img, self.size, self.interpolation)
179
+
180
+ def __repr__(self):
181
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
182
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
183
+
184
+
185
+ class Scale(Resize):
186
+ """
187
+ Note: This transform is deprecated in favor of Resize.
188
+ """
189
+ def __init__(self, *args, **kwargs):
190
+ warnings.warn("The use of the transforms.Scale transform is deprecated, " +
191
+ "please use transforms.Resize instead.")
192
+ super(Scale, self).__init__(*args, **kwargs)
193
+
194
+
195
+ class CenterCrop(object):
196
+ """Crops the given PIL Image at the center.
197
+ Args:
198
+ size (sequence or int): Desired output size of the crop. If size is an
199
+ int instead of sequence like (h, w), a square crop (size, size) is
200
+ made.
201
+ """
202
+
203
+ def __init__(self, size):
204
+ if isinstance(size, numbers.Number):
205
+ self.size = (int(size), int(size))
206
+ else:
207
+ self.size = size
208
+
209
+ def __call__(self, img, target=None):
210
+ """
211
+ Args:
212
+ img (PIL Image): Image to be cropped.
213
+ target (PIL Image): (optional) Target to be cropped
214
+ Returns:
215
+ PIL Image: Cropped image(s).
216
+ """
217
+ if target is not None:
218
+ return F.center_crop(img, self.size), F.center_crop(target, self.size)
219
+ return F.center_crop(img, self.size)
220
+
221
+ def __repr__(self):
222
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
223
+
224
+
225
+ class Pad(object):
226
+ """Pad the given PIL Image on all sides with the given "pad" value.
227
+ Args:
228
+ padding (int or tuple): Padding on each border. If a single int is provided this
229
+ is used to pad all borders. If tuple of length 2 is provided this is the padding
230
+ on left/right and top/bottom respectively. If a tuple of length 4 is provided
231
+ this is the padding for the left, top, right and bottom borders
232
+ respectively.
233
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
234
+ length 3, it is used to fill R, G, B channels respectively.
235
+ This value is only used when the padding_mode is constant
236
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
237
+ constant: pads with a constant value, this value is specified with fill
238
+ edge: pads with the last value at the edge of the image
239
+ reflect: pads with reflection of image (without repeating the last value on the edge)
240
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
241
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
242
+ symmetric: pads with reflection of image (repeating the last value on the edge)
243
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
244
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
245
+ """
246
+
247
+ def __init__(self, padding, fill=0, padding_mode='constant'):
248
+ assert isinstance(padding, (numbers.Number, tuple))
249
+ assert isinstance(fill, (numbers.Number, str, tuple))
250
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
251
+ if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
252
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
253
+ "{} element tuple".format(len(padding)))
254
+
255
+ self.padding = padding
256
+ self.fill = fill
257
+ self.padding_mode = padding_mode
258
+
259
+ def __call__(self, img):
260
+ """
261
+ Args:
262
+ img (PIL Image): Image to be padded.
263
+ Returns:
264
+ PIL Image: Padded image.
265
+ """
266
+ return F.pad(img, self.padding, self.fill, self.padding_mode)
267
+
268
+ def __repr__(self):
269
+ return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
270
+ format(self.padding, self.fill, self.padding_mode)
271
+
272
+
273
+ class Lambda(object):
274
+ """Apply a user-defined lambda as a transform.
275
+ Args:
276
+ lambd (function): Lambda/function to be used for transform.
277
+ """
278
+
279
+ def __init__(self, lambd):
280
+ assert isinstance(lambd, types.LambdaType)
281
+ self.lambd = lambd
282
+
283
+ def __call__(self, img):
284
+ return self.lambd(img)
285
+
286
+ def __repr__(self):
287
+ return self.__class__.__name__ + '()'
288
+
289
+
290
+ class RandomTransforms(object):
291
+ """Base class for a list of transformations with randomness
292
+ Args:
293
+ transforms (list or tuple): list of transformations
294
+ """
295
+
296
+ def __init__(self, transforms):
297
+ assert isinstance(transforms, (list, tuple))
298
+ self.transforms = transforms
299
+
300
+ def __call__(self, *args, **kwargs):
301
+ raise NotImplementedError()
302
+
303
+ def __repr__(self):
304
+ format_string = self.__class__.__name__ + '('
305
+ for t in self.transforms:
306
+ format_string += '\n'
307
+ format_string += ' {0}'.format(t)
308
+ format_string += '\n)'
309
+ return format_string
310
+
311
+
312
+ class RandomApply(RandomTransforms):
313
+ """Apply randomly a list of transformations with a given probability
314
+ Args:
315
+ transforms (list or tuple): list of transformations
316
+ p (float): probability
317
+ """
318
+
319
+ def __init__(self, transforms, p=0.5):
320
+ super(RandomApply, self).__init__(transforms)
321
+ self.p = p
322
+
323
+ def __call__(self, img):
324
+ if self.p < random.random():
325
+ return img
326
+ for t in self.transforms:
327
+ img = t(img)
328
+ return img
329
+
330
+ def __repr__(self):
331
+ format_string = self.__class__.__name__ + '('
332
+ format_string += '\n p={}'.format(self.p)
333
+ for t in self.transforms:
334
+ format_string += '\n'
335
+ format_string += ' {0}'.format(t)
336
+ format_string += '\n)'
337
+ return format_string
338
+
339
+
340
+ class RandomOrder(RandomTransforms):
341
+ """Apply a list of transformations in a random order
342
+ """
343
+ def __call__(self, img):
344
+ order = list(range(len(self.transforms)))
345
+ random.shuffle(order)
346
+ for i in order:
347
+ img = self.transforms[i](img)
348
+ return img
349
+
350
+
351
+ class RandomChoice(RandomTransforms):
352
+ """Apply single transformation randomly picked from a list
353
+ """
354
+ def __call__(self, img):
355
+ t = random.choice(self.transforms)
356
+ return t(img)
357
+
358
+
359
+ class RandomCrop(object):
360
+ """Crop the given PIL Image at a random location.
361
+ Args:
362
+ size (sequence or int): Desired output size of the crop. If size is an
363
+ int instead of sequence like (h, w), a square crop (size, size) is
364
+ made.
365
+ padding (int or sequence, optional): Optional padding on each border
366
+ of the image. Default is 0, i.e no padding. If a sequence of length
367
+ 4 is provided, it is used to pad left, top, right, bottom borders
368
+ respectively.
369
+ pad_if_needed (boolean): It will pad the image if smaller than the
370
+ desired size to avoid raising an exception.
371
+ """
372
+
373
+ def __init__(self, size, padding=0, pad_if_needed=False):
374
+ if isinstance(size, numbers.Number):
375
+ self.size = (int(size), int(size))
376
+ else:
377
+ self.size = size
378
+ self.padding = padding
379
+ self.pad_if_needed = pad_if_needed
380
+
381
+ @staticmethod
382
+ def get_params(img, output_size):
383
+ """Get parameters for ``crop`` for a random crop.
384
+ Args:
385
+ img (PIL Image): Image to be cropped.
386
+ output_size (tuple): Expected output size of the crop.
387
+ Returns:
388
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
389
+ """
390
+ w, h = img.size
391
+ th, tw = output_size
392
+ if w == tw and h == th:
393
+ return 0, 0, h, w
394
+
395
+ i = random.randint(0, h - th)
396
+ j = random.randint(0, w - tw)
397
+ return i, j, th, tw
398
+
399
+ def __call__(self, img, target = None):
400
+ """
401
+ Args:
402
+ img (PIL Image): Image to be cropped.
403
+ target (PIL Image): (optional) Target to be cropped
404
+ Returns:
405
+ PIL Images: Cropped image(s).
406
+ """
407
+ if self.padding > 0:
408
+ img = F.pad(img, self.padding)
409
+ if target is not None:
410
+ target = F.pad(target, self.padding)
411
+
412
+ # pad the width if needed
413
+ if self.pad_if_needed and img.size[0] < self.size[1]:
414
+ img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
415
+ if target is not None:
416
+ target = F.pad(target, (int((1 + self.size[1] - target.size[0]) / 2), 0))
417
+ # pad the height if needed
418
+ if self.pad_if_needed and img.size[1] < self.size[0]:
419
+ img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
420
+ if target is not None:
421
+ target = F.pad(target, (0, int((1 + self.size[0] - target.size[1]) / 2)))
422
+ i, j, h, w = self.get_params(img, self.size)
423
+
424
+ if target is not None:
425
+ return F.crop(img, i, j, h, w), F.crop(target, i, j, h, w)
426
+ else:
427
+ return F.crop(img, i, j, h, w)
428
+
429
+ def __repr__(self):
430
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
431
+
432
+
433
+ class RandomHorizontalFlip(object):
434
+ """Horizontally flip the given PIL Image randomly with a given probability.
435
+ Args:
436
+ p (float): probability of the image being flipped. Default value is 0.5
437
+ """
438
+
439
+ def __init__(self, p=0.5):
440
+ self.p = p
441
+
442
+ def __call__(self, img, target=None):
443
+ """
444
+ Args:
445
+ img (PIL Image): Image to be flipped.
446
+ target (PIL Image): (optional) Target to be flipped
447
+ Returns:
448
+ PIL Image: Randomly flipped image(s).
449
+ """
450
+ if random.random() < self.p:
451
+ if target is not None:
452
+ return F.hflip(img), F.hflip(target)
453
+ else:
454
+ return F.hflip(img)
455
+
456
+ if target is not None:
457
+ return img, target
458
+ return img
459
+
460
+ def __repr__(self):
461
+ return self.__class__.__name__ + '(p={})'.format(self.p)
462
+
463
+
464
+ class RandomVerticalFlip(object):
465
+ """Vertically flip the given PIL Image randomly with a given probability.
466
+ Args:
467
+ p (float): probability of the image being flipped. Default value is 0.5
468
+ """
469
+
470
+ def __init__(self, p=0.5):
471
+ self.p = p
472
+
473
+ def __call__(self, img, target=None):
474
+ """
475
+ Args:
476
+ img (PIL Image): Image to be flipped.
477
+ target (PIL Image): (optional) Target to be flipped
478
+ Returns:
479
+ PIL Image: Randomly flipped image(s).
480
+ """
481
+ if random.random() < self.p:
482
+ if target is not None:
483
+ return F.vflip(img), F.vflip(target)
484
+ else:
485
+ return F.vflip(img)
486
+
487
+ if target is not None:
488
+ return img, target
489
+ return img
490
+
491
+ def __repr__(self):
492
+ return self.__class__.__name__ + '(p={})'.format(self.p)
493
+
494
+
495
+ class RandomResizedCrop(object):
496
+ """Crop the given PIL Image to random size and aspect ratio.
497
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
498
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
499
+ is finally resized to given size.
500
+ This is popularly used to train the Inception networks.
501
+ Args:
502
+ size: expected output size of each edge
503
+ scale: range of size of the origin size cropped
504
+ ratio: range of aspect ratio of the origin aspect ratio cropped
505
+ interpolation: Default: PIL.Image.BILINEAR
506
+ """
507
+
508
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
509
+ interpolation=Image.BILINEAR, interpolation_tg = Image.NEAREST):
510
+ self.size = (size, size)
511
+ self.interpolation = interpolation
512
+ self.interpolation_tg = interpolation_tg
513
+ self.scale = scale
514
+ self.ratio = ratio
515
+
516
+ @staticmethod
517
+ def get_params(img, scale, ratio):
518
+ """Get parameters for ``crop`` for a random sized crop.
519
+ Args:
520
+ img (PIL Image): Image to be cropped.
521
+ scale (tuple): range of size of the origin size cropped
522
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
523
+ Returns:
524
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
525
+ sized crop.
526
+ """
527
+ for attempt in range(10):
528
+ area = img.size[0] * img.size[1]
529
+ target_area = random.uniform(*scale) * area
530
+ aspect_ratio = random.uniform(*ratio)
531
+
532
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
533
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
534
+
535
+ if random.random() < 0.5:
536
+ w, h = h, w
537
+
538
+ if w <= img.size[0] and h <= img.size[1]:
539
+ i = random.randint(0, img.size[1] - h)
540
+ j = random.randint(0, img.size[0] - w)
541
+ return i, j, h, w
542
+
543
+ # Fallback
544
+ w = min(img.size[0], img.size[1])
545
+ i = (img.size[1] - w) // 2
546
+ j = (img.size[0] - w) // 2
547
+ return i, j, w, w
548
+
549
+ def __call__(self, img, target = None):
550
+ """
551
+ Args:
552
+ img (PIL Image): Image to be cropped and resized.
553
+ target (PIL Image): (optional) Target to be cropped and resized.
554
+ Returns:
555
+ PIL Image: Randomly cropped and resized image(s).
556
+ """
557
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
558
+ if target is not None:
559
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
560
+ F.resized_crop(target, i, j, h, w, self.size, self.interpolation_tg)
561
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
562
+
563
+ def __repr__(self):
564
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
565
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
566
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
567
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
568
+ format_string += ', interpolation={0})'.format(interpolate_str)
569
+ return format_string
570
+
571
+
572
+ class RandomSizedCrop(RandomResizedCrop):
573
+ """
574
+ Note: This transform is deprecated in favor of RandomResizedCrop.
575
+ """
576
+ def __init__(self, *args, **kwargs):
577
+ warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
578
+ "please use transforms.RandomResizedCrop instead.")
579
+ super(RandomSizedCrop, self).__init__(*args, **kwargs)
580
+
581
+
582
+ class FiveCrop(object):
583
+ """Crop the given PIL Image into four corners and the central crop
584
+ .. Note::
585
+ This transform returns a tuple of images and there may be a mismatch in the number of
586
+ inputs and targets your Dataset returns. See below for an example of how to deal with
587
+ this.
588
+ Args:
589
+ size (sequence or int): Desired output size of the crop. If size is an ``int``
590
+ instead of sequence like (h, w), a square crop of size (size, size) is made.
591
+ Example:
592
+ >>> transform = Compose([
593
+ >>> FiveCrop(size), # this is a list of PIL Images
594
+ >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
595
+ >>> ])
596
+ >>> #In your test loop you can do the following:
597
+ >>> input, target = batch # input is a 5d tensor, target is 2d
598
+ >>> bs, ncrops, c, h, w = input.size()
599
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
600
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
601
+ """
602
+
603
+ def __init__(self, size):
604
+ self.size = size
605
+ if isinstance(size, numbers.Number):
606
+ self.size = (int(size), int(size))
607
+ else:
608
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
609
+ self.size = size
610
+
611
+ def __call__(self, img, target=None):
612
+ if target is not None:
613
+ return F.five_crop(img, self.size), F.five_crop(target, self.size)
614
+ return F.five_crop(img, self.size)
615
+
616
+ def __repr__(self):
617
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
618
+
619
+
620
+ class TenCrop(object):
621
+ """Crop the given PIL Image into four corners and the central crop plus the flipped version of
622
+ these (horizontal flipping is used by default)
623
+ .. Note::
624
+ This transform returns a tuple of images and there may be a mismatch in the number of
625
+ inputs and targets your Dataset returns. See below for an example of how to deal with
626
+ this.
627
+ Args:
628
+ size (sequence or int): Desired output size of the crop. If size is an
629
+ int instead of sequence like (h, w), a square crop (size, size) is
630
+ made.
631
+ vertical_flip(bool): Use vertical flipping instead of horizontal
632
+ Example:
633
+ >>> transform = Compose([
634
+ >>> TenCrop(size), # this is a list of PIL Images
635
+ >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
636
+ >>> ])
637
+ >>> #In your test loop you can do the following:
638
+ >>> input, target = batch # input is a 5d tensor, target is 2d
639
+ >>> bs, ncrops, c, h, w = input.size()
640
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
641
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
642
+ """
643
+
644
+ def __init__(self, size, vertical_flip=False):
645
+ self.size = size
646
+ if isinstance(size, numbers.Number):
647
+ self.size = (int(size), int(size))
648
+ else:
649
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
650
+ self.size = size
651
+ self.vertical_flip = vertical_flip
652
+
653
+ def __call__(self, img, target = None):
654
+ if target is not None:
655
+ return F.ten_crop(img, self.size), F.ten_crop(target, self.size)
656
+ return F.ten_crop(img, self.size, self.vertical_flip)
657
+
658
+ def __repr__(self):
659
+ return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
660
+
661
+
662
+ class LinearTransformation(object):
663
+ """Transform a tensor image with a square transformation matrix computed
664
+ offline.
665
+ Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
666
+ product with the transformation matrix and reshape the tensor to its
667
+ original shape.
668
+ Applications:
669
+ - whitening: zero-center the data, compute the data covariance matrix
670
+ [D x D] with np.dot(X.T, X), perform SVD on this matrix and
671
+ pass it as transformation_matrix.
672
+ Args:
673
+ transformation_matrix (Tensor): tensor [D x D], D = C x H x W
674
+ """
675
+
676
+ def __init__(self, transformation_matrix):
677
+ if transformation_matrix.size(0) != transformation_matrix.size(1):
678
+ raise ValueError("transformation_matrix should be square. Got " +
679
+ "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
680
+ self.transformation_matrix = transformation_matrix
681
+
682
+ def __call__(self, tensor, target_tensor=None):
683
+ """
684
+ Args:
685
+ tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
686
+ Returns:
687
+ Tensor: Transformed image.
688
+ """
689
+ if target_tensor is not None:
690
+ raise NotImplementedError("LinearTransformation not implemented for tensor pairs.")
691
+ if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
692
+ raise ValueError("tensor and transformation matrix have incompatible shape." +
693
+ "[{} x {} x {}] != ".format(*tensor.size()) +
694
+ "{}".format(self.transformation_matrix.size(0)))
695
+ flat_tensor = tensor.view(1, -1)
696
+ transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
697
+ tensor = transformed_tensor.view(tensor.size())
698
+ return tensor
699
+
700
+ def __repr__(self):
701
+ format_string = self.__class__.__name__ + '('
702
+ format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
703
+ return format_string
704
+
705
+
706
+ class ColorJitter(object):
707
+ """Randomly change the brightness, contrast and saturation of an image.
708
+ Args:
709
+ brightness (float): How much to jitter brightness. brightness_factor
710
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
711
+ contrast (float): How much to jitter contrast. contrast_factor
712
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
713
+ saturation (float): How much to jitter saturation. saturation_factor
714
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
715
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
716
+ [-hue, hue]. Should be >=0 and <= 0.5.
717
+ """
718
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
719
+ self.brightness = brightness
720
+ self.contrast = contrast
721
+ self.saturation = saturation
722
+ self.hue = hue
723
+
724
+ @staticmethod
725
+ def get_params(brightness, contrast, saturation, hue):
726
+ """Get a randomized transform to be applied on image.
727
+ Arguments are same as that of __init__.
728
+ Returns:
729
+ Transform which randomly adjusts brightness, contrast and
730
+ saturation in a random order.
731
+ """
732
+ transforms = []
733
+ if brightness > 0:
734
+ brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
735
+ transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
736
+
737
+ if contrast > 0:
738
+ contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
739
+ transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
740
+
741
+ if saturation > 0:
742
+ saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
743
+ transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
744
+
745
+ if hue > 0:
746
+ hue_factor = random.uniform(-hue, hue)
747
+ transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
748
+
749
+ random.shuffle(transforms)
750
+ transform = Compose(transforms)
751
+
752
+ return transform
753
+
754
+ def __call__(self, img, target = None):
755
+ """
756
+ Args:
757
+ img (PIL Image): Input image.
758
+ Returns:
759
+ PIL Image: Color jittered image.
760
+ """
761
+ transform = self.get_params(self.brightness, self.contrast,
762
+ self.saturation, self.hue)
763
+
764
+ if target is not None:
765
+ return transform(img), target
766
+ return transform(img)
767
+
768
+ def __repr__(self):
769
+ format_string = self.__class__.__name__ + '('
770
+ format_string += 'brightness={0}'.format(self.brightness)
771
+ format_string += ', contrast={0}'.format(self.contrast)
772
+ format_string += ', saturation={0}'.format(self.saturation)
773
+ format_string += ', hue={0})'.format(self.hue)
774
+ return format_string
775
+
776
+
777
+ class RandomRotation(object):
778
+ """Rotate the image by angle.
779
+ Args:
780
+ degrees (sequence or float or int): Range of degrees to select from.
781
+ If degrees is a number instead of sequence like (min, max), the range of degrees
782
+ will be (-degrees, +degrees).
783
+ resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
784
+ An optional resampling filter.
785
+ See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
786
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
787
+ expand (bool, optional): Optional expansion flag.
788
+ If true, expands the output to make it large enough to hold the entire rotated image.
789
+ If false or omitted, make the output image the same size as the input image.
790
+ Note that the expand flag assumes rotation around the center and no translation.
791
+ center (2-tuple, optional): Optional center of rotation.
792
+ Origin is the upper left corner.
793
+ Default is the center of the image.
794
+ """
795
+
796
+ def __init__(self, degrees, resample=False, resample_tg=False, expand=False, center=None):
797
+ if isinstance(degrees, numbers.Number):
798
+ if degrees < 0:
799
+ raise ValueError("If degrees is a single number, it must be positive.")
800
+ self.degrees = (-degrees, degrees)
801
+ else:
802
+ if len(degrees) != 2:
803
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
804
+ self.degrees = degrees
805
+
806
+ self.resample = resample
807
+ self.resample_tg = resample_tg
808
+ self.expand = expand
809
+ self.center = center
810
+
811
+ @staticmethod
812
+ def get_params(degrees):
813
+ """Get parameters for ``rotate`` for a random rotation.
814
+ Returns:
815
+ sequence: params to be passed to ``rotate`` for random rotation.
816
+ """
817
+ angle = random.uniform(degrees[0], degrees[1])
818
+
819
+ return angle
820
+
821
+ def __call__(self, img, target=None):
822
+ """
823
+ img (PIL Image): Image to be rotated.
824
+ target (PIL Image): (optional) Target to be rotated
825
+ Returns:
826
+ PIL Image: Rotated image(s).
827
+ """
828
+
829
+ angle = self.get_params(self.degrees)
830
+
831
+ if target is not None:
832
+ return F.rotate(img, angle, self.resample, self.expand, self.center), \
833
+ F.rotate(target, angle, self.resample_tg, self.expand, self.center)
834
+ # resample = False is by default nearest, appropriate for targets
835
+
836
+ def __repr__(self):
837
+ format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
838
+ format_string += ', resample={0}'.format(self.resample)
839
+ format_string += ', expand={0}'.format(self.expand)
840
+ if self.center is not None:
841
+ format_string += ', center={0}'.format(self.center)
842
+ format_string += ')'
843
+ return format_string
844
+
845
+
846
+ class RandomAffine(object):
847
+ """Random affine transformation of the image keeping center invariant
848
+ Args:
849
+ degrees (sequence or float or int): Range of degrees to select from.
850
+ If degrees is a number instead of sequence like (min, max), the range of degrees
851
+ will be (-degrees, +degrees). Set to 0 to desactivate rotations.
852
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
853
+ and vertical translations. For example translate=(a, b), then horizontal shift
854
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
855
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
856
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
857
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
858
+ shear (sequence or float or int, optional): Range of degrees to select from.
859
+ If degrees is a number instead of sequence like (min, max), the range of degrees
860
+ will be (-degrees, +degrees). Will not apply shear by default
861
+ resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
862
+ An optional resampling filter.
863
+ See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
864
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
865
+ fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
866
+ """
867
+
868
+ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, resample_tg=False, fillcolor=0):
869
+ if isinstance(degrees, numbers.Number):
870
+ if degrees < 0:
871
+ raise ValueError("If degrees is a single number, it must be positive.")
872
+ self.degrees = (-degrees, degrees)
873
+ else:
874
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
875
+ "degrees should be a list or tuple and it must be of length 2."
876
+ self.degrees = degrees
877
+
878
+ if translate is not None:
879
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
880
+ "translate should be a list or tuple and it must be of length 2."
881
+ for t in translate:
882
+ if not (0.0 <= t <= 1.0):
883
+ raise ValueError("translation values should be between 0 and 1")
884
+ self.translate = translate
885
+
886
+ if scale is not None:
887
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
888
+ "scale should be a list or tuple and it must be of length 2."
889
+ for s in scale:
890
+ if s <= 0:
891
+ raise ValueError("scale values should be positive")
892
+ self.scale = scale
893
+
894
+ if shear is not None:
895
+ if isinstance(shear, numbers.Number):
896
+ if shear < 0:
897
+ raise ValueError("If shear is a single number, it must be positive.")
898
+ self.shear = (-shear, shear)
899
+ else:
900
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
901
+ "shear should be a list or tuple and it must be of length 2."
902
+ self.shear = shear
903
+ else:
904
+ self.shear = shear
905
+
906
+ self.resample = resample
907
+ self.resample_tg = resample_tg
908
+ self.fillcolor = fillcolor
909
+
910
+ @staticmethod
911
+ def get_params(degrees, translate, scale_ranges, shears, img_size):
912
+ """Get parameters for affine transformation
913
+ Returns:
914
+ sequence: params to be passed to the affine transformation
915
+ """
916
+ angle = random.uniform(degrees[0], degrees[1])
917
+ if translate is not None:
918
+ max_dx = translate[0] * img_size[0]
919
+ max_dy = translate[1] * img_size[1]
920
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
921
+ np.round(random.uniform(-max_dy, max_dy)))
922
+ else:
923
+ translations = (0, 0)
924
+
925
+ if scale_ranges is not None:
926
+ scale = random.uniform(scale_ranges[0], scale_ranges[1])
927
+ else:
928
+ scale = 1.0
929
+
930
+ if shears is not None:
931
+ shear = random.uniform(shears[0], shears[1])
932
+ else:
933
+ shear = 0.0
934
+
935
+ return angle, translations, scale, shear
936
+
937
+ def __call__(self, img, target=None):
938
+ """
939
+ img (PIL Image): Image to be rotated.
940
+ target (PIL Image): (optional) Target to be rotated
941
+ Returns:
942
+ PIL Image: Rotated image(s).
943
+ """
944
+ ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
945
+ if target is not None:
946
+ return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor), \
947
+ F.affine(target, *ret, resample=self.resample_tg, fillcolor=self.fillcolor)
948
+ # resample = False is by default nearest, appropriate for targets
949
+ return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
950
+
951
+ def __repr__(self):
952
+ s = '{name}(degrees={degrees}'
953
+ if self.translate is not None:
954
+ s += ', translate={translate}'
955
+ if self.scale is not None:
956
+ s += ', scale={scale}'
957
+ if self.shear is not None:
958
+ s += ', shear={shear}'
959
+ if self.resample > 0:
960
+ s += ', resample={resample}'
961
+ if self.fillcolor != 0:
962
+ s += ', fillcolor={fillcolor}'
963
+ s += ')'
964
+ d = dict(self.__dict__)
965
+ d['resample'] = _pil_interpolation_to_str[d['resample']]
966
+ return s.format(name=self.__class__.__name__, **d)
967
+
968
+
969
+ class Grayscale(object):
970
+ """Convert image to grayscale.
971
+ Args:
972
+ num_output_channels (int): (1 or 3) number of channels desired for output image
973
+ Returns:
974
+ PIL Image: Grayscale version of the input.
975
+ - If num_output_channels == 1 : returned image is single channel
976
+ - If num_output_channels == 3 : returned image is 3 channel with r == g == b
977
+ """
978
+
979
+ def __init__(self, num_output_channels=1):
980
+ self.num_output_channels = num_output_channels
981
+
982
+ def __call__(self, img, target = None):
983
+ """
984
+ Args:
985
+ img (PIL Image): Image to be converted to grayscale.
986
+ Returns:
987
+ PIL Image: Randomly grayscaled image.
988
+ """
989
+ if target is not None:
990
+ return F.to_grayscale(img, num_output_channels=self.num_output_channels), target
991
+ return F.to_grayscale(img, num_output_channels=self.num_output_channels)
992
+
993
+ def __repr__(self):
994
+ return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
995
+
996
+
997
+ class RandomGrayscale(object):
998
+ """Randomly convert image to grayscale with a probability of p (default 0.1).
999
+ Args:
1000
+ p (float): probability that image should be converted to grayscale.
1001
+ Returns:
1002
+ PIL Image: Grayscale version of the input image with probability p and unchanged
1003
+ with probability (1-p).
1004
+ - If input image is 1 channel: grayscale version is 1 channel
1005
+ - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1006
+ """
1007
+
1008
+ def __init__(self, p=0.1):
1009
+ self.p = p
1010
+
1011
+ def __call__(self, img, target = None):
1012
+ """
1013
+ Args:
1014
+ img (PIL Image): Image to be converted to grayscale.
1015
+ Returns:
1016
+ PIL Image: Randomly grayscaled image.
1017
+ """
1018
+ num_output_channels = 1 if img.mode == 'L' else 3
1019
+ if random.random() < self.p:
1020
+ if target is not None:
1021
+ return F.to_grayscale(img, num_output_channels=num_output_channels), target
1022
+ if target is not None:
1023
+ return img, target
1024
+ return img
1025
+
1026
+ def __repr__(self):
1027
+ return self.__class__.__name__ + '(p={0})'.format(self.p)
dependecies/segroot/paired_weight_vgg16.plk ADDED
Binary file (3.22 kB). View file
 
dependecies/segroot/predict_imgs.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision
6
+ from skimage.morphology import erosion
7
+ import matplotlib.pyplot as plt
8
+ import time
9
+
10
+ from segroot.utils import init_weights
11
+ from segroot.dataloader import pad_pair_256, normalize
12
+ from segroot.model import SegRoot
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--image", default="test.jpg", type=str, help="filename of one test image"
17
+ )
18
+ parser.add_argument(
19
+ "--thres", default=0.9, type=float, help="threshold of the final binarization"
20
+ )
21
+ parser.add_argument(
22
+ "--all", action="store_true", help="make prediction on all images in the folder"
23
+ )
24
+ parser.add_argument(
25
+ "--data_dir",
26
+ default="../data/prediction",
27
+ type=Path,
28
+ help="define the data directory",
29
+ )
30
+ parser.add_argument(
31
+ "--weights",
32
+ default="../weights/best_segnet-(8,5)-0.6441.pt",
33
+ type=Path,
34
+ help="path of pretrained weights",
35
+ )
36
+ parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
37
+ parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
38
+
39
+
40
+ def pad_256(img_path):
41
+ image = Image.open(img_path)
42
+ W, H = image.size
43
+ img, _ = pad_pair_256(image, image)
44
+ NW, NH = img.size
45
+ img = torchvision.transforms.ToTensor()(img)
46
+ img = normalize(img)
47
+ return img, (H, W, NH, NW)
48
+
49
+
50
+ def predict(model, test_img, device):
51
+ for p in model.parameters():
52
+ p.requires_grad = False
53
+
54
+ model.eval()
55
+ # test_img.shape = (3, 2304, 2560)
56
+ test_img = test_img.unsqueeze(0)
57
+ output = model(test_img)
58
+ # output.shape = (1, 1, 2304, 2560)
59
+ output = torch.squeeze(output)
60
+ torch.cuda.empty_cache()
61
+ return output
62
+
63
+
64
+ def predict_gen(model, img_path, thres, device, info):
65
+ img, dims = pad_256(img_path)
66
+ H, W, NH, NW = dims
67
+ img = img.to(device)
68
+ prediction = predict(model, img, device)
69
+ prediction[prediction >= thres] = 1.0
70
+ prediction[prediction < thres] = 0.0
71
+ if device.type == "cpu":
72
+ prediction = prediction.detach().numpy()
73
+ else:
74
+ prediction = prediction.cpu().detach().numpy()
75
+ prediction = erosion(prediction)
76
+ # reverse padding
77
+ prediction = prediction[
78
+ (NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W
79
+ ]
80
+ save_path = img_path.parent / (
81
+ img_path.parts[-1].split(".jpg")[0] + "-pre-mask-segnet-({},5).jpg".format(info)
82
+ )
83
+ plt.imsave(save_path.as_posix(), prediction, cmap="gray")
84
+ print("{} generated!".format(save_path.parts[-1]))
85
+
86
+
87
+ if __name__ == "__main__":
88
+ args = parser.parse_args()
89
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90
+ # define model
91
+ print("using segnet, width : {}, depth : {}".format(args.width, args.depth))
92
+ model = SegRoot(args.width, args.depth).to(device)
93
+ weights_path = args.weights
94
+
95
+ if device.type == "cpu":
96
+ print("load weights to cpu")
97
+ print(weights_path.as_posix())
98
+ model.load_state_dict(torch.load(weights_path.as_posix(), map_location="cpu"))
99
+ else:
100
+ print("load weights to gpu")
101
+ print(weights_path.as_posix())
102
+ model.load_state_dict(torch.load(weights_path.as_posix()))
103
+
104
+ # define the prediction's saving directory
105
+ pre_dir = Path("../data/prediction")
106
+ pre_dir.mkdir(parents=True, exist_ok=True)
107
+ if not args.all:
108
+ # load and pad image
109
+ img_path = pre_dir / args.image
110
+ start_time = time.time()
111
+ predict_gen(model, img_path, args.thres, device, 8)
112
+ end_time = time.time()
113
+ print("{:.4f}s for one image".format(end_time - start_time))
114
+ else:
115
+ img_paths = args.data_dir.glob("*.jpg")
116
+ for img_path in img_paths:
117
+ start_time = time.time()
118
+ predict_gen(model, img_path, args.thres, device, 8)
119
+ end_time = time.time()
120
+ print("{:.4f}s for one image".format(end_time - start_time))
121
+
dependecies/segroot/run_all_experiments.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # !/bin/sh
2
+ python -u train_segroot.py --width 2 > "log_SegRoot(2,5).txt"
3
+ python -u train_segroot.py --width 16 --depth 4 --lr 1e-3 > "log_SegRoot(16,4).txt"
4
+ python -u train_segroot.py --width 32 --depth 5 --lr 1e-4 --bs 32 > "log_SegRoot(32,5).txt"
5
+ python -u train_segroot.py --width 64 --depth 4 --lr 1e-4 --bs 16 > "log_SegRoot(64,4).txt"
6
+ python -u train_segroot.py --width 64 --depth 5 --lr 2e-5 --bs 8 --epochs 100 --verbose 2 > "log_SegRoot(64,5).txt"
dependecies/segroot/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ from torchvision import models
4
+ import random
5
+ import logging
6
+ import numpy as np
7
+ import json
8
+
9
+ def set_random_seed(seed):
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed(seed)
14
+ torch.backends.cudnn.deterministic = True
15
+
16
+ def set_logger(log_path):
17
+ logger = logging.getLogger()
18
+ logger.setLevel(logging.INFO)
19
+
20
+ if not logger.handlers:
21
+ # Logging to a file
22
+ file_handler = logging.FileHandler(log_path)
23
+ file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
24
+ logger.addHandler(file_handler)
25
+
26
+ # Logging to console
27
+ stream_handler = logging.StreamHandler()
28
+ stream_handler.setFormatter(logging.Formatter('%(message)s'))
29
+ logger.addHandler(stream_handler)
30
+
31
+ def to_np(x):
32
+ return x.data.cpu().numpy()
33
+
34
+ def get_ids(length_dataset):
35
+ ids = list(range(length_dataset))
36
+
37
+ random.shuffle(ids)
38
+ train_split = round(0.6 * length_dataset)
39
+ t_v_spplit = (length_dataset - train_split) // 2
40
+ train_ids = ids[:train_split]
41
+ valid_ids = ids[train_split:train_split+t_v_spplit]
42
+ test_ids = ids[train_split+t_v_spplit:]
43
+ return train_ids, valid_ids, test_ids
44
+
45
+ def dice_score(y, y_pred, smooth=1.0, thres=0.9):
46
+ n = y.shape[0]
47
+ y = y.view(n, -1)
48
+ y_pred = y_pred.view(n, -1)
49
+ # y_pred_[y_pred>=thres] = 1.0
50
+ # y_pred_[y_pred<thres] = 0.0
51
+ num = 2 * torch.sum(y * y_pred, dim=1, keepdim=True) + smooth
52
+ den = torch.sum(y, dim=1, keepdim=True) + \
53
+ torch.sum(y_pred, dim=1, keepdim=True) + smooth
54
+ score = num / den
55
+ return score
56
+
57
+ def init_weights(m):
58
+ if isinstance(m, torch.nn.Conv2d):
59
+ torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
60
+ # torch.nn.init.constant_(m.bias, 0)
61
+ elif isinstance(m, torch.nn.BatchNorm2d):
62
+ torch.nn.init.constant_(m.weight, 1)
63
+
64
+ def load_vgg16(segnet):
65
+ vgg16 = models.vgg16_bn(pretrained=True)
66
+ with open('paired_weight_vgg16.plk', 'rb') as handle:
67
+ paired = pickle.load(handle)
68
+ segnet_p = dict(segnet.state_dict())
69
+ vgg16_p = vgg16.state_dict()
70
+
71
+ for k, v in paired.items():
72
+ for n, p in vgg16_p.items():
73
+ if n == v:
74
+ segnet_p[k].data.copy_(p.data)
75
+ segnet.load_state_dict(segnet_p)
76
+ return segnet
77
+
78
+ def train_one_epoch(model, train_iter, optimizer, device):
79
+ model.train()
80
+ for p in model.parameters():
81
+ p.requires_grad = True
82
+ for x, y in train_iter:
83
+ x, y = x.to(device), y.to(device)
84
+ bs = x.shape[0]
85
+ optimizer.zero_grad()
86
+ y_pred = model(x)
87
+ loss = 1 - dice_score(y, y_pred)
88
+ loss = torch.sum(loss) / bs
89
+ loss.backward()
90
+ optimizer.step()
91
+
92
+ def evaluate(model, dataset, device, thres=0.9):
93
+ model.eval()
94
+ torch.cuda.empty_cache()
95
+ num, den = 0, 0
96
+ # shutdown the autograd
97
+ with torch.no_grad():
98
+ for i in range(len(dataset)):
99
+ x, y = dataset[i]
100
+ x, y = x.unsqueeze(0).to(device), y.unsqueeze(0).to(device)
101
+ y_pred = model(x)
102
+ y = y.cpu().detach().numpy()
103
+ y_pred = y_pred.cpu().detach().numpy()
104
+ y_pred[y_pred>=thres] = 1.0
105
+ y_pred[y_pred<thres] = 0.0
106
+ num += 2 * (y_pred * y).sum()
107
+ den += y_pred.sum() + y.sum()
108
+ torch.cuda.empty_cache()
109
+ return num / den
example_1.jpg ADDED

Git LFS Details

  • SHA256: 7cc73230caa75bc91bda46f9158ef92d9d746b69c1d8eed5a7ba9374105b5d13
  • Pointer size: 131 Bytes
  • Size of remote file: 999 kB
example_2.jpg ADDED

Git LFS Details

  • SHA256: d05f1b7fef6657b3639e217c3e17f797a2cc0369c28efd66577214fbac6b68d2
  • Pointer size: 131 Bytes
  • Size of remote file: 871 kB
example_3.jpg ADDED

Git LFS Details

  • SHA256: 5c98df3cb589224d08c8d6fad3a309f25ab57575fa40f5c122cadf930cc5413f
  • Pointer size: 131 Bytes
  • Size of remote file: 722 kB
flagged/input_img/a7a20e8c8e03de5e007f/example_1.jpg ADDED

Git LFS Details

  • SHA256: 7cc73230caa75bc91bda46f9158ef92d9d746b69c1d8eed5a7ba9374105b5d13
  • Pointer size: 131 Bytes
  • Size of remote file: 999 kB
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ input_img,Model,output,flag,username,timestamp
2
+ flagged\input_img\a7a20e8c8e03de5e007f\example_1.jpg,segroot_finetuned,,,,2024-11-20 11:20:45.490192
logo.png ADDED

Git LFS Details

  • SHA256: 1c4e90f6cbc8f1b5395af452ff65aa7dd6bd155072ef619c014eb19a3760199c
  • Pointer size: 130 Bytes
  • Size of remote file: 48.2 kB
main.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from processsors import RootSegmentor
4
+ from processsors import *
5
+
6
+ from gradio_imageslider import ImageSlider
7
+
8
+ import cv2 as cv
9
+
10
+ PRELOAD_MODELS = False
11
+
12
+ if PRELOAD_MODELS:
13
+ root_segmentor = RootSegmentor()
14
+
15
+
16
+ def process(input_img, model_type):
17
+
18
+ print(model_type)
19
+
20
+ if PRELOAD_MODELS:
21
+ global root_segmentor
22
+ else:
23
+ root_segmentor = RootSegmentor(model_type)
24
+
25
+ result = root_segmentor.predict(input_img)
26
+
27
+ return result
28
+
29
+ def just_show(files, should_process, model_type):
30
+
31
+ imgs = []
32
+
33
+ img = merge_images(files)
34
+
35
+
36
+
37
+ imgs.append(img)
38
+
39
+ if should_process:
40
+ root_segmentor = RootSegmentor(model_type)
41
+
42
+ results = []
43
+
44
+ for file in files:
45
+ print(type(file))
46
+ print(file)
47
+ img = cv.imread(file)
48
+ img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
49
+ #imgs.append(img)
50
+
51
+ if should_process:
52
+
53
+ result = root_segmentor.predict(img)
54
+ results.append(result)
55
+ #imgs.append(results)
56
+
57
+ if should_process:
58
+ img_res = merge_images(results)
59
+ imgs.append(img_res)
60
+
61
+ return imgs
62
+
63
+ def slider_test(img1, img2):
64
+
65
+ return [img1,img2]
66
+
67
+ def download_result():
68
+
69
+ #print(filepath)
70
+ return
71
+
72
+
73
+ def gui():
74
+
75
+ with gr.Blocks(title="Root analysis", theme=gr.themes.Soft()) as demo:
76
+
77
+ big_block = gr.HTML("""
78
+
79
+ <style>
80
+ body {
81
+ font-family: Arial, sans-serif;
82
+ background-color: white
83
+ margin: 0;
84
+ }
85
+
86
+ header {
87
+ display: flex;
88
+ justify-content: space-between;
89
+ align-items: center;
90
+ padding: 5px;
91
+ color: #fff;
92
+ }
93
+
94
+ hr {
95
+ border: 1px solid #ddd;
96
+ margin: 5px;
97
+ }
98
+
99
+ </style>
100
+
101
+ <header>
102
+ <div style="display: flex; align-items: center;">
103
+ <div style="text-align: left;">
104
+ <h1>Root Analysis</h1>
105
+ <p>Root segmentation using underground root scanner images.</p>
106
+ <h3>Tropical Forages Program</h3>
107
+ <p><b>Authors: </b>Andres Felipe Ruiz-Hurtado, Juan Andrés Cardoso Arango</p>
108
+ <p></p>
109
+ </div>
110
+ </div>
111
+ <div style="background-color: white; padding: 5px; border-radius: 15px; box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);">
112
+ <img src='file/logo.png' alt="Logo" width="200" height="100">
113
+ </div>
114
+ </header>
115
+
116
+ """)
117
+
118
+ #<iframe style="height:600px;width: 100%;" src="/file=slides.html" title="description"></iframe>
119
+
120
+
121
+ #<iframe style="height:600px;width: 100%;" src="https://revealjs.com/demo/?view" title="description"></iframe>
122
+
123
+ with gr.Tab("Single Image"):
124
+
125
+ model_selector = gr.Dropdown(
126
+ ["segroot_finetuned", "segroot", "segroot_finetuned_dec", "seg_model"], label="Model"
127
+ , info="AI model"
128
+ ,value="segroot_finetuned"
129
+ )
130
+
131
+ input_img=gr.Image(render=False)
132
+ output_img=gr.Image(render=False)
133
+
134
+ gr.Interface(
135
+ fn=process,
136
+ inputs=[input_img,model_selector],
137
+ outputs=output_img,
138
+ examples=[["example_1.jpg"],["example_2.jpg"],["example_3.jpg"]]
139
+ )
140
+
141
+ #examples = gr.Examples([["Chicago"], ["Little Rock"], ["San Francisco"]], textbox)
142
+
143
+ with gr.Row():
144
+ img_comp = ImageSlider(label="Root Segmentation")
145
+ with gr.Row():
146
+ compare_button = gr.Button("Compare")
147
+ compare_button.click(fn=slider_test, inputs=[input_img,output_img], outputs=img_comp, api_name="slider_test")
148
+
149
+ with gr.Tab("Multiple Images"):
150
+
151
+ #img_comp = ImageSlider(label="Blur image", type="pil")
152
+
153
+ gallery = gr.Gallery(show_fullscreen_button=True, render=False)
154
+
155
+ gr.Interface(
156
+ fn=just_show
157
+ ,inputs=[gr.File(file_count="multiple"),gr.Checkbox(label="Process", info="Check if you want to process"),model_selector]
158
+ ,outputs= gallery
159
+ , examples=[[["example_1.jpg", "example_2.jpg", "example_3.jpg"]]]
160
+ )
161
+
162
+ with gr.Tab("Compare"):
163
+
164
+ img_comp = ImageSlider(label="Root Segmentation")
165
+ img_comp.upload(inputs=img_comp, outputs=img_comp)
166
+
167
+
168
+ #d = gr.DownloadButton("Download the file")
169
+ #d.click(download_result, gallery, None)
170
+
171
+ # with gr.Row():
172
+ # img1=gr.Image()
173
+ # img2=gr.Image()
174
+ # with gr.Row():
175
+ # img_comp = ImageSlider(label="Blur image", type="pil")
176
+ # with gr.Row():
177
+ # compare_button = gr.Button("Compare")
178
+ # compare_button.click(fn=slider_test, inputs=[img1,img2], outputs=img_comp, api_name="slider_test")
179
+
180
+ # with gr.Group():
181
+ # img_comp = ImageSlider(label="Blur image", type="pil")
182
+ # #img1.upload(slider_test, inputs=[img1,img2], outputs=img_comp)
183
+ # gr.Interface(slider_test, inputs=[img1,img2], outputs=img_comp)
184
+
185
+ demo.launch(allowed_paths=["logo.png"], share=False)
186
+
187
+ if __name__ == "__main__":
188
+ gui()
models/best_segnet-(8,5)-0.6441.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dffa166609b5ab3241d1b175bffbf454377beaa4a7fb46bd74e38605e2f71d03
3
+ size 1611034
models/roots_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48254d394d1b11fd9bcfd42bcc754bb1fba5a2052848f5ad70b259972bce4681
3
+ size 58655218
models/segroot-(8,5)_finetuned.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbb992086ea1900ef24e110d7b454126d6214ac8f14687348ec021cf860f4eca
3
+ size 1640578
processsors.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ from skimage.morphology import erosion
9
+
10
+ from dependecies.segroot.model import SegRoot
11
+ from dependecies.segroot.dataloader import pad_pair_256, normalize
12
+ from torchvision.transforms import v2 as transforms
13
+
14
+
15
+ import onnxruntime as ort
16
+ import cv2 as cv
17
+
18
+ import os
19
+
20
+ MODELS_PATH = r"./models"
21
+
22
+ def pad_256(img_path):
23
+ image = Image.open(img_path)
24
+ W, H = image.size
25
+ img, _ = pad_pair_256(image, image)
26
+ NW, NH = img.size
27
+ img = torchvision.transforms.ToTensor()(img)
28
+ img = normalize(img)
29
+ return img, (H, W, NH, NW)
30
+
31
+ def pad_256_np(np_img):
32
+ #image = Image.open(img_path)
33
+ image = Image.fromarray(np_img)
34
+ W, H = image.size
35
+ img, _ = pad_pair_256(image, image)
36
+ NW, NH = img.size
37
+ img = torchvision.transforms.ToTensor()(img)
38
+ img = normalize(img)
39
+ return img, (H, W, NH, NW)
40
+
41
+ def merge_images(files, path=""):
42
+
43
+ is_array = False
44
+ if type(files[0]) == np.ndarray:
45
+ is_array = True
46
+
47
+
48
+ final_img = []
49
+ resize_factor = 0.4
50
+ offset0 = 930
51
+ offset1 = 305
52
+ for index, file in enumerate(files):
53
+
54
+ if is_array:
55
+ img = file
56
+ else:
57
+ img = cv.imread(file)
58
+ img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
59
+ #img = cv.resize(img, (0,0), fx=resize_factor, fy=resize_factor)
60
+ img = cv.rotate(img, cv.ROTATE_90_CLOCKWISE)
61
+
62
+ if index == 0:
63
+ img = img[0:img.shape[0]-offset0,0:img.shape[1]]
64
+ final_img = img
65
+ elif index == len(file)-1:
66
+ final_img = cv.vconcat([final_img, img])
67
+ else:
68
+ #final_img = np.concatenate((final_img, img), axis=1)
69
+ img = img[0:img.shape[0]-offset1,0:img.shape[1]]
70
+ final_img = cv.vconcat([final_img, img])
71
+
72
+ final_img = cv.resize(final_img, (0,0), fx=resize_factor, fy=resize_factor)
73
+
74
+ #cv.imwrite(path, final_img)
75
+ print(final_img.shape)
76
+
77
+ return final_img
78
+
79
+ class RootSegmentor():
80
+
81
+ def __init__(self, model_type):
82
+
83
+
84
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
85
+
86
+ self.model_type = model_type
87
+
88
+ if model_type != "seg_model":
89
+ self.initialize()
90
+
91
+ return
92
+
93
+ def initialize(self):
94
+
95
+ width = 8
96
+ depth = 5
97
+
98
+ if self.model_type == "segroot":
99
+ #weights_path = os.path.join(r"D:\local_mydev\roots_finetuning\SegRoot0\weights\best_segnet-(8,5)-0.6441.pt"
100
+ #weights_path = r"D:\local_mydev\SegRoot\weights\best_segnet-(8,5)-0.6441.pt"
101
+ #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\best_segnet-(8,5)-0.6441.pt"
102
+ #weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\best_segnet-(8,5)-0.6441.pt")
103
+ weights_path = os.path.join(MODELS_PATH, r"best_segnet-(8,5)-0.6441.pt")
104
+ elif self.model_type == "segroot_finetuned":
105
+ #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned.pt"
106
+ #weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned.pt")
107
+ weights_path = os.path.join(MODELS_PATH, r"segroot-(8,5)_finetuned.pt")
108
+ elif self.model_type == "segroot_finetuned_dec":
109
+ #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_dec_full.pt"
110
+ #weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_clas.pt"
111
+ #weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_clas.pt")
112
+ weights_path = os.path.join(MODELS_PATH, r"segroot-(8,5)_finetuned.pt")
113
+
114
+ self.model = SegRoot(width, depth).to(self.device)
115
+
116
+ if self.device.type == "cpu":
117
+ print("load weights to cpu")
118
+ #print(weights_path.as_posix())
119
+ self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
120
+ else:
121
+ print("load weights to gpu")
122
+ #print(weights_path.as_posix())
123
+ self.model.load_state_dict(torch.load(weights_path))
124
+
125
+ for p in self.model.parameters():
126
+ p.requires_grad = False
127
+
128
+ self.model.eval()
129
+
130
+ return
131
+
132
+ def predict(self, img_path):
133
+
134
+ if self.model_type == "seg_model":
135
+
136
+ print(str(type(img_path)))
137
+
138
+ if type(img_path) == np.ndarray:
139
+ img = img_path
140
+ else:
141
+ img = cv.imread(img_path)
142
+ img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
143
+
144
+ weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\roots_model.onnx"
145
+ weights_path = os.path.join(MODELS_PATH,"roots_model.onnx")
146
+ ort_sess = ort.InferenceSession(weights_path
147
+ ,providers=ort.get_available_providers()
148
+ )
149
+
150
+ dim = img.shape
151
+
152
+ transforms_list = []
153
+ transforms_list.append(transforms.ToTensor())
154
+ transforms_list.append(transforms.Resize((800,800)))
155
+ #transforms_list.append(transforms.CenterCrop(800))
156
+ #transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
157
+
158
+ apply_t = transforms.Compose(transforms_list)
159
+
160
+ img = apply_t(img)
161
+
162
+ outputs = ort_sess.run(None, {'input': [img.numpy()]})
163
+
164
+ print(outputs)
165
+
166
+ #np_res = outputs[0][0]
167
+
168
+ output_image = outputs[0][:,:,1]
169
+ final = cv.resize(output_image, (dim[0], dim[1]))
170
+
171
+ return final
172
+
173
+ else:
174
+
175
+ thres = 0.9
176
+
177
+ print(str(type(img_path)))
178
+
179
+ if type(img_path) == np.ndarray:
180
+ img, dims = pad_256_np(img_path)
181
+ else:
182
+ img, dims = pad_256(img_path)
183
+
184
+ H, W, NH, NW = dims
185
+
186
+ img = img.to(self.device)
187
+
188
+ img = img.unsqueeze(0)
189
+ output = self.model(img)
190
+
191
+ output = torch.squeeze(output)
192
+ torch.cuda.empty_cache()
193
+
194
+ prediction = output
195
+
196
+ prediction[prediction >= thres] = 1.0
197
+ prediction[prediction < thres] = 0.0
198
+
199
+ if self.device.type == "cpu":
200
+ prediction = prediction.detach().numpy()
201
+ else:
202
+ prediction = prediction.cpu().detach().numpy()
203
+
204
+ prediction = erosion(prediction)
205
+ # reverse padding
206
+ prediction = prediction[
207
+ (NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W
208
+ ]
209
+
210
+ return prediction
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ opencv-python
4
+ pillow
5
+ scikit-image
6
+ scikit-learn
7
+ torch
8
+ torchvision
9
+ gradio
10
+ onnxruntime
11
+ rasterio