Spaces:
Running
Running
Andres Felipe Ruiz-Hurtado
commited on
Commit
·
bc97962
1
Parent(s):
c16482d
initial
Browse files- .gitattributes +2 -0
- README.md +1 -1
- __pycache__/processsors.cpython-312.pyc +0 -0
- dependecies/__init__.py +0 -0
- dependecies/__pycache__/__init__.cpython-312.pyc +0 -0
- dependecies/segroot/__init__.py +0 -0
- dependecies/segroot/__pycache__/__init__.cpython-312.pyc +0 -0
- dependecies/segroot/__pycache__/dataloader.cpython-312.pyc +0 -0
- dependecies/segroot/__pycache__/model.cpython-312.pyc +0 -0
- dependecies/segroot/__pycache__/paired_transforms_pt04.cpython-312.pyc +0 -0
- dependecies/segroot/binarize_crop.py +72 -0
- dependecies/segroot/dataloader.py +151 -0
- dependecies/segroot/main_segroot.py +112 -0
- dependecies/segroot/model.py +124 -0
- dependecies/segroot/paired_transforms_pt04.py +1027 -0
- dependecies/segroot/paired_weight_vgg16.plk +0 -0
- dependecies/segroot/predict_imgs.py +121 -0
- dependecies/segroot/run_all_experiments.sh +6 -0
- dependecies/segroot/utils.py +109 -0
- example_1.jpg +3 -0
- example_2.jpg +3 -0
- example_3.jpg +3 -0
- flagged/input_img/a7a20e8c8e03de5e007f/example_1.jpg +3 -0
- flagged/log.csv +2 -0
- logo.png +3 -0
- main.py +188 -0
- models/best_segnet-(8,5)-0.6441.pt +3 -0
- models/roots_model.onnx +3 -0
- models/segroot-(8,5)_finetuned.pt +3 -0
- processsors.py +210 -0
- requirements.txt +11 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.6.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
short_description: Root analysis using deep learning
|
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.6.0
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
short_description: Root analysis using deep learning
|
__pycache__/processsors.cpython-312.pyc
ADDED
Binary file (8.33 kB). View file
|
|
dependecies/__init__.py
ADDED
File without changes
|
dependecies/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (160 Bytes). View file
|
|
dependecies/segroot/__init__.py
ADDED
File without changes
|
dependecies/segroot/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (168 Bytes). View file
|
|
dependecies/segroot/__pycache__/dataloader.cpython-312.pyc
ADDED
Binary file (9.13 kB). View file
|
|
dependecies/segroot/__pycache__/model.cpython-312.pyc
ADDED
Binary file (7.64 kB). View file
|
|
dependecies/segroot/__pycache__/paired_transforms_pt04.cpython-312.pyc
ADDED
Binary file (56.8 kB). View file
|
|
dependecies/segroot/binarize_crop.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from PIL import Image
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import skimage.io as io
|
5 |
+
from skimage.morphology import dilation
|
6 |
+
import pickle
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
from dataloader import pad_pair_256
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument(
|
13 |
+
"--dilate",
|
14 |
+
default=0,
|
15 |
+
type=int,
|
16 |
+
help="dilation degree of masks")
|
17 |
+
|
18 |
+
args = parser.parse_args()
|
19 |
+
data_dir = Path('../data/data_raw')
|
20 |
+
mask_dir = Path('../data/masks')
|
21 |
+
mask_dir.mkdir(exist_ok=True, parents=True)
|
22 |
+
|
23 |
+
imgs = sorted(list(data_dir.glob('*Untitled.jpg')))
|
24 |
+
print('original images count : ', len(imgs))
|
25 |
+
masks = sorted(list(data_dir.glob('*Untitled-mask.jpg')))
|
26 |
+
print('original masks count : ', len(masks))
|
27 |
+
|
28 |
+
# generate binary masks for every annotated images
|
29 |
+
for m in masks:
|
30 |
+
mask = io.imread(m.as_posix(), as_gray=True)
|
31 |
+
# binarize
|
32 |
+
mask[mask > 0.5 ] = 1.0
|
33 |
+
mask[mask <= 0.5] = 0.0
|
34 |
+
for i in range(args.dilate):
|
35 |
+
mask = dilation(mask)
|
36 |
+
print('binary masks dilated !!!')
|
37 |
+
plt.imsave((mask_dir / m.parts[-1]).as_posix(), mask, cmap='gray')
|
38 |
+
print('binary masks generated !!!')
|
39 |
+
|
40 |
+
# save idx info in a dictionary
|
41 |
+
info_dict = {k: v.parts[-1] for k, v in enumerate(imgs)}
|
42 |
+
with open('../data/info.pkl', 'wb') as handle:
|
43 |
+
pickle.dump(info_dict, handle)
|
44 |
+
print('index info saved!!!')
|
45 |
+
|
46 |
+
# crop the padded image to generate 256*256 subimages
|
47 |
+
new_masks = sorted(list(mask_dir.glob('*Untitled-mask.jpg')))
|
48 |
+
print('new_mask length : ',len(new_masks))
|
49 |
+
|
50 |
+
subimg_path = Path('../data/subimg')
|
51 |
+
subimg_path.mkdir(exist_ok=True, parents=True)
|
52 |
+
submask_path = Path('../data/submask')
|
53 |
+
submask_path.mkdir(exist_ok=True, parents=True)
|
54 |
+
|
55 |
+
for idx, (mask_path, img_path) in enumerate(zip(new_masks, imgs)):
|
56 |
+
mask = Image.open(mask_path)
|
57 |
+
img = Image.open(img_path)
|
58 |
+
new_img, new_mask = pad_pair_256(img, mask)
|
59 |
+
new_img, new_mask = np.array(new_img), np.array(new_mask)
|
60 |
+
# padded shape (2560, 2304)
|
61 |
+
w, h, _ = new_img.shape
|
62 |
+
for i in range(int(w/256)):
|
63 |
+
for j in range(int(h/256)):
|
64 |
+
subimg = new_img[i*256:(i+1)*256, j*256:(j+1)*256, :]
|
65 |
+
subimg_fn = '{}/{}-{}-{}.png'.format(
|
66 |
+
Path('../data/subimg').as_posix(), idx, i, j)
|
67 |
+
plt.imsave(subimg_fn, subimg)
|
68 |
+
submask_fn = '{}/{}-{}-{}.png'.format(
|
69 |
+
Path('../data/submask').as_posix(), idx, i, j)
|
70 |
+
submask = new_mask[i*256:(i+1)*256, j*256:(j+1)*256]
|
71 |
+
plt.imsave(submask_fn, submask, cmap='gray')
|
72 |
+
print('No.{} image & mask cropped!!!'.format(idx))
|
dependecies/segroot/dataloader.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import itertools
|
3 |
+
import pickle
|
4 |
+
import torch
|
5 |
+
from torchvision import models
|
6 |
+
from pathlib import Path
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
9 |
+
|
10 |
+
import dependecies.segroot.paired_transforms_pt04 as p_tr
|
11 |
+
|
12 |
+
train_transform = p_tr.Compose([
|
13 |
+
p_tr.RandomCrop(256),
|
14 |
+
p_tr.RandomRotation((90, 90)),
|
15 |
+
p_tr.RandomRotation((180, 180)),
|
16 |
+
p_tr.RandomRotation((270, 270)),
|
17 |
+
p_tr.RandomHorizontalFlip(),
|
18 |
+
p_tr.RandomVerticalFlip(),
|
19 |
+
p_tr.ToTensor()
|
20 |
+
])
|
21 |
+
|
22 |
+
# normalize = p_tr.Normalize([0.35042979, 0.44016893, 0.2340332],
|
23 |
+
# [0.20999724, 0.25972678, 0.13885915])
|
24 |
+
normalize = p_tr.Normalize([0.5, 0.5, 0.5],
|
25 |
+
[0.5, 0.5, 0.5])
|
26 |
+
|
27 |
+
|
28 |
+
def pad_pair_256(image, gt):
|
29 |
+
w, h = image.size
|
30 |
+
new_w = ((w - 1) // 256 + 1) * 256
|
31 |
+
new_h = ((h - 1) // 256 + 1) * 256
|
32 |
+
new_image = Image.new("RGB", (new_w, new_h))
|
33 |
+
new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2))
|
34 |
+
new_gt = Image.new("L", (new_w, new_h))
|
35 |
+
new_gt.paste(gt, ((new_w - w) // 2, (new_h - h) // 2))
|
36 |
+
return new_image, new_gt
|
37 |
+
|
38 |
+
|
39 |
+
def convert_png(image, gt):
|
40 |
+
new_image = Image.new('RGB', (256, 256))
|
41 |
+
new_image.paste(image)
|
42 |
+
new_gt = Image.new('L', (256, 256))
|
43 |
+
new_gt.paste(gt)
|
44 |
+
return new_image, new_gt
|
45 |
+
|
46 |
+
|
47 |
+
def get_paths(root_dir, im_ids):
|
48 |
+
imgs = []
|
49 |
+
for i in im_ids:
|
50 |
+
tmp = Path(root_dir).glob('*{}-*.png'.format(i))
|
51 |
+
tmp = [p for p in tmp if p.parts[-1].startswith(str(i)+'-')]
|
52 |
+
imgs = imgs + list(tmp)
|
53 |
+
return imgs
|
54 |
+
|
55 |
+
|
56 |
+
class LoopSampler(Sampler):
|
57 |
+
def __init__(self, data_source):
|
58 |
+
self.data_source = data_source
|
59 |
+
|
60 |
+
def __iter__(self):
|
61 |
+
return itertools.cycle(range(len(self.data_source)))
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.data_source)
|
65 |
+
|
66 |
+
|
67 |
+
class TrainDataset(Dataset):
|
68 |
+
def __init__(self, im_ids):
|
69 |
+
self.root_dir = '../data/data_raw'
|
70 |
+
self.mask_dir = '../data/mask'
|
71 |
+
self.im_ids = im_ids
|
72 |
+
with open('../data/info.pkl', 'rb') as handle:
|
73 |
+
self.info = pickle.load(handle)
|
74 |
+
self.fns = [self.info[im_id] for im_id in im_ids]
|
75 |
+
|
76 |
+
def __getitem__(self, index):
|
77 |
+
im_fn = self.fns[index]
|
78 |
+
im_name = os.path.join(self.root_dir, im_fn)
|
79 |
+
gt_name = os.path.join(
|
80 |
+
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
|
81 |
+
image = Image.open(im_name)
|
82 |
+
gt = Image.open(gt_name)
|
83 |
+
image, gt = pad_pair_256(image, gt)
|
84 |
+
|
85 |
+
image, gt = train_transform(image, gt)
|
86 |
+
image = normalize(image)
|
87 |
+
|
88 |
+
return image, gt
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return len(self.im_ids)
|
92 |
+
|
93 |
+
|
94 |
+
class StaticTrainDataset(Dataset):
|
95 |
+
def __init__(self, im_ids):
|
96 |
+
self.subimgs = sorted(get_paths('../data/subimg', im_ids))
|
97 |
+
self.submasks = sorted(get_paths('../data/submask', im_ids))
|
98 |
+
self.im_ids = im_ids
|
99 |
+
|
100 |
+
def __getitem__(self, index):
|
101 |
+
im_name = self.subimgs[index]
|
102 |
+
gt_name = self.submasks[index]
|
103 |
+
image = Image.open(im_name)
|
104 |
+
gt = Image.open(gt_name)
|
105 |
+
image, gt = convert_png(image, gt)
|
106 |
+
|
107 |
+
image, gt = train_transform(image, gt)
|
108 |
+
image = normalize(image)
|
109 |
+
|
110 |
+
return image, gt
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return len(self.im_ids * 90)
|
114 |
+
|
115 |
+
|
116 |
+
class TrainDataLoader():
|
117 |
+
def __init__(self, dataset, batch_size, num_workers=0):
|
118 |
+
self.dataset = dataset
|
119 |
+
self.dataloader = DataLoader(self.dataset, batch_size=batch_size,
|
120 |
+
num_workers=num_workers, sampler=LoopSampler(self.dataset))
|
121 |
+
self.dl = iter(self.dataloader)
|
122 |
+
|
123 |
+
def next_batch(self):
|
124 |
+
image, gt = next(self.dl)
|
125 |
+
return image, gt
|
126 |
+
|
127 |
+
|
128 |
+
class TestDataset(Dataset):
|
129 |
+
def __init__(self, im_ids):
|
130 |
+
self.root_dir = '../data/data_raw'
|
131 |
+
self.mask_dir = '../data/masks'
|
132 |
+
with open('../data/info.pkl', 'rb') as handle:
|
133 |
+
self.info = pickle.load(handle)
|
134 |
+
self.im_ids = im_ids
|
135 |
+
self.fns = [self.info[im_id] for im_id in im_ids]
|
136 |
+
|
137 |
+
def __getitem__(self, index):
|
138 |
+
im_fn = self.fns[index]
|
139 |
+
im_name = os.path.join(self.root_dir, im_fn)
|
140 |
+
gt_name = os.path.join(
|
141 |
+
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
|
142 |
+
image = Image.open(im_name)
|
143 |
+
gt = Image.open(gt_name)
|
144 |
+
image, gt = pad_pair_256(image, gt)
|
145 |
+
|
146 |
+
image, gt = p_tr.ToTensor()(image, gt)
|
147 |
+
image = normalize(image)
|
148 |
+
return image, gt
|
149 |
+
|
150 |
+
def __len__(self):
|
151 |
+
return len(self.fns)
|
dependecies/segroot/main_segroot.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
from tqdm import tqdm
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
from model import SegRoot
|
9 |
+
from dataloader import StaticTrainDataset, TestDataset, TrainDataset, LoopSampler
|
10 |
+
from utils import (
|
11 |
+
dice_score,
|
12 |
+
init_weights,
|
13 |
+
evaluate,
|
14 |
+
get_ids,
|
15 |
+
load_vgg16,
|
16 |
+
set_random_seed,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--seed", default=42, type=int, help="set random seed")
|
22 |
+
parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
|
23 |
+
parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
|
24 |
+
parser.add_argument("--bs", default=64, type=int, help="batch size of dataloaders")
|
25 |
+
parser.add_argument("--lr", default=1e-2, type=float, help="learning rate")
|
26 |
+
parser.add_argument("--epochs", default=200, type=int, help="max epochs of training")
|
27 |
+
parser.add_argument(
|
28 |
+
"--verbose", default=5, type=int, help="intervals to save and validate model"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--dynamic", action="store_true", help="use dynamic sub-images during training"
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def train_one_epoch(model, train_iter, optimizer, device):
|
36 |
+
model.train()
|
37 |
+
for p in model.parameters():
|
38 |
+
p.requires_grad = True
|
39 |
+
for x, y in train_iter:
|
40 |
+
x, y = x.to(device), y.to(device)
|
41 |
+
bs = x.shape[0]
|
42 |
+
optimizer.zero_grad()
|
43 |
+
y_pred = model(x)
|
44 |
+
loss = 1 - dice_score(y, y_pred)
|
45 |
+
loss = torch.sum(loss) / bs
|
46 |
+
loss.backward()
|
47 |
+
optimizer.step()
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
args = parser.parse_args()
|
52 |
+
seed = args.seed
|
53 |
+
bs = args.bs
|
54 |
+
lr = args.lr
|
55 |
+
width = args.width
|
56 |
+
depth = args.depth
|
57 |
+
epochs = args.epochs
|
58 |
+
verbose = args.verbose
|
59 |
+
|
60 |
+
# set random seed
|
61 |
+
set_random_seed(seed)
|
62 |
+
# define the device for training
|
63 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
64 |
+
# get training ids
|
65 |
+
train_ids, valid_ids, test_ids = get_ids(65)
|
66 |
+
# define dataloaders
|
67 |
+
if args.dynamic:
|
68 |
+
train_data = TrainDataset(train_ids)
|
69 |
+
train_iter = DataLoader(
|
70 |
+
train_data, batch_size=bs, num_workers=6, sampler=LoopSampler
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
train_data = StaticTrainDataset(train_ids)
|
74 |
+
train_iter = DataLoader(train_data, batch_size=bs, num_workers=6, shuffle=True)
|
75 |
+
|
76 |
+
train_tdata = TestDataset(train_ids)
|
77 |
+
valid_tdata = TestDataset(valid_ids)
|
78 |
+
test_tdata = TestDataset(test_ids)
|
79 |
+
|
80 |
+
# define model
|
81 |
+
model = SegRoot(width, depth).to(device)
|
82 |
+
model = model.apply(init_weights)
|
83 |
+
|
84 |
+
# define optimizer and lr_scheduler
|
85 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
|
86 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
87 |
+
optimizer, mode="max", factor=0.5, verbose=True, patience=5
|
88 |
+
)
|
89 |
+
|
90 |
+
print(f"Start training SegRoot-({width},{depth}))......")
|
91 |
+
print(f"Random seed is {seed}, batch size is {bs}......")
|
92 |
+
print(f"learning rate is {lr}, max epochs is {epochs}......")
|
93 |
+
best_valid = float("-inf")
|
94 |
+
for epoch in tqdm(range(epochs)):
|
95 |
+
train_one_epoch(model, train_iter, optimizer, device)
|
96 |
+
if epoch % verbose == 0:
|
97 |
+
train_dice = evaluate(model, train_tdata, device)
|
98 |
+
valid_dice = evaluate(model, valid_tdata, device)
|
99 |
+
scheduler.step(valid_dice)
|
100 |
+
print(
|
101 |
+
"Epoch {:05d}, train dice: {:.4f}, valid dice: {:.4f}".format(
|
102 |
+
epoch, train_dice, valid_dice
|
103 |
+
)
|
104 |
+
)
|
105 |
+
if valid_dice > best_valid:
|
106 |
+
best_valid = valid_dice
|
107 |
+
test_dice = evaluate(model, test_tdata, device)
|
108 |
+
print("New best validation, test dice: {:.4f}".format(test_dice))
|
109 |
+
torch.save(
|
110 |
+
model.state_dict(),
|
111 |
+
f"../weights/best_segroot-({args.width},{args.depth}).pt",
|
112 |
+
)
|
dependecies/segroot/model.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class ConvBNRelu(nn.Module):
|
8 |
+
def __init__(self, in_ch, out_ch):
|
9 |
+
super(ConvBNRelu, self).__init__()
|
10 |
+
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True)
|
11 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
12 |
+
self.activation = nn.ReLU()
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
x = self.conv(x)
|
16 |
+
x = self.bn(x)
|
17 |
+
x = self.activation(x)
|
18 |
+
# print(x.shape)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class FirstBlock(nn.Module):
|
23 |
+
def __init__(self, in_ch, out_ch):
|
24 |
+
super(FirstBlock, self).__init__()
|
25 |
+
self.conv1 = ConvBNRelu(in_ch, out_ch)
|
26 |
+
self.conv2 = ConvBNRelu(out_ch, out_ch)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.conv1(x)
|
30 |
+
x = self.conv2(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class DownBlock(nn.Module):
|
35 |
+
def __init__(self, in_ch, out_ch):
|
36 |
+
super(DownBlock, self).__init__()
|
37 |
+
self.conv1 = ConvBNRelu(in_ch, out_ch)
|
38 |
+
self.conv2 = ConvBNRelu(out_ch, out_ch)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = F.max_pool2d(x,kernel_size=2,stride=2)
|
42 |
+
x = self.conv1(x)
|
43 |
+
x = self.conv2(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
class Encoder(nn.Module):
|
47 |
+
def __init__(self, in_ch, out_ch, block_num=2):
|
48 |
+
super(Encoder, self).__init__()
|
49 |
+
layers = []
|
50 |
+
layers += [ConvBNRelu(in_ch, out_ch)]
|
51 |
+
for i in range(block_num-1):
|
52 |
+
layers += [ConvBNRelu(out_ch, out_ch)]
|
53 |
+
# layers += [nn.Dropout2d(0.5)]
|
54 |
+
self.features = nn.Sequential(*layers)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
x = self.features(x)
|
58 |
+
x, indices = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
|
59 |
+
return x, indices
|
60 |
+
|
61 |
+
class Decoder(nn.Module):
|
62 |
+
def __init__(self, in_ch, out_ch, block_num=2):
|
63 |
+
super(Decoder, self).__init__()
|
64 |
+
layers = []
|
65 |
+
layers += [ConvBNRelu(in_ch, out_ch)]
|
66 |
+
for i in range(block_num-1):
|
67 |
+
layers += [ConvBNRelu(out_ch, out_ch)]
|
68 |
+
# layers += [nn.Dropout2d(0.5)]
|
69 |
+
self.features = nn.Sequential(*layers)
|
70 |
+
|
71 |
+
def forward(self, x, indices):
|
72 |
+
x = F.max_unpool2d(x, indices=indices, kernel_size=2, stride=2)
|
73 |
+
x = self.features(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
class SegRoot(nn.Module):
|
77 |
+
def __init__(self, width=8, depth=5, num_classes=2):
|
78 |
+
super(SegRoot, self).__init__()
|
79 |
+
chs = []
|
80 |
+
for i in range(depth-1):
|
81 |
+
chs.append(width * (2**i))
|
82 |
+
chs.append(chs[-1])
|
83 |
+
self.e_ch_info = [3,] + chs
|
84 |
+
self.e_bl_info = [2,2,3,3]
|
85 |
+
for _ in range(depth - 4):
|
86 |
+
self.e_bl_info += [3,]
|
87 |
+
self.d_ch_info = chs[::-1] + [4,]
|
88 |
+
self.d_bl_info = self.e_bl_info[::-1]
|
89 |
+
# using same setup with Unet
|
90 |
+
if width == 4:
|
91 |
+
self.e_ch_info = [3,4,8,16,32,64]
|
92 |
+
self.d_ch_info = [64,32,16,8,4,4]
|
93 |
+
self.num_classes = num_classes
|
94 |
+
self.encoders = nn.ModuleList()
|
95 |
+
self.decoders = nn.ModuleList()
|
96 |
+
|
97 |
+
for i in range(1,len(self.e_ch_info)):
|
98 |
+
self.encoders.append(Encoder(self.e_ch_info[i-1], self.e_ch_info[i], self.e_bl_info[i-1]))
|
99 |
+
self.decoders.append(Decoder(self.d_ch_info[i-1], self.d_ch_info[i], self.d_bl_info[i-1]))
|
100 |
+
|
101 |
+
# self.classifier = nn.Conv2d(self.d_ch_info[-1], num_classes, kernel_size=3, padding=1)
|
102 |
+
self.classifier = nn.Conv2d(self.d_ch_info[-1], 1, 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
indices = []
|
106 |
+
bs = x.shape[0]
|
107 |
+
for i in range(len(self.e_bl_info)):
|
108 |
+
x, ind = self.encoders[i](x)
|
109 |
+
indices.append(ind)
|
110 |
+
|
111 |
+
indices = indices[::-1]
|
112 |
+
for i in range(len(self.e_bl_info)):
|
113 |
+
x = self.decoders[i](x, indices[i])
|
114 |
+
|
115 |
+
x = self.classifier(x)
|
116 |
+
# x = F.softmax(x,dim=1)
|
117 |
+
x = torch.sigmoid(x)
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
x = torch.zeros((1, 3, 256, 256))
|
123 |
+
net = SegRoot(8,5)
|
124 |
+
print(net(x).shape)
|
dependecies/segroot/paired_transforms_pt04.py
ADDED
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from PIL import Image, ImageOps, ImageEnhance
|
6 |
+
try:
|
7 |
+
import accimage
|
8 |
+
except ImportError:
|
9 |
+
accimage = None
|
10 |
+
import numpy as np
|
11 |
+
import numbers
|
12 |
+
import types
|
13 |
+
import collections
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
from torchvision.transforms import functional as F
|
17 |
+
|
18 |
+
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
|
19 |
+
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
|
20 |
+
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
|
21 |
+
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
|
22 |
+
|
23 |
+
_pil_interpolation_to_str = {
|
24 |
+
Image.NEAREST: 'PIL.Image.NEAREST',
|
25 |
+
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
26 |
+
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
27 |
+
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class Compose(object):
|
32 |
+
"""Composes several transforms together.
|
33 |
+
Args:
|
34 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
35 |
+
Example:
|
36 |
+
>>> transforms.Compose([
|
37 |
+
>>> transforms.CenterCrop(10),
|
38 |
+
>>> transforms.ToTensor(),
|
39 |
+
>>> ])
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, transforms):
|
43 |
+
self.transforms = transforms
|
44 |
+
|
45 |
+
def __call__(self, img, target = None):
|
46 |
+
if target is not None:
|
47 |
+
for t in self.transforms:
|
48 |
+
img, target = t(img, target)
|
49 |
+
return img, target
|
50 |
+
|
51 |
+
for t in self.transforms:
|
52 |
+
img = t(img)
|
53 |
+
return img
|
54 |
+
|
55 |
+
def __repr__(self):
|
56 |
+
format_string = self.__class__.__name__ + '('
|
57 |
+
for t in self.transforms:
|
58 |
+
format_string += '\n'
|
59 |
+
format_string += ' {0}'.format(t)
|
60 |
+
format_string += '\n)'
|
61 |
+
return format_string
|
62 |
+
|
63 |
+
|
64 |
+
class ToTensor(object):
|
65 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
66 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
67 |
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __call__(self, pic, pic2=None):
|
71 |
+
"""
|
72 |
+
Args:
|
73 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
74 |
+
pic2 (PIL Image): (optional) Second image to be converted also.
|
75 |
+
Returns:
|
76 |
+
Tensor(s): Converted image(s).
|
77 |
+
"""
|
78 |
+
if pic2 is not None:
|
79 |
+
return F.to_tensor(pic), F.to_tensor(pic2)
|
80 |
+
return F.to_tensor(pic)
|
81 |
+
|
82 |
+
def __repr__(self):
|
83 |
+
return self.__class__.__name__ + '()'
|
84 |
+
|
85 |
+
|
86 |
+
class ToPILImage(object):
|
87 |
+
"""Convert a tensor or an ndarray to PIL Image.
|
88 |
+
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
|
89 |
+
H x W x C to a PIL Image while preserving the value range.
|
90 |
+
Args:
|
91 |
+
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
92 |
+
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
|
93 |
+
1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
|
94 |
+
2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
|
95 |
+
3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e,
|
96 |
+
``int``, ``float``, ``short``).
|
97 |
+
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
|
98 |
+
"""
|
99 |
+
def __init__(self, mode=None):
|
100 |
+
self.mode = mode
|
101 |
+
|
102 |
+
def __call__(self, pic, pic2=None):
|
103 |
+
"""
|
104 |
+
Args:
|
105 |
+
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
106 |
+
Returns:
|
107 |
+
PIL Image: Image converted to PIL Image.
|
108 |
+
"""
|
109 |
+
if pic2 is not None:
|
110 |
+
return F.to_pil_image(pic), F.to_pil_image(pic2)
|
111 |
+
return F.to_pil_image(pic, self.mode)
|
112 |
+
|
113 |
+
def __repr__(self):
|
114 |
+
format_string = self.__class__.__name__ + '('
|
115 |
+
if self.mode is not None:
|
116 |
+
format_string += 'mode={0}'.format(self.mode)
|
117 |
+
format_string += ')'
|
118 |
+
return format_string
|
119 |
+
|
120 |
+
|
121 |
+
class Normalize(object):
|
122 |
+
"""Normalize a tensor image with mean and standard deviation.
|
123 |
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
124 |
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
125 |
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
126 |
+
Args:
|
127 |
+
mean (sequence): Sequence of means for each channel.
|
128 |
+
std (sequence): Sequence of standard deviations for each channel.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, mean, std):
|
132 |
+
self.mean = mean
|
133 |
+
self.std = std
|
134 |
+
|
135 |
+
def __call__(self, tensor):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
139 |
+
Returns:
|
140 |
+
Tensor: Normalized Tensor image.
|
141 |
+
"""
|
142 |
+
return F.normalize(tensor, self.mean, self.std)
|
143 |
+
|
144 |
+
def __repr__(self):
|
145 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
146 |
+
|
147 |
+
|
148 |
+
class Resize(object):
|
149 |
+
"""Resize the input PIL Image to the given size.
|
150 |
+
Args:
|
151 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
152 |
+
(h, w), output size will be matched to this. If size is an int,
|
153 |
+
smaller edge of the image will be matched to this number.
|
154 |
+
i.e, if height > width, then image will be rescaled to
|
155 |
+
(size * height / width, size)
|
156 |
+
interpolation (int, optional): Desired interpolation. Default is
|
157 |
+
``PIL.Image.BILINEAR``
|
158 |
+
interpolation_tg (int, optional): Desired interpolation for target. Default is
|
159 |
+
``PIL.Image.NEAREST``
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, size, interpolation=Image.BILINEAR, interpolation_tg = Image.NEAREST):
|
163 |
+
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
|
164 |
+
self.size = size
|
165 |
+
self.interpolation = interpolation
|
166 |
+
self.interpolation_tg = interpolation_tg
|
167 |
+
|
168 |
+
def __call__(self, img, target = None):
|
169 |
+
"""
|
170 |
+
Args:
|
171 |
+
img (PIL Image): Image to be scaled.
|
172 |
+
target (PIL Image): (optional) Target to be scaled
|
173 |
+
Returns:
|
174 |
+
PIL Image: Rescaled image(s).
|
175 |
+
"""
|
176 |
+
if target is not None:
|
177 |
+
return F.resize(img, self.size, self.interpolation), F.resize(target, self.size, self.interpolation_tg)
|
178 |
+
return F.resize(img, self.size, self.interpolation)
|
179 |
+
|
180 |
+
def __repr__(self):
|
181 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
182 |
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
|
183 |
+
|
184 |
+
|
185 |
+
class Scale(Resize):
|
186 |
+
"""
|
187 |
+
Note: This transform is deprecated in favor of Resize.
|
188 |
+
"""
|
189 |
+
def __init__(self, *args, **kwargs):
|
190 |
+
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
|
191 |
+
"please use transforms.Resize instead.")
|
192 |
+
super(Scale, self).__init__(*args, **kwargs)
|
193 |
+
|
194 |
+
|
195 |
+
class CenterCrop(object):
|
196 |
+
"""Crops the given PIL Image at the center.
|
197 |
+
Args:
|
198 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
199 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
200 |
+
made.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, size):
|
204 |
+
if isinstance(size, numbers.Number):
|
205 |
+
self.size = (int(size), int(size))
|
206 |
+
else:
|
207 |
+
self.size = size
|
208 |
+
|
209 |
+
def __call__(self, img, target=None):
|
210 |
+
"""
|
211 |
+
Args:
|
212 |
+
img (PIL Image): Image to be cropped.
|
213 |
+
target (PIL Image): (optional) Target to be cropped
|
214 |
+
Returns:
|
215 |
+
PIL Image: Cropped image(s).
|
216 |
+
"""
|
217 |
+
if target is not None:
|
218 |
+
return F.center_crop(img, self.size), F.center_crop(target, self.size)
|
219 |
+
return F.center_crop(img, self.size)
|
220 |
+
|
221 |
+
def __repr__(self):
|
222 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
223 |
+
|
224 |
+
|
225 |
+
class Pad(object):
|
226 |
+
"""Pad the given PIL Image on all sides with the given "pad" value.
|
227 |
+
Args:
|
228 |
+
padding (int or tuple): Padding on each border. If a single int is provided this
|
229 |
+
is used to pad all borders. If tuple of length 2 is provided this is the padding
|
230 |
+
on left/right and top/bottom respectively. If a tuple of length 4 is provided
|
231 |
+
this is the padding for the left, top, right and bottom borders
|
232 |
+
respectively.
|
233 |
+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
234 |
+
length 3, it is used to fill R, G, B channels respectively.
|
235 |
+
This value is only used when the padding_mode is constant
|
236 |
+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
237 |
+
constant: pads with a constant value, this value is specified with fill
|
238 |
+
edge: pads with the last value at the edge of the image
|
239 |
+
reflect: pads with reflection of image (without repeating the last value on the edge)
|
240 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
241 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
242 |
+
symmetric: pads with reflection of image (repeating the last value on the edge)
|
243 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
244 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self, padding, fill=0, padding_mode='constant'):
|
248 |
+
assert isinstance(padding, (numbers.Number, tuple))
|
249 |
+
assert isinstance(fill, (numbers.Number, str, tuple))
|
250 |
+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
251 |
+
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
|
252 |
+
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
|
253 |
+
"{} element tuple".format(len(padding)))
|
254 |
+
|
255 |
+
self.padding = padding
|
256 |
+
self.fill = fill
|
257 |
+
self.padding_mode = padding_mode
|
258 |
+
|
259 |
+
def __call__(self, img):
|
260 |
+
"""
|
261 |
+
Args:
|
262 |
+
img (PIL Image): Image to be padded.
|
263 |
+
Returns:
|
264 |
+
PIL Image: Padded image.
|
265 |
+
"""
|
266 |
+
return F.pad(img, self.padding, self.fill, self.padding_mode)
|
267 |
+
|
268 |
+
def __repr__(self):
|
269 |
+
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
|
270 |
+
format(self.padding, self.fill, self.padding_mode)
|
271 |
+
|
272 |
+
|
273 |
+
class Lambda(object):
|
274 |
+
"""Apply a user-defined lambda as a transform.
|
275 |
+
Args:
|
276 |
+
lambd (function): Lambda/function to be used for transform.
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(self, lambd):
|
280 |
+
assert isinstance(lambd, types.LambdaType)
|
281 |
+
self.lambd = lambd
|
282 |
+
|
283 |
+
def __call__(self, img):
|
284 |
+
return self.lambd(img)
|
285 |
+
|
286 |
+
def __repr__(self):
|
287 |
+
return self.__class__.__name__ + '()'
|
288 |
+
|
289 |
+
|
290 |
+
class RandomTransforms(object):
|
291 |
+
"""Base class for a list of transformations with randomness
|
292 |
+
Args:
|
293 |
+
transforms (list or tuple): list of transformations
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(self, transforms):
|
297 |
+
assert isinstance(transforms, (list, tuple))
|
298 |
+
self.transforms = transforms
|
299 |
+
|
300 |
+
def __call__(self, *args, **kwargs):
|
301 |
+
raise NotImplementedError()
|
302 |
+
|
303 |
+
def __repr__(self):
|
304 |
+
format_string = self.__class__.__name__ + '('
|
305 |
+
for t in self.transforms:
|
306 |
+
format_string += '\n'
|
307 |
+
format_string += ' {0}'.format(t)
|
308 |
+
format_string += '\n)'
|
309 |
+
return format_string
|
310 |
+
|
311 |
+
|
312 |
+
class RandomApply(RandomTransforms):
|
313 |
+
"""Apply randomly a list of transformations with a given probability
|
314 |
+
Args:
|
315 |
+
transforms (list or tuple): list of transformations
|
316 |
+
p (float): probability
|
317 |
+
"""
|
318 |
+
|
319 |
+
def __init__(self, transforms, p=0.5):
|
320 |
+
super(RandomApply, self).__init__(transforms)
|
321 |
+
self.p = p
|
322 |
+
|
323 |
+
def __call__(self, img):
|
324 |
+
if self.p < random.random():
|
325 |
+
return img
|
326 |
+
for t in self.transforms:
|
327 |
+
img = t(img)
|
328 |
+
return img
|
329 |
+
|
330 |
+
def __repr__(self):
|
331 |
+
format_string = self.__class__.__name__ + '('
|
332 |
+
format_string += '\n p={}'.format(self.p)
|
333 |
+
for t in self.transforms:
|
334 |
+
format_string += '\n'
|
335 |
+
format_string += ' {0}'.format(t)
|
336 |
+
format_string += '\n)'
|
337 |
+
return format_string
|
338 |
+
|
339 |
+
|
340 |
+
class RandomOrder(RandomTransforms):
|
341 |
+
"""Apply a list of transformations in a random order
|
342 |
+
"""
|
343 |
+
def __call__(self, img):
|
344 |
+
order = list(range(len(self.transforms)))
|
345 |
+
random.shuffle(order)
|
346 |
+
for i in order:
|
347 |
+
img = self.transforms[i](img)
|
348 |
+
return img
|
349 |
+
|
350 |
+
|
351 |
+
class RandomChoice(RandomTransforms):
|
352 |
+
"""Apply single transformation randomly picked from a list
|
353 |
+
"""
|
354 |
+
def __call__(self, img):
|
355 |
+
t = random.choice(self.transforms)
|
356 |
+
return t(img)
|
357 |
+
|
358 |
+
|
359 |
+
class RandomCrop(object):
|
360 |
+
"""Crop the given PIL Image at a random location.
|
361 |
+
Args:
|
362 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
363 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
364 |
+
made.
|
365 |
+
padding (int or sequence, optional): Optional padding on each border
|
366 |
+
of the image. Default is 0, i.e no padding. If a sequence of length
|
367 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
368 |
+
respectively.
|
369 |
+
pad_if_needed (boolean): It will pad the image if smaller than the
|
370 |
+
desired size to avoid raising an exception.
|
371 |
+
"""
|
372 |
+
|
373 |
+
def __init__(self, size, padding=0, pad_if_needed=False):
|
374 |
+
if isinstance(size, numbers.Number):
|
375 |
+
self.size = (int(size), int(size))
|
376 |
+
else:
|
377 |
+
self.size = size
|
378 |
+
self.padding = padding
|
379 |
+
self.pad_if_needed = pad_if_needed
|
380 |
+
|
381 |
+
@staticmethod
|
382 |
+
def get_params(img, output_size):
|
383 |
+
"""Get parameters for ``crop`` for a random crop.
|
384 |
+
Args:
|
385 |
+
img (PIL Image): Image to be cropped.
|
386 |
+
output_size (tuple): Expected output size of the crop.
|
387 |
+
Returns:
|
388 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
389 |
+
"""
|
390 |
+
w, h = img.size
|
391 |
+
th, tw = output_size
|
392 |
+
if w == tw and h == th:
|
393 |
+
return 0, 0, h, w
|
394 |
+
|
395 |
+
i = random.randint(0, h - th)
|
396 |
+
j = random.randint(0, w - tw)
|
397 |
+
return i, j, th, tw
|
398 |
+
|
399 |
+
def __call__(self, img, target = None):
|
400 |
+
"""
|
401 |
+
Args:
|
402 |
+
img (PIL Image): Image to be cropped.
|
403 |
+
target (PIL Image): (optional) Target to be cropped
|
404 |
+
Returns:
|
405 |
+
PIL Images: Cropped image(s).
|
406 |
+
"""
|
407 |
+
if self.padding > 0:
|
408 |
+
img = F.pad(img, self.padding)
|
409 |
+
if target is not None:
|
410 |
+
target = F.pad(target, self.padding)
|
411 |
+
|
412 |
+
# pad the width if needed
|
413 |
+
if self.pad_if_needed and img.size[0] < self.size[1]:
|
414 |
+
img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
|
415 |
+
if target is not None:
|
416 |
+
target = F.pad(target, (int((1 + self.size[1] - target.size[0]) / 2), 0))
|
417 |
+
# pad the height if needed
|
418 |
+
if self.pad_if_needed and img.size[1] < self.size[0]:
|
419 |
+
img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
|
420 |
+
if target is not None:
|
421 |
+
target = F.pad(target, (0, int((1 + self.size[0] - target.size[1]) / 2)))
|
422 |
+
i, j, h, w = self.get_params(img, self.size)
|
423 |
+
|
424 |
+
if target is not None:
|
425 |
+
return F.crop(img, i, j, h, w), F.crop(target, i, j, h, w)
|
426 |
+
else:
|
427 |
+
return F.crop(img, i, j, h, w)
|
428 |
+
|
429 |
+
def __repr__(self):
|
430 |
+
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
|
431 |
+
|
432 |
+
|
433 |
+
class RandomHorizontalFlip(object):
|
434 |
+
"""Horizontally flip the given PIL Image randomly with a given probability.
|
435 |
+
Args:
|
436 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
437 |
+
"""
|
438 |
+
|
439 |
+
def __init__(self, p=0.5):
|
440 |
+
self.p = p
|
441 |
+
|
442 |
+
def __call__(self, img, target=None):
|
443 |
+
"""
|
444 |
+
Args:
|
445 |
+
img (PIL Image): Image to be flipped.
|
446 |
+
target (PIL Image): (optional) Target to be flipped
|
447 |
+
Returns:
|
448 |
+
PIL Image: Randomly flipped image(s).
|
449 |
+
"""
|
450 |
+
if random.random() < self.p:
|
451 |
+
if target is not None:
|
452 |
+
return F.hflip(img), F.hflip(target)
|
453 |
+
else:
|
454 |
+
return F.hflip(img)
|
455 |
+
|
456 |
+
if target is not None:
|
457 |
+
return img, target
|
458 |
+
return img
|
459 |
+
|
460 |
+
def __repr__(self):
|
461 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
462 |
+
|
463 |
+
|
464 |
+
class RandomVerticalFlip(object):
|
465 |
+
"""Vertically flip the given PIL Image randomly with a given probability.
|
466 |
+
Args:
|
467 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
468 |
+
"""
|
469 |
+
|
470 |
+
def __init__(self, p=0.5):
|
471 |
+
self.p = p
|
472 |
+
|
473 |
+
def __call__(self, img, target=None):
|
474 |
+
"""
|
475 |
+
Args:
|
476 |
+
img (PIL Image): Image to be flipped.
|
477 |
+
target (PIL Image): (optional) Target to be flipped
|
478 |
+
Returns:
|
479 |
+
PIL Image: Randomly flipped image(s).
|
480 |
+
"""
|
481 |
+
if random.random() < self.p:
|
482 |
+
if target is not None:
|
483 |
+
return F.vflip(img), F.vflip(target)
|
484 |
+
else:
|
485 |
+
return F.vflip(img)
|
486 |
+
|
487 |
+
if target is not None:
|
488 |
+
return img, target
|
489 |
+
return img
|
490 |
+
|
491 |
+
def __repr__(self):
|
492 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
493 |
+
|
494 |
+
|
495 |
+
class RandomResizedCrop(object):
|
496 |
+
"""Crop the given PIL Image to random size and aspect ratio.
|
497 |
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
498 |
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
499 |
+
is finally resized to given size.
|
500 |
+
This is popularly used to train the Inception networks.
|
501 |
+
Args:
|
502 |
+
size: expected output size of each edge
|
503 |
+
scale: range of size of the origin size cropped
|
504 |
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
505 |
+
interpolation: Default: PIL.Image.BILINEAR
|
506 |
+
"""
|
507 |
+
|
508 |
+
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
509 |
+
interpolation=Image.BILINEAR, interpolation_tg = Image.NEAREST):
|
510 |
+
self.size = (size, size)
|
511 |
+
self.interpolation = interpolation
|
512 |
+
self.interpolation_tg = interpolation_tg
|
513 |
+
self.scale = scale
|
514 |
+
self.ratio = ratio
|
515 |
+
|
516 |
+
@staticmethod
|
517 |
+
def get_params(img, scale, ratio):
|
518 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
519 |
+
Args:
|
520 |
+
img (PIL Image): Image to be cropped.
|
521 |
+
scale (tuple): range of size of the origin size cropped
|
522 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
523 |
+
Returns:
|
524 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
525 |
+
sized crop.
|
526 |
+
"""
|
527 |
+
for attempt in range(10):
|
528 |
+
area = img.size[0] * img.size[1]
|
529 |
+
target_area = random.uniform(*scale) * area
|
530 |
+
aspect_ratio = random.uniform(*ratio)
|
531 |
+
|
532 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
533 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
534 |
+
|
535 |
+
if random.random() < 0.5:
|
536 |
+
w, h = h, w
|
537 |
+
|
538 |
+
if w <= img.size[0] and h <= img.size[1]:
|
539 |
+
i = random.randint(0, img.size[1] - h)
|
540 |
+
j = random.randint(0, img.size[0] - w)
|
541 |
+
return i, j, h, w
|
542 |
+
|
543 |
+
# Fallback
|
544 |
+
w = min(img.size[0], img.size[1])
|
545 |
+
i = (img.size[1] - w) // 2
|
546 |
+
j = (img.size[0] - w) // 2
|
547 |
+
return i, j, w, w
|
548 |
+
|
549 |
+
def __call__(self, img, target = None):
|
550 |
+
"""
|
551 |
+
Args:
|
552 |
+
img (PIL Image): Image to be cropped and resized.
|
553 |
+
target (PIL Image): (optional) Target to be cropped and resized.
|
554 |
+
Returns:
|
555 |
+
PIL Image: Randomly cropped and resized image(s).
|
556 |
+
"""
|
557 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
558 |
+
if target is not None:
|
559 |
+
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
|
560 |
+
F.resized_crop(target, i, j, h, w, self.size, self.interpolation_tg)
|
561 |
+
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
|
562 |
+
|
563 |
+
def __repr__(self):
|
564 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
565 |
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
566 |
+
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
567 |
+
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
568 |
+
format_string += ', interpolation={0})'.format(interpolate_str)
|
569 |
+
return format_string
|
570 |
+
|
571 |
+
|
572 |
+
class RandomSizedCrop(RandomResizedCrop):
|
573 |
+
"""
|
574 |
+
Note: This transform is deprecated in favor of RandomResizedCrop.
|
575 |
+
"""
|
576 |
+
def __init__(self, *args, **kwargs):
|
577 |
+
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
|
578 |
+
"please use transforms.RandomResizedCrop instead.")
|
579 |
+
super(RandomSizedCrop, self).__init__(*args, **kwargs)
|
580 |
+
|
581 |
+
|
582 |
+
class FiveCrop(object):
|
583 |
+
"""Crop the given PIL Image into four corners and the central crop
|
584 |
+
.. Note::
|
585 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
586 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
587 |
+
this.
|
588 |
+
Args:
|
589 |
+
size (sequence or int): Desired output size of the crop. If size is an ``int``
|
590 |
+
instead of sequence like (h, w), a square crop of size (size, size) is made.
|
591 |
+
Example:
|
592 |
+
>>> transform = Compose([
|
593 |
+
>>> FiveCrop(size), # this is a list of PIL Images
|
594 |
+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
|
595 |
+
>>> ])
|
596 |
+
>>> #In your test loop you can do the following:
|
597 |
+
>>> input, target = batch # input is a 5d tensor, target is 2d
|
598 |
+
>>> bs, ncrops, c, h, w = input.size()
|
599 |
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
|
600 |
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
|
601 |
+
"""
|
602 |
+
|
603 |
+
def __init__(self, size):
|
604 |
+
self.size = size
|
605 |
+
if isinstance(size, numbers.Number):
|
606 |
+
self.size = (int(size), int(size))
|
607 |
+
else:
|
608 |
+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
609 |
+
self.size = size
|
610 |
+
|
611 |
+
def __call__(self, img, target=None):
|
612 |
+
if target is not None:
|
613 |
+
return F.five_crop(img, self.size), F.five_crop(target, self.size)
|
614 |
+
return F.five_crop(img, self.size)
|
615 |
+
|
616 |
+
def __repr__(self):
|
617 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
618 |
+
|
619 |
+
|
620 |
+
class TenCrop(object):
|
621 |
+
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of
|
622 |
+
these (horizontal flipping is used by default)
|
623 |
+
.. Note::
|
624 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
625 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
626 |
+
this.
|
627 |
+
Args:
|
628 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
629 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
630 |
+
made.
|
631 |
+
vertical_flip(bool): Use vertical flipping instead of horizontal
|
632 |
+
Example:
|
633 |
+
>>> transform = Compose([
|
634 |
+
>>> TenCrop(size), # this is a list of PIL Images
|
635 |
+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
|
636 |
+
>>> ])
|
637 |
+
>>> #In your test loop you can do the following:
|
638 |
+
>>> input, target = batch # input is a 5d tensor, target is 2d
|
639 |
+
>>> bs, ncrops, c, h, w = input.size()
|
640 |
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
|
641 |
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
|
642 |
+
"""
|
643 |
+
|
644 |
+
def __init__(self, size, vertical_flip=False):
|
645 |
+
self.size = size
|
646 |
+
if isinstance(size, numbers.Number):
|
647 |
+
self.size = (int(size), int(size))
|
648 |
+
else:
|
649 |
+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
650 |
+
self.size = size
|
651 |
+
self.vertical_flip = vertical_flip
|
652 |
+
|
653 |
+
def __call__(self, img, target = None):
|
654 |
+
if target is not None:
|
655 |
+
return F.ten_crop(img, self.size), F.ten_crop(target, self.size)
|
656 |
+
return F.ten_crop(img, self.size, self.vertical_flip)
|
657 |
+
|
658 |
+
def __repr__(self):
|
659 |
+
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
|
660 |
+
|
661 |
+
|
662 |
+
class LinearTransformation(object):
|
663 |
+
"""Transform a tensor image with a square transformation matrix computed
|
664 |
+
offline.
|
665 |
+
Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
|
666 |
+
product with the transformation matrix and reshape the tensor to its
|
667 |
+
original shape.
|
668 |
+
Applications:
|
669 |
+
- whitening: zero-center the data, compute the data covariance matrix
|
670 |
+
[D x D] with np.dot(X.T, X), perform SVD on this matrix and
|
671 |
+
pass it as transformation_matrix.
|
672 |
+
Args:
|
673 |
+
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
|
674 |
+
"""
|
675 |
+
|
676 |
+
def __init__(self, transformation_matrix):
|
677 |
+
if transformation_matrix.size(0) != transformation_matrix.size(1):
|
678 |
+
raise ValueError("transformation_matrix should be square. Got " +
|
679 |
+
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
|
680 |
+
self.transformation_matrix = transformation_matrix
|
681 |
+
|
682 |
+
def __call__(self, tensor, target_tensor=None):
|
683 |
+
"""
|
684 |
+
Args:
|
685 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
|
686 |
+
Returns:
|
687 |
+
Tensor: Transformed image.
|
688 |
+
"""
|
689 |
+
if target_tensor is not None:
|
690 |
+
raise NotImplementedError("LinearTransformation not implemented for tensor pairs.")
|
691 |
+
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
|
692 |
+
raise ValueError("tensor and transformation matrix have incompatible shape." +
|
693 |
+
"[{} x {} x {}] != ".format(*tensor.size()) +
|
694 |
+
"{}".format(self.transformation_matrix.size(0)))
|
695 |
+
flat_tensor = tensor.view(1, -1)
|
696 |
+
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
|
697 |
+
tensor = transformed_tensor.view(tensor.size())
|
698 |
+
return tensor
|
699 |
+
|
700 |
+
def __repr__(self):
|
701 |
+
format_string = self.__class__.__name__ + '('
|
702 |
+
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
|
703 |
+
return format_string
|
704 |
+
|
705 |
+
|
706 |
+
class ColorJitter(object):
|
707 |
+
"""Randomly change the brightness, contrast and saturation of an image.
|
708 |
+
Args:
|
709 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
710 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
711 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
712 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
713 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
714 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
715 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
716 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
717 |
+
"""
|
718 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
719 |
+
self.brightness = brightness
|
720 |
+
self.contrast = contrast
|
721 |
+
self.saturation = saturation
|
722 |
+
self.hue = hue
|
723 |
+
|
724 |
+
@staticmethod
|
725 |
+
def get_params(brightness, contrast, saturation, hue):
|
726 |
+
"""Get a randomized transform to be applied on image.
|
727 |
+
Arguments are same as that of __init__.
|
728 |
+
Returns:
|
729 |
+
Transform which randomly adjusts brightness, contrast and
|
730 |
+
saturation in a random order.
|
731 |
+
"""
|
732 |
+
transforms = []
|
733 |
+
if brightness > 0:
|
734 |
+
brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
|
735 |
+
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
|
736 |
+
|
737 |
+
if contrast > 0:
|
738 |
+
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
|
739 |
+
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
|
740 |
+
|
741 |
+
if saturation > 0:
|
742 |
+
saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
|
743 |
+
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
|
744 |
+
|
745 |
+
if hue > 0:
|
746 |
+
hue_factor = random.uniform(-hue, hue)
|
747 |
+
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
748 |
+
|
749 |
+
random.shuffle(transforms)
|
750 |
+
transform = Compose(transforms)
|
751 |
+
|
752 |
+
return transform
|
753 |
+
|
754 |
+
def __call__(self, img, target = None):
|
755 |
+
"""
|
756 |
+
Args:
|
757 |
+
img (PIL Image): Input image.
|
758 |
+
Returns:
|
759 |
+
PIL Image: Color jittered image.
|
760 |
+
"""
|
761 |
+
transform = self.get_params(self.brightness, self.contrast,
|
762 |
+
self.saturation, self.hue)
|
763 |
+
|
764 |
+
if target is not None:
|
765 |
+
return transform(img), target
|
766 |
+
return transform(img)
|
767 |
+
|
768 |
+
def __repr__(self):
|
769 |
+
format_string = self.__class__.__name__ + '('
|
770 |
+
format_string += 'brightness={0}'.format(self.brightness)
|
771 |
+
format_string += ', contrast={0}'.format(self.contrast)
|
772 |
+
format_string += ', saturation={0}'.format(self.saturation)
|
773 |
+
format_string += ', hue={0})'.format(self.hue)
|
774 |
+
return format_string
|
775 |
+
|
776 |
+
|
777 |
+
class RandomRotation(object):
|
778 |
+
"""Rotate the image by angle.
|
779 |
+
Args:
|
780 |
+
degrees (sequence or float or int): Range of degrees to select from.
|
781 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
782 |
+
will be (-degrees, +degrees).
|
783 |
+
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
784 |
+
An optional resampling filter.
|
785 |
+
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
|
786 |
+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
787 |
+
expand (bool, optional): Optional expansion flag.
|
788 |
+
If true, expands the output to make it large enough to hold the entire rotated image.
|
789 |
+
If false or omitted, make the output image the same size as the input image.
|
790 |
+
Note that the expand flag assumes rotation around the center and no translation.
|
791 |
+
center (2-tuple, optional): Optional center of rotation.
|
792 |
+
Origin is the upper left corner.
|
793 |
+
Default is the center of the image.
|
794 |
+
"""
|
795 |
+
|
796 |
+
def __init__(self, degrees, resample=False, resample_tg=False, expand=False, center=None):
|
797 |
+
if isinstance(degrees, numbers.Number):
|
798 |
+
if degrees < 0:
|
799 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
800 |
+
self.degrees = (-degrees, degrees)
|
801 |
+
else:
|
802 |
+
if len(degrees) != 2:
|
803 |
+
raise ValueError("If degrees is a sequence, it must be of len 2.")
|
804 |
+
self.degrees = degrees
|
805 |
+
|
806 |
+
self.resample = resample
|
807 |
+
self.resample_tg = resample_tg
|
808 |
+
self.expand = expand
|
809 |
+
self.center = center
|
810 |
+
|
811 |
+
@staticmethod
|
812 |
+
def get_params(degrees):
|
813 |
+
"""Get parameters for ``rotate`` for a random rotation.
|
814 |
+
Returns:
|
815 |
+
sequence: params to be passed to ``rotate`` for random rotation.
|
816 |
+
"""
|
817 |
+
angle = random.uniform(degrees[0], degrees[1])
|
818 |
+
|
819 |
+
return angle
|
820 |
+
|
821 |
+
def __call__(self, img, target=None):
|
822 |
+
"""
|
823 |
+
img (PIL Image): Image to be rotated.
|
824 |
+
target (PIL Image): (optional) Target to be rotated
|
825 |
+
Returns:
|
826 |
+
PIL Image: Rotated image(s).
|
827 |
+
"""
|
828 |
+
|
829 |
+
angle = self.get_params(self.degrees)
|
830 |
+
|
831 |
+
if target is not None:
|
832 |
+
return F.rotate(img, angle, self.resample, self.expand, self.center), \
|
833 |
+
F.rotate(target, angle, self.resample_tg, self.expand, self.center)
|
834 |
+
# resample = False is by default nearest, appropriate for targets
|
835 |
+
|
836 |
+
def __repr__(self):
|
837 |
+
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
|
838 |
+
format_string += ', resample={0}'.format(self.resample)
|
839 |
+
format_string += ', expand={0}'.format(self.expand)
|
840 |
+
if self.center is not None:
|
841 |
+
format_string += ', center={0}'.format(self.center)
|
842 |
+
format_string += ')'
|
843 |
+
return format_string
|
844 |
+
|
845 |
+
|
846 |
+
class RandomAffine(object):
|
847 |
+
"""Random affine transformation of the image keeping center invariant
|
848 |
+
Args:
|
849 |
+
degrees (sequence or float or int): Range of degrees to select from.
|
850 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
851 |
+
will be (-degrees, +degrees). Set to 0 to desactivate rotations.
|
852 |
+
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
853 |
+
and vertical translations. For example translate=(a, b), then horizontal shift
|
854 |
+
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
855 |
+
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
856 |
+
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
857 |
+
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
858 |
+
shear (sequence or float or int, optional): Range of degrees to select from.
|
859 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
860 |
+
will be (-degrees, +degrees). Will not apply shear by default
|
861 |
+
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
862 |
+
An optional resampling filter.
|
863 |
+
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
|
864 |
+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
865 |
+
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
|
866 |
+
"""
|
867 |
+
|
868 |
+
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, resample_tg=False, fillcolor=0):
|
869 |
+
if isinstance(degrees, numbers.Number):
|
870 |
+
if degrees < 0:
|
871 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
872 |
+
self.degrees = (-degrees, degrees)
|
873 |
+
else:
|
874 |
+
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
|
875 |
+
"degrees should be a list or tuple and it must be of length 2."
|
876 |
+
self.degrees = degrees
|
877 |
+
|
878 |
+
if translate is not None:
|
879 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
880 |
+
"translate should be a list or tuple and it must be of length 2."
|
881 |
+
for t in translate:
|
882 |
+
if not (0.0 <= t <= 1.0):
|
883 |
+
raise ValueError("translation values should be between 0 and 1")
|
884 |
+
self.translate = translate
|
885 |
+
|
886 |
+
if scale is not None:
|
887 |
+
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
888 |
+
"scale should be a list or tuple and it must be of length 2."
|
889 |
+
for s in scale:
|
890 |
+
if s <= 0:
|
891 |
+
raise ValueError("scale values should be positive")
|
892 |
+
self.scale = scale
|
893 |
+
|
894 |
+
if shear is not None:
|
895 |
+
if isinstance(shear, numbers.Number):
|
896 |
+
if shear < 0:
|
897 |
+
raise ValueError("If shear is a single number, it must be positive.")
|
898 |
+
self.shear = (-shear, shear)
|
899 |
+
else:
|
900 |
+
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
|
901 |
+
"shear should be a list or tuple and it must be of length 2."
|
902 |
+
self.shear = shear
|
903 |
+
else:
|
904 |
+
self.shear = shear
|
905 |
+
|
906 |
+
self.resample = resample
|
907 |
+
self.resample_tg = resample_tg
|
908 |
+
self.fillcolor = fillcolor
|
909 |
+
|
910 |
+
@staticmethod
|
911 |
+
def get_params(degrees, translate, scale_ranges, shears, img_size):
|
912 |
+
"""Get parameters for affine transformation
|
913 |
+
Returns:
|
914 |
+
sequence: params to be passed to the affine transformation
|
915 |
+
"""
|
916 |
+
angle = random.uniform(degrees[0], degrees[1])
|
917 |
+
if translate is not None:
|
918 |
+
max_dx = translate[0] * img_size[0]
|
919 |
+
max_dy = translate[1] * img_size[1]
|
920 |
+
translations = (np.round(random.uniform(-max_dx, max_dx)),
|
921 |
+
np.round(random.uniform(-max_dy, max_dy)))
|
922 |
+
else:
|
923 |
+
translations = (0, 0)
|
924 |
+
|
925 |
+
if scale_ranges is not None:
|
926 |
+
scale = random.uniform(scale_ranges[0], scale_ranges[1])
|
927 |
+
else:
|
928 |
+
scale = 1.0
|
929 |
+
|
930 |
+
if shears is not None:
|
931 |
+
shear = random.uniform(shears[0], shears[1])
|
932 |
+
else:
|
933 |
+
shear = 0.0
|
934 |
+
|
935 |
+
return angle, translations, scale, shear
|
936 |
+
|
937 |
+
def __call__(self, img, target=None):
|
938 |
+
"""
|
939 |
+
img (PIL Image): Image to be rotated.
|
940 |
+
target (PIL Image): (optional) Target to be rotated
|
941 |
+
Returns:
|
942 |
+
PIL Image: Rotated image(s).
|
943 |
+
"""
|
944 |
+
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
|
945 |
+
if target is not None:
|
946 |
+
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor), \
|
947 |
+
F.affine(target, *ret, resample=self.resample_tg, fillcolor=self.fillcolor)
|
948 |
+
# resample = False is by default nearest, appropriate for targets
|
949 |
+
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
|
950 |
+
|
951 |
+
def __repr__(self):
|
952 |
+
s = '{name}(degrees={degrees}'
|
953 |
+
if self.translate is not None:
|
954 |
+
s += ', translate={translate}'
|
955 |
+
if self.scale is not None:
|
956 |
+
s += ', scale={scale}'
|
957 |
+
if self.shear is not None:
|
958 |
+
s += ', shear={shear}'
|
959 |
+
if self.resample > 0:
|
960 |
+
s += ', resample={resample}'
|
961 |
+
if self.fillcolor != 0:
|
962 |
+
s += ', fillcolor={fillcolor}'
|
963 |
+
s += ')'
|
964 |
+
d = dict(self.__dict__)
|
965 |
+
d['resample'] = _pil_interpolation_to_str[d['resample']]
|
966 |
+
return s.format(name=self.__class__.__name__, **d)
|
967 |
+
|
968 |
+
|
969 |
+
class Grayscale(object):
|
970 |
+
"""Convert image to grayscale.
|
971 |
+
Args:
|
972 |
+
num_output_channels (int): (1 or 3) number of channels desired for output image
|
973 |
+
Returns:
|
974 |
+
PIL Image: Grayscale version of the input.
|
975 |
+
- If num_output_channels == 1 : returned image is single channel
|
976 |
+
- If num_output_channels == 3 : returned image is 3 channel with r == g == b
|
977 |
+
"""
|
978 |
+
|
979 |
+
def __init__(self, num_output_channels=1):
|
980 |
+
self.num_output_channels = num_output_channels
|
981 |
+
|
982 |
+
def __call__(self, img, target = None):
|
983 |
+
"""
|
984 |
+
Args:
|
985 |
+
img (PIL Image): Image to be converted to grayscale.
|
986 |
+
Returns:
|
987 |
+
PIL Image: Randomly grayscaled image.
|
988 |
+
"""
|
989 |
+
if target is not None:
|
990 |
+
return F.to_grayscale(img, num_output_channels=self.num_output_channels), target
|
991 |
+
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
|
992 |
+
|
993 |
+
def __repr__(self):
|
994 |
+
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
|
995 |
+
|
996 |
+
|
997 |
+
class RandomGrayscale(object):
|
998 |
+
"""Randomly convert image to grayscale with a probability of p (default 0.1).
|
999 |
+
Args:
|
1000 |
+
p (float): probability that image should be converted to grayscale.
|
1001 |
+
Returns:
|
1002 |
+
PIL Image: Grayscale version of the input image with probability p and unchanged
|
1003 |
+
with probability (1-p).
|
1004 |
+
- If input image is 1 channel: grayscale version is 1 channel
|
1005 |
+
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
|
1006 |
+
"""
|
1007 |
+
|
1008 |
+
def __init__(self, p=0.1):
|
1009 |
+
self.p = p
|
1010 |
+
|
1011 |
+
def __call__(self, img, target = None):
|
1012 |
+
"""
|
1013 |
+
Args:
|
1014 |
+
img (PIL Image): Image to be converted to grayscale.
|
1015 |
+
Returns:
|
1016 |
+
PIL Image: Randomly grayscaled image.
|
1017 |
+
"""
|
1018 |
+
num_output_channels = 1 if img.mode == 'L' else 3
|
1019 |
+
if random.random() < self.p:
|
1020 |
+
if target is not None:
|
1021 |
+
return F.to_grayscale(img, num_output_channels=num_output_channels), target
|
1022 |
+
if target is not None:
|
1023 |
+
return img, target
|
1024 |
+
return img
|
1025 |
+
|
1026 |
+
def __repr__(self):
|
1027 |
+
return self.__class__.__name__ + '(p={0})'.format(self.p)
|
dependecies/segroot/paired_weight_vgg16.plk
ADDED
Binary file (3.22 kB). View file
|
|
dependecies/segroot/predict_imgs.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from skimage.morphology import erosion
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import time
|
9 |
+
|
10 |
+
from segroot.utils import init_weights
|
11 |
+
from segroot.dataloader import pad_pair_256, normalize
|
12 |
+
from segroot.model import SegRoot
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"--image", default="test.jpg", type=str, help="filename of one test image"
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--thres", default=0.9, type=float, help="threshold of the final binarization"
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--all", action="store_true", help="make prediction on all images in the folder"
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--data_dir",
|
26 |
+
default="../data/prediction",
|
27 |
+
type=Path,
|
28 |
+
help="define the data directory",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--weights",
|
32 |
+
default="../weights/best_segnet-(8,5)-0.6441.pt",
|
33 |
+
type=Path,
|
34 |
+
help="path of pretrained weights",
|
35 |
+
)
|
36 |
+
parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
|
37 |
+
parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
|
38 |
+
|
39 |
+
|
40 |
+
def pad_256(img_path):
|
41 |
+
image = Image.open(img_path)
|
42 |
+
W, H = image.size
|
43 |
+
img, _ = pad_pair_256(image, image)
|
44 |
+
NW, NH = img.size
|
45 |
+
img = torchvision.transforms.ToTensor()(img)
|
46 |
+
img = normalize(img)
|
47 |
+
return img, (H, W, NH, NW)
|
48 |
+
|
49 |
+
|
50 |
+
def predict(model, test_img, device):
|
51 |
+
for p in model.parameters():
|
52 |
+
p.requires_grad = False
|
53 |
+
|
54 |
+
model.eval()
|
55 |
+
# test_img.shape = (3, 2304, 2560)
|
56 |
+
test_img = test_img.unsqueeze(0)
|
57 |
+
output = model(test_img)
|
58 |
+
# output.shape = (1, 1, 2304, 2560)
|
59 |
+
output = torch.squeeze(output)
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
def predict_gen(model, img_path, thres, device, info):
|
65 |
+
img, dims = pad_256(img_path)
|
66 |
+
H, W, NH, NW = dims
|
67 |
+
img = img.to(device)
|
68 |
+
prediction = predict(model, img, device)
|
69 |
+
prediction[prediction >= thres] = 1.0
|
70 |
+
prediction[prediction < thres] = 0.0
|
71 |
+
if device.type == "cpu":
|
72 |
+
prediction = prediction.detach().numpy()
|
73 |
+
else:
|
74 |
+
prediction = prediction.cpu().detach().numpy()
|
75 |
+
prediction = erosion(prediction)
|
76 |
+
# reverse padding
|
77 |
+
prediction = prediction[
|
78 |
+
(NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W
|
79 |
+
]
|
80 |
+
save_path = img_path.parent / (
|
81 |
+
img_path.parts[-1].split(".jpg")[0] + "-pre-mask-segnet-({},5).jpg".format(info)
|
82 |
+
)
|
83 |
+
plt.imsave(save_path.as_posix(), prediction, cmap="gray")
|
84 |
+
print("{} generated!".format(save_path.parts[-1]))
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
args = parser.parse_args()
|
89 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
90 |
+
# define model
|
91 |
+
print("using segnet, width : {}, depth : {}".format(args.width, args.depth))
|
92 |
+
model = SegRoot(args.width, args.depth).to(device)
|
93 |
+
weights_path = args.weights
|
94 |
+
|
95 |
+
if device.type == "cpu":
|
96 |
+
print("load weights to cpu")
|
97 |
+
print(weights_path.as_posix())
|
98 |
+
model.load_state_dict(torch.load(weights_path.as_posix(), map_location="cpu"))
|
99 |
+
else:
|
100 |
+
print("load weights to gpu")
|
101 |
+
print(weights_path.as_posix())
|
102 |
+
model.load_state_dict(torch.load(weights_path.as_posix()))
|
103 |
+
|
104 |
+
# define the prediction's saving directory
|
105 |
+
pre_dir = Path("../data/prediction")
|
106 |
+
pre_dir.mkdir(parents=True, exist_ok=True)
|
107 |
+
if not args.all:
|
108 |
+
# load and pad image
|
109 |
+
img_path = pre_dir / args.image
|
110 |
+
start_time = time.time()
|
111 |
+
predict_gen(model, img_path, args.thres, device, 8)
|
112 |
+
end_time = time.time()
|
113 |
+
print("{:.4f}s for one image".format(end_time - start_time))
|
114 |
+
else:
|
115 |
+
img_paths = args.data_dir.glob("*.jpg")
|
116 |
+
for img_path in img_paths:
|
117 |
+
start_time = time.time()
|
118 |
+
predict_gen(model, img_path, args.thres, device, 8)
|
119 |
+
end_time = time.time()
|
120 |
+
print("{:.4f}s for one image".format(end_time - start_time))
|
121 |
+
|
dependecies/segroot/run_all_experiments.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/bin/sh
|
2 |
+
python -u train_segroot.py --width 2 > "log_SegRoot(2,5).txt"
|
3 |
+
python -u train_segroot.py --width 16 --depth 4 --lr 1e-3 > "log_SegRoot(16,4).txt"
|
4 |
+
python -u train_segroot.py --width 32 --depth 5 --lr 1e-4 --bs 32 > "log_SegRoot(32,5).txt"
|
5 |
+
python -u train_segroot.py --width 64 --depth 4 --lr 1e-4 --bs 16 > "log_SegRoot(64,4).txt"
|
6 |
+
python -u train_segroot.py --width 64 --depth 5 --lr 2e-5 --bs 8 --epochs 100 --verbose 2 > "log_SegRoot(64,5).txt"
|
dependecies/segroot/utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
from torchvision import models
|
4 |
+
import random
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
|
9 |
+
def set_random_seed(seed):
|
10 |
+
random.seed(seed)
|
11 |
+
np.random.seed(seed)
|
12 |
+
torch.manual_seed(seed)
|
13 |
+
torch.cuda.manual_seed(seed)
|
14 |
+
torch.backends.cudnn.deterministic = True
|
15 |
+
|
16 |
+
def set_logger(log_path):
|
17 |
+
logger = logging.getLogger()
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
|
20 |
+
if not logger.handlers:
|
21 |
+
# Logging to a file
|
22 |
+
file_handler = logging.FileHandler(log_path)
|
23 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
24 |
+
logger.addHandler(file_handler)
|
25 |
+
|
26 |
+
# Logging to console
|
27 |
+
stream_handler = logging.StreamHandler()
|
28 |
+
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
29 |
+
logger.addHandler(stream_handler)
|
30 |
+
|
31 |
+
def to_np(x):
|
32 |
+
return x.data.cpu().numpy()
|
33 |
+
|
34 |
+
def get_ids(length_dataset):
|
35 |
+
ids = list(range(length_dataset))
|
36 |
+
|
37 |
+
random.shuffle(ids)
|
38 |
+
train_split = round(0.6 * length_dataset)
|
39 |
+
t_v_spplit = (length_dataset - train_split) // 2
|
40 |
+
train_ids = ids[:train_split]
|
41 |
+
valid_ids = ids[train_split:train_split+t_v_spplit]
|
42 |
+
test_ids = ids[train_split+t_v_spplit:]
|
43 |
+
return train_ids, valid_ids, test_ids
|
44 |
+
|
45 |
+
def dice_score(y, y_pred, smooth=1.0, thres=0.9):
|
46 |
+
n = y.shape[0]
|
47 |
+
y = y.view(n, -1)
|
48 |
+
y_pred = y_pred.view(n, -1)
|
49 |
+
# y_pred_[y_pred>=thres] = 1.0
|
50 |
+
# y_pred_[y_pred<thres] = 0.0
|
51 |
+
num = 2 * torch.sum(y * y_pred, dim=1, keepdim=True) + smooth
|
52 |
+
den = torch.sum(y, dim=1, keepdim=True) + \
|
53 |
+
torch.sum(y_pred, dim=1, keepdim=True) + smooth
|
54 |
+
score = num / den
|
55 |
+
return score
|
56 |
+
|
57 |
+
def init_weights(m):
|
58 |
+
if isinstance(m, torch.nn.Conv2d):
|
59 |
+
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
|
60 |
+
# torch.nn.init.constant_(m.bias, 0)
|
61 |
+
elif isinstance(m, torch.nn.BatchNorm2d):
|
62 |
+
torch.nn.init.constant_(m.weight, 1)
|
63 |
+
|
64 |
+
def load_vgg16(segnet):
|
65 |
+
vgg16 = models.vgg16_bn(pretrained=True)
|
66 |
+
with open('paired_weight_vgg16.plk', 'rb') as handle:
|
67 |
+
paired = pickle.load(handle)
|
68 |
+
segnet_p = dict(segnet.state_dict())
|
69 |
+
vgg16_p = vgg16.state_dict()
|
70 |
+
|
71 |
+
for k, v in paired.items():
|
72 |
+
for n, p in vgg16_p.items():
|
73 |
+
if n == v:
|
74 |
+
segnet_p[k].data.copy_(p.data)
|
75 |
+
segnet.load_state_dict(segnet_p)
|
76 |
+
return segnet
|
77 |
+
|
78 |
+
def train_one_epoch(model, train_iter, optimizer, device):
|
79 |
+
model.train()
|
80 |
+
for p in model.parameters():
|
81 |
+
p.requires_grad = True
|
82 |
+
for x, y in train_iter:
|
83 |
+
x, y = x.to(device), y.to(device)
|
84 |
+
bs = x.shape[0]
|
85 |
+
optimizer.zero_grad()
|
86 |
+
y_pred = model(x)
|
87 |
+
loss = 1 - dice_score(y, y_pred)
|
88 |
+
loss = torch.sum(loss) / bs
|
89 |
+
loss.backward()
|
90 |
+
optimizer.step()
|
91 |
+
|
92 |
+
def evaluate(model, dataset, device, thres=0.9):
|
93 |
+
model.eval()
|
94 |
+
torch.cuda.empty_cache()
|
95 |
+
num, den = 0, 0
|
96 |
+
# shutdown the autograd
|
97 |
+
with torch.no_grad():
|
98 |
+
for i in range(len(dataset)):
|
99 |
+
x, y = dataset[i]
|
100 |
+
x, y = x.unsqueeze(0).to(device), y.unsqueeze(0).to(device)
|
101 |
+
y_pred = model(x)
|
102 |
+
y = y.cpu().detach().numpy()
|
103 |
+
y_pred = y_pred.cpu().detach().numpy()
|
104 |
+
y_pred[y_pred>=thres] = 1.0
|
105 |
+
y_pred[y_pred<thres] = 0.0
|
106 |
+
num += 2 * (y_pred * y).sum()
|
107 |
+
den += y_pred.sum() + y.sum()
|
108 |
+
torch.cuda.empty_cache()
|
109 |
+
return num / den
|
example_1.jpg
ADDED
![]() |
Git LFS Details
|
example_2.jpg
ADDED
![]() |
Git LFS Details
|
example_3.jpg
ADDED
![]() |
Git LFS Details
|
flagged/input_img/a7a20e8c8e03de5e007f/example_1.jpg
ADDED
![]() |
Git LFS Details
|
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
input_img,Model,output,flag,username,timestamp
|
2 |
+
flagged\input_img\a7a20e8c8e03de5e007f\example_1.jpg,segroot_finetuned,,,,2024-11-20 11:20:45.490192
|
logo.png
ADDED
![]() |
Git LFS Details
|
main.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from processsors import RootSegmentor
|
4 |
+
from processsors import *
|
5 |
+
|
6 |
+
from gradio_imageslider import ImageSlider
|
7 |
+
|
8 |
+
import cv2 as cv
|
9 |
+
|
10 |
+
PRELOAD_MODELS = False
|
11 |
+
|
12 |
+
if PRELOAD_MODELS:
|
13 |
+
root_segmentor = RootSegmentor()
|
14 |
+
|
15 |
+
|
16 |
+
def process(input_img, model_type):
|
17 |
+
|
18 |
+
print(model_type)
|
19 |
+
|
20 |
+
if PRELOAD_MODELS:
|
21 |
+
global root_segmentor
|
22 |
+
else:
|
23 |
+
root_segmentor = RootSegmentor(model_type)
|
24 |
+
|
25 |
+
result = root_segmentor.predict(input_img)
|
26 |
+
|
27 |
+
return result
|
28 |
+
|
29 |
+
def just_show(files, should_process, model_type):
|
30 |
+
|
31 |
+
imgs = []
|
32 |
+
|
33 |
+
img = merge_images(files)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
imgs.append(img)
|
38 |
+
|
39 |
+
if should_process:
|
40 |
+
root_segmentor = RootSegmentor(model_type)
|
41 |
+
|
42 |
+
results = []
|
43 |
+
|
44 |
+
for file in files:
|
45 |
+
print(type(file))
|
46 |
+
print(file)
|
47 |
+
img = cv.imread(file)
|
48 |
+
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
|
49 |
+
#imgs.append(img)
|
50 |
+
|
51 |
+
if should_process:
|
52 |
+
|
53 |
+
result = root_segmentor.predict(img)
|
54 |
+
results.append(result)
|
55 |
+
#imgs.append(results)
|
56 |
+
|
57 |
+
if should_process:
|
58 |
+
img_res = merge_images(results)
|
59 |
+
imgs.append(img_res)
|
60 |
+
|
61 |
+
return imgs
|
62 |
+
|
63 |
+
def slider_test(img1, img2):
|
64 |
+
|
65 |
+
return [img1,img2]
|
66 |
+
|
67 |
+
def download_result():
|
68 |
+
|
69 |
+
#print(filepath)
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
def gui():
|
74 |
+
|
75 |
+
with gr.Blocks(title="Root analysis", theme=gr.themes.Soft()) as demo:
|
76 |
+
|
77 |
+
big_block = gr.HTML("""
|
78 |
+
|
79 |
+
<style>
|
80 |
+
body {
|
81 |
+
font-family: Arial, sans-serif;
|
82 |
+
background-color: white
|
83 |
+
margin: 0;
|
84 |
+
}
|
85 |
+
|
86 |
+
header {
|
87 |
+
display: flex;
|
88 |
+
justify-content: space-between;
|
89 |
+
align-items: center;
|
90 |
+
padding: 5px;
|
91 |
+
color: #fff;
|
92 |
+
}
|
93 |
+
|
94 |
+
hr {
|
95 |
+
border: 1px solid #ddd;
|
96 |
+
margin: 5px;
|
97 |
+
}
|
98 |
+
|
99 |
+
</style>
|
100 |
+
|
101 |
+
<header>
|
102 |
+
<div style="display: flex; align-items: center;">
|
103 |
+
<div style="text-align: left;">
|
104 |
+
<h1>Root Analysis</h1>
|
105 |
+
<p>Root segmentation using underground root scanner images.</p>
|
106 |
+
<h3>Tropical Forages Program</h3>
|
107 |
+
<p><b>Authors: </b>Andres Felipe Ruiz-Hurtado, Juan Andrés Cardoso Arango</p>
|
108 |
+
<p></p>
|
109 |
+
</div>
|
110 |
+
</div>
|
111 |
+
<div style="background-color: white; padding: 5px; border-radius: 15px; box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);">
|
112 |
+
<img src='file/logo.png' alt="Logo" width="200" height="100">
|
113 |
+
</div>
|
114 |
+
</header>
|
115 |
+
|
116 |
+
""")
|
117 |
+
|
118 |
+
#<iframe style="height:600px;width: 100%;" src="/file=slides.html" title="description"></iframe>
|
119 |
+
|
120 |
+
|
121 |
+
#<iframe style="height:600px;width: 100%;" src="https://revealjs.com/demo/?view" title="description"></iframe>
|
122 |
+
|
123 |
+
with gr.Tab("Single Image"):
|
124 |
+
|
125 |
+
model_selector = gr.Dropdown(
|
126 |
+
["segroot_finetuned", "segroot", "segroot_finetuned_dec", "seg_model"], label="Model"
|
127 |
+
, info="AI model"
|
128 |
+
,value="segroot_finetuned"
|
129 |
+
)
|
130 |
+
|
131 |
+
input_img=gr.Image(render=False)
|
132 |
+
output_img=gr.Image(render=False)
|
133 |
+
|
134 |
+
gr.Interface(
|
135 |
+
fn=process,
|
136 |
+
inputs=[input_img,model_selector],
|
137 |
+
outputs=output_img,
|
138 |
+
examples=[["example_1.jpg"],["example_2.jpg"],["example_3.jpg"]]
|
139 |
+
)
|
140 |
+
|
141 |
+
#examples = gr.Examples([["Chicago"], ["Little Rock"], ["San Francisco"]], textbox)
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
img_comp = ImageSlider(label="Root Segmentation")
|
145 |
+
with gr.Row():
|
146 |
+
compare_button = gr.Button("Compare")
|
147 |
+
compare_button.click(fn=slider_test, inputs=[input_img,output_img], outputs=img_comp, api_name="slider_test")
|
148 |
+
|
149 |
+
with gr.Tab("Multiple Images"):
|
150 |
+
|
151 |
+
#img_comp = ImageSlider(label="Blur image", type="pil")
|
152 |
+
|
153 |
+
gallery = gr.Gallery(show_fullscreen_button=True, render=False)
|
154 |
+
|
155 |
+
gr.Interface(
|
156 |
+
fn=just_show
|
157 |
+
,inputs=[gr.File(file_count="multiple"),gr.Checkbox(label="Process", info="Check if you want to process"),model_selector]
|
158 |
+
,outputs= gallery
|
159 |
+
, examples=[[["example_1.jpg", "example_2.jpg", "example_3.jpg"]]]
|
160 |
+
)
|
161 |
+
|
162 |
+
with gr.Tab("Compare"):
|
163 |
+
|
164 |
+
img_comp = ImageSlider(label="Root Segmentation")
|
165 |
+
img_comp.upload(inputs=img_comp, outputs=img_comp)
|
166 |
+
|
167 |
+
|
168 |
+
#d = gr.DownloadButton("Download the file")
|
169 |
+
#d.click(download_result, gallery, None)
|
170 |
+
|
171 |
+
# with gr.Row():
|
172 |
+
# img1=gr.Image()
|
173 |
+
# img2=gr.Image()
|
174 |
+
# with gr.Row():
|
175 |
+
# img_comp = ImageSlider(label="Blur image", type="pil")
|
176 |
+
# with gr.Row():
|
177 |
+
# compare_button = gr.Button("Compare")
|
178 |
+
# compare_button.click(fn=slider_test, inputs=[img1,img2], outputs=img_comp, api_name="slider_test")
|
179 |
+
|
180 |
+
# with gr.Group():
|
181 |
+
# img_comp = ImageSlider(label="Blur image", type="pil")
|
182 |
+
# #img1.upload(slider_test, inputs=[img1,img2], outputs=img_comp)
|
183 |
+
# gr.Interface(slider_test, inputs=[img1,img2], outputs=img_comp)
|
184 |
+
|
185 |
+
demo.launch(allowed_paths=["logo.png"], share=False)
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
gui()
|
models/best_segnet-(8,5)-0.6441.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dffa166609b5ab3241d1b175bffbf454377beaa4a7fb46bd74e38605e2f71d03
|
3 |
+
size 1611034
|
models/roots_model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48254d394d1b11fd9bcfd42bcc754bb1fba5a2052848f5ad70b259972bce4681
|
3 |
+
size 58655218
|
models/segroot-(8,5)_finetuned.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbb992086ea1900ef24e110d7b454126d6214ac8f14687348ec021cf860f4eca
|
3 |
+
size 1640578
|
processsors.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from skimage.morphology import erosion
|
9 |
+
|
10 |
+
from dependecies.segroot.model import SegRoot
|
11 |
+
from dependecies.segroot.dataloader import pad_pair_256, normalize
|
12 |
+
from torchvision.transforms import v2 as transforms
|
13 |
+
|
14 |
+
|
15 |
+
import onnxruntime as ort
|
16 |
+
import cv2 as cv
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
MODELS_PATH = r"./models"
|
21 |
+
|
22 |
+
def pad_256(img_path):
|
23 |
+
image = Image.open(img_path)
|
24 |
+
W, H = image.size
|
25 |
+
img, _ = pad_pair_256(image, image)
|
26 |
+
NW, NH = img.size
|
27 |
+
img = torchvision.transforms.ToTensor()(img)
|
28 |
+
img = normalize(img)
|
29 |
+
return img, (H, W, NH, NW)
|
30 |
+
|
31 |
+
def pad_256_np(np_img):
|
32 |
+
#image = Image.open(img_path)
|
33 |
+
image = Image.fromarray(np_img)
|
34 |
+
W, H = image.size
|
35 |
+
img, _ = pad_pair_256(image, image)
|
36 |
+
NW, NH = img.size
|
37 |
+
img = torchvision.transforms.ToTensor()(img)
|
38 |
+
img = normalize(img)
|
39 |
+
return img, (H, W, NH, NW)
|
40 |
+
|
41 |
+
def merge_images(files, path=""):
|
42 |
+
|
43 |
+
is_array = False
|
44 |
+
if type(files[0]) == np.ndarray:
|
45 |
+
is_array = True
|
46 |
+
|
47 |
+
|
48 |
+
final_img = []
|
49 |
+
resize_factor = 0.4
|
50 |
+
offset0 = 930
|
51 |
+
offset1 = 305
|
52 |
+
for index, file in enumerate(files):
|
53 |
+
|
54 |
+
if is_array:
|
55 |
+
img = file
|
56 |
+
else:
|
57 |
+
img = cv.imread(file)
|
58 |
+
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
|
59 |
+
#img = cv.resize(img, (0,0), fx=resize_factor, fy=resize_factor)
|
60 |
+
img = cv.rotate(img, cv.ROTATE_90_CLOCKWISE)
|
61 |
+
|
62 |
+
if index == 0:
|
63 |
+
img = img[0:img.shape[0]-offset0,0:img.shape[1]]
|
64 |
+
final_img = img
|
65 |
+
elif index == len(file)-1:
|
66 |
+
final_img = cv.vconcat([final_img, img])
|
67 |
+
else:
|
68 |
+
#final_img = np.concatenate((final_img, img), axis=1)
|
69 |
+
img = img[0:img.shape[0]-offset1,0:img.shape[1]]
|
70 |
+
final_img = cv.vconcat([final_img, img])
|
71 |
+
|
72 |
+
final_img = cv.resize(final_img, (0,0), fx=resize_factor, fy=resize_factor)
|
73 |
+
|
74 |
+
#cv.imwrite(path, final_img)
|
75 |
+
print(final_img.shape)
|
76 |
+
|
77 |
+
return final_img
|
78 |
+
|
79 |
+
class RootSegmentor():
|
80 |
+
|
81 |
+
def __init__(self, model_type):
|
82 |
+
|
83 |
+
|
84 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
85 |
+
|
86 |
+
self.model_type = model_type
|
87 |
+
|
88 |
+
if model_type != "seg_model":
|
89 |
+
self.initialize()
|
90 |
+
|
91 |
+
return
|
92 |
+
|
93 |
+
def initialize(self):
|
94 |
+
|
95 |
+
width = 8
|
96 |
+
depth = 5
|
97 |
+
|
98 |
+
if self.model_type == "segroot":
|
99 |
+
#weights_path = os.path.join(r"D:\local_mydev\roots_finetuning\SegRoot0\weights\best_segnet-(8,5)-0.6441.pt"
|
100 |
+
#weights_path = r"D:\local_mydev\SegRoot\weights\best_segnet-(8,5)-0.6441.pt"
|
101 |
+
#weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\best_segnet-(8,5)-0.6441.pt"
|
102 |
+
#weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\best_segnet-(8,5)-0.6441.pt")
|
103 |
+
weights_path = os.path.join(MODELS_PATH, r"best_segnet-(8,5)-0.6441.pt")
|
104 |
+
elif self.model_type == "segroot_finetuned":
|
105 |
+
#weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned.pt"
|
106 |
+
#weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned.pt")
|
107 |
+
weights_path = os.path.join(MODELS_PATH, r"segroot-(8,5)_finetuned.pt")
|
108 |
+
elif self.model_type == "segroot_finetuned_dec":
|
109 |
+
#weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_dec_full.pt"
|
110 |
+
#weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_clas.pt"
|
111 |
+
#weights_path = os.path.join(MODELS_PATH, r"AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\segroot-(8,5)_finetuned_clas.pt")
|
112 |
+
weights_path = os.path.join(MODELS_PATH, r"segroot-(8,5)_finetuned.pt")
|
113 |
+
|
114 |
+
self.model = SegRoot(width, depth).to(self.device)
|
115 |
+
|
116 |
+
if self.device.type == "cpu":
|
117 |
+
print("load weights to cpu")
|
118 |
+
#print(weights_path.as_posix())
|
119 |
+
self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
120 |
+
else:
|
121 |
+
print("load weights to gpu")
|
122 |
+
#print(weights_path.as_posix())
|
123 |
+
self.model.load_state_dict(torch.load(weights_path))
|
124 |
+
|
125 |
+
for p in self.model.parameters():
|
126 |
+
p.requires_grad = False
|
127 |
+
|
128 |
+
self.model.eval()
|
129 |
+
|
130 |
+
return
|
131 |
+
|
132 |
+
def predict(self, img_path):
|
133 |
+
|
134 |
+
if self.model_type == "seg_model":
|
135 |
+
|
136 |
+
print(str(type(img_path)))
|
137 |
+
|
138 |
+
if type(img_path) == np.ndarray:
|
139 |
+
img = img_path
|
140 |
+
else:
|
141 |
+
img = cv.imread(img_path)
|
142 |
+
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
|
143 |
+
|
144 |
+
weights_path = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\roots\roots_model.onnx"
|
145 |
+
weights_path = os.path.join(MODELS_PATH,"roots_model.onnx")
|
146 |
+
ort_sess = ort.InferenceSession(weights_path
|
147 |
+
,providers=ort.get_available_providers()
|
148 |
+
)
|
149 |
+
|
150 |
+
dim = img.shape
|
151 |
+
|
152 |
+
transforms_list = []
|
153 |
+
transforms_list.append(transforms.ToTensor())
|
154 |
+
transforms_list.append(transforms.Resize((800,800)))
|
155 |
+
#transforms_list.append(transforms.CenterCrop(800))
|
156 |
+
#transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
|
157 |
+
|
158 |
+
apply_t = transforms.Compose(transforms_list)
|
159 |
+
|
160 |
+
img = apply_t(img)
|
161 |
+
|
162 |
+
outputs = ort_sess.run(None, {'input': [img.numpy()]})
|
163 |
+
|
164 |
+
print(outputs)
|
165 |
+
|
166 |
+
#np_res = outputs[0][0]
|
167 |
+
|
168 |
+
output_image = outputs[0][:,:,1]
|
169 |
+
final = cv.resize(output_image, (dim[0], dim[1]))
|
170 |
+
|
171 |
+
return final
|
172 |
+
|
173 |
+
else:
|
174 |
+
|
175 |
+
thres = 0.9
|
176 |
+
|
177 |
+
print(str(type(img_path)))
|
178 |
+
|
179 |
+
if type(img_path) == np.ndarray:
|
180 |
+
img, dims = pad_256_np(img_path)
|
181 |
+
else:
|
182 |
+
img, dims = pad_256(img_path)
|
183 |
+
|
184 |
+
H, W, NH, NW = dims
|
185 |
+
|
186 |
+
img = img.to(self.device)
|
187 |
+
|
188 |
+
img = img.unsqueeze(0)
|
189 |
+
output = self.model(img)
|
190 |
+
|
191 |
+
output = torch.squeeze(output)
|
192 |
+
torch.cuda.empty_cache()
|
193 |
+
|
194 |
+
prediction = output
|
195 |
+
|
196 |
+
prediction[prediction >= thres] = 1.0
|
197 |
+
prediction[prediction < thres] = 0.0
|
198 |
+
|
199 |
+
if self.device.type == "cpu":
|
200 |
+
prediction = prediction.detach().numpy()
|
201 |
+
else:
|
202 |
+
prediction = prediction.cpu().detach().numpy()
|
203 |
+
|
204 |
+
prediction = erosion(prediction)
|
205 |
+
# reverse padding
|
206 |
+
prediction = prediction[
|
207 |
+
(NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W
|
208 |
+
]
|
209 |
+
|
210 |
+
return prediction
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
opencv-python
|
4 |
+
pillow
|
5 |
+
scikit-image
|
6 |
+
scikit-learn
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
gradio
|
10 |
+
onnxruntime
|
11 |
+
rasterio
|