curt-park commited on
Commit
1615d09
1 Parent(s): 4c746e8

Refactor code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +90 -57
  2. isegm/data/base.py +32 -25
  3. isegm/data/compose.py +13 -7
  4. isegm/data/datasets/__init__.py +5 -4
  5. isegm/data/datasets/ade20k.py +13 -11
  6. isegm/data/datasets/berkeley.py +3 -1
  7. isegm/data/datasets/coco.py +21 -17
  8. isegm/data/datasets/coco_lvis.py +35 -23
  9. isegm/data/datasets/davis.py +6 -6
  10. isegm/data/datasets/grabcut.py +13 -7
  11. isegm/data/datasets/images_dir.py +22 -16
  12. isegm/data/datasets/lvis.py +27 -24
  13. isegm/data/datasets/openimages.py +22 -13
  14. isegm/data/datasets/pascalvoc.py +22 -10
  15. isegm/data/datasets/sbd.py +31 -23
  16. isegm/data/points_sampler.py +120 -57
  17. isegm/data/sample.py +50 -27
  18. isegm/data/transforms.py +73 -44
  19. isegm/engine/optimizer.py +10 -8
  20. isegm/engine/trainer.py +259 -122
  21. isegm/inference/clicker.py +15 -6
  22. isegm/inference/evaluation.py +22 -6
  23. isegm/inference/predictors/__init__.py +78 -56
  24. isegm/inference/predictors/base.py +49 -24
  25. isegm/inference/predictors/brs.py +157 -66
  26. isegm/inference/predictors/brs_functors.py +22 -13
  27. isegm/inference/predictors/brs_losses.py +7 -5
  28. isegm/inference/transforms/__init__.py +2 -2
  29. isegm/inference/transforms/crops.py +20 -9
  30. isegm/inference/transforms/flip.py +7 -3
  31. isegm/inference/transforms/zoom_in.py +69 -25
  32. isegm/inference/utils.py +59 -41
  33. isegm/model/initializer.py +32 -17
  34. isegm/model/is_deeplab_model.py +28 -9
  35. isegm/model/is_hrnet_model.py +19 -6
  36. isegm/model/is_model.py +86 -31
  37. isegm/model/losses.py +101 -31
  38. isegm/model/metrics.py +35 -9
  39. isegm/model/modeling/basic_blocks.py +68 -22
  40. isegm/model/modeling/deeplab_v3.py +106 -45
  41. isegm/model/modeling/hrnet_ocr.py +292 -132
  42. isegm/model/modeling/ocr.py +91 -45
  43. isegm/model/modeling/resnet.py +27 -13
  44. isegm/model/modeling/resnetv1b.py +227 -61
  45. isegm/model/modifiers.py +3 -5
  46. isegm/model/ops.py +38 -15
  47. isegm/utils/cython/__init__.py +1 -1
  48. isegm/utils/cython/_get_dist_maps.pyx +2 -1
  49. isegm/utils/cython/dist_maps.py +4 -2
  50. isegm/utils/distributed.py +11 -2
app.py CHANGED
@@ -1,94 +1,127 @@
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
- import numpy as np
4
- import cv2
5
  import wget
6
- import os
7
-
8
  from PIL import Image
9
  from streamlit_drawable_canvas import st_canvas
10
 
11
  from isegm.inference import clicker as ck
12
  from isegm.inference import utils
13
- from isegm.inference.predictors import get_predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- @st.experimental_memo
16
- def load_model(model_path, device):
 
 
 
17
  model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
18
  predictor_params = {"brs_mode": "NoBRS"}
19
  predictor = get_predictor(model, device=device, **predictor_params)
20
  return predictor
21
 
22
 
23
- # Objects in the global scope
24
- url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
25
- models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
26
- clicker = ck.Clicker()
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- pos_color, neg_color = "#3498DB", "#C70039"
29
- canvas_height, canvas_width = 600, 600
30
- err_x, err_y = 5.5, 1.0
31
- predictor = None
32
- image = None
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Items in the sidebar.
35
- model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
36
  threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
37
- marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
38
  image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
 
 
39
 
40
- # Objects for prediction.
 
 
 
41
  with st.spinner("Wait for downloading a model..."):
42
- if not os.path.exists(models[model]):
43
- _ = wget.download(f"{url_prefix}/{models[model]}")
44
-
45
  with st.spinner("Wait for loading a model..."):
46
- predictor = load_model(models[model], device)
47
 
 
 
 
48
  # Create a canvas component.
49
- if image_path:
50
- image = Image.open(image_path).convert("RGB")
51
-
52
  st.title("Canvas:")
53
  canvas_result = st_canvas(
54
- fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
55
- stroke_width=3,
56
- stroke_color=pos_color if marking_type == "positive" else neg_color,
57
- background_color="#eee",
58
- background_image=image,
59
- update_streamlit=True,
60
- drawing_mode="point",
61
- point_display_radius=3,
62
- key="canvas",
63
- width=canvas_width,
64
- height=canvas_height,
65
  )
66
 
 
 
 
67
  # Check the user inputs ans execute predictions.
68
  st.title("Prediction:")
69
  if canvas_result.json_data and canvas_result.json_data["objects"] and image:
70
- objects = canvas_result.json_data["objects"]
71
  image_width, image_height = image.size
72
- ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
73
-
74
- pos_clicks, neg_clicks = [], []
75
- for click in objects:
76
- x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
77
- x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
78
-
79
- is_positive = click["stroke"] == pos_color
80
- click = ck.Click(is_positive=is_positive, coords=(y, x))
81
- clicker.add_click(click)
82
 
83
  # Run prediction.
84
- pred = None
85
- predictor.set_input_image(np.array(image))
86
- init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
87
-
88
- with st.spinner("Wait for prediction..."):
89
- pred = predictor.get_prediction(clicker, prev_mask=init_mask)
90
- pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
91
- pred = np.where(pred > threshold, 1.0, 0)
92
 
93
  # Show the prediction result.
94
  st.image(pred, caption="")
 
1
+ import os
2
+ from typing import Dict, List
3
+
4
+ import cv2
5
+ import numpy as np
6
  import streamlit as st
7
  import torch
 
 
8
  import wget
 
 
9
  from PIL import Image
10
  from streamlit_drawable_canvas import st_canvas
11
 
12
  from isegm.inference import clicker as ck
13
  from isegm.inference import utils
14
+ from isegm.inference.predictors import BasePredictor, get_predictor
15
+
16
+ ###################################
17
+ # Global scope objects.
18
+ ###################################
19
+ URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
20
+ CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600
21
+ POS_COLOR, NEG_COLOR = "#3498DB", "#C70039"
22
+ ERR_X, ERR_Y = 5.5, 1.0
23
+ MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ clicker = ck.Clicker()
26
+ predictor = None
27
+ image = None
28
+
29
 
30
+ ###################################
31
+ # Functions.
32
+ ###################################
33
+ # @st.cache_resource
34
+ def load_model(model_path: str, device: torch.device) -> BasePredictor:
35
  model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
36
  predictor_params = {"brs_mode": "NoBRS"}
37
  predictor = get_predictor(model, device=device, **predictor_params)
38
  return predictor
39
 
40
 
41
+ def feed_clicks(
42
+ clicker: ck.Clicker,
43
+ clicks: List[Dict[str, float]],
44
+ image_width: int,
45
+ image_height: int,
46
+ ) -> None:
47
+ ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH
48
+ for click in clicks:
49
+ x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h
50
+ x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
51
+
52
+ is_positive = click["stroke"] == POS_COLOR
53
+ click = ck.Click(is_positive=is_positive, coords=(y, x))
54
+ clicker.add_click(click)
55
 
56
+
57
+ def predict(
58
+ image: Image, mask: torch.Tensor, threshold: float = 0.5
59
+ ) -> torch.Tensor:
60
+ predictor.set_input_image(np.array(image))
61
+ with st.spinner("Wait for prediction..."):
62
+ pred = predictor.get_prediction(clicker, prev_mask=mask)
63
+ pred = cv2.resize(
64
+ pred,
65
+ dsize=(CANVAS_HEIGHT, CANVAS_WIDTH),
66
+ interpolation=cv2.INTER_CUBIC,
67
+ )
68
+ pred = np.where(pred > threshold, 1.0, 0)
69
+ return pred
70
+
71
+
72
+ ###################################
73
+ # Sidebar GUI
74
+ ###################################
75
  # Items in the sidebar.
76
+ model = st.sidebar.selectbox("Select a Model:", tuple(MODELS.keys()))
77
  threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
78
+ marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative"))
79
  image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
80
+ if image_path:
81
+ image = Image.open(image_path).convert("RGB")
82
 
83
+ ###################################
84
+ # Preparation
85
+ ###################################
86
+ # Model.
87
  with st.spinner("Wait for downloading a model..."):
88
+ if not os.path.exists(MODELS[model]):
89
+ _ = wget.download(f"{URL_PREFIX}/{MODELS[model]}")
90
+ # Predictor.
91
  with st.spinner("Wait for loading a model..."):
92
+ predictor = load_model(MODELS[model], device)
93
 
94
+ ###################################
95
+ # GUI
96
+ ###################################
97
  # Create a canvas component.
 
 
 
98
  st.title("Canvas:")
99
  canvas_result = st_canvas(
100
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
101
+ stroke_width=3,
102
+ stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR,
103
+ background_color="#eee",
104
+ background_image=image,
105
+ update_streamlit=True,
106
+ drawing_mode="point",
107
+ point_display_radius=3,
108
+ key="canvas",
109
+ width=CANVAS_WIDTH,
110
+ height=CANVAS_HEIGHT,
111
  )
112
 
113
+ ###################################
114
+ # Prediction
115
+ ###################################
116
  # Check the user inputs ans execute predictions.
117
  st.title("Prediction:")
118
  if canvas_result.json_data and canvas_result.json_data["objects"] and image:
 
119
  image_width, image_height = image.size
120
+ feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
 
 
 
 
 
 
 
 
 
121
 
122
  # Run prediction.
123
+ mask = torch.zeros((1, 1, image_width, image_height), device=device)
124
+ pred = predict(image, mask, threshold)
 
 
 
 
 
 
125
 
126
  # Show the prediction result.
127
  st.image(pred, caption="")
isegm/data/base.py CHANGED
@@ -1,22 +1,26 @@
1
- import random
2
  import pickle
 
 
3
  import numpy as np
4
  import torch
5
  from torchvision import transforms
 
6
  from .points_sampler import MultiPointSampler
7
  from .sample import DSample
8
 
9
 
10
  class ISDataset(torch.utils.data.dataset.Dataset):
11
- def __init__(self,
12
- augmentator=None,
13
- points_sampler=MultiPointSampler(max_num_points=12),
14
- min_object_area=0,
15
- keep_background_prob=0.0,
16
- with_image_info=False,
17
- samples_scores_path=None,
18
- samples_scores_gamma=1.0,
19
- epoch_len=-1):
 
 
20
  super(ISDataset, self).__init__()
21
  self.epoch_len = epoch_len
22
  self.augmentator = augmentator
@@ -24,15 +28,19 @@ class ISDataset(torch.utils.data.dataset.Dataset):
24
  self.keep_background_prob = keep_background_prob
25
  self.points_sampler = points_sampler
26
  self.with_image_info = with_image_info
27
- self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma)
 
 
28
  self.to_tensor = transforms.ToTensor()
29
 
30
  self.dataset_samples = None
31
 
32
  def __getitem__(self, index):
33
  if self.samples_precomputed_scores is not None:
34
- index = np.random.choice(self.samples_precomputed_scores['indices'],
35
- p=self.samples_precomputed_scores['probs'])
 
 
36
  else:
37
  if self.epoch_len > 0:
38
  index = random.randrange(0, len(self.dataset_samples))
@@ -46,13 +54,13 @@ class ISDataset(torch.utils.data.dataset.Dataset):
46
  mask = self.points_sampler.selected_mask
47
 
48
  output = {
49
- 'images': self.to_tensor(sample.image),
50
- 'points': points.astype(np.float32),
51
- 'instances': mask
52
  }
53
 
54
  if self.with_image_info:
55
- output['image_info'] = sample.sample_id
56
 
57
  return output
58
 
@@ -63,8 +71,10 @@ class ISDataset(torch.utils.data.dataset.Dataset):
63
  valid_augmentation = False
64
  while not valid_augmentation:
65
  sample.augment(self.augmentator)
66
- keep_sample = (self.keep_background_prob < 0.0 or
67
- random.random() < self.keep_background_prob)
 
 
68
  valid_augmentation = len(sample) > 0 or keep_sample
69
 
70
  return sample
@@ -86,14 +96,11 @@ class ISDataset(torch.utils.data.dataset.Dataset):
86
  if samples_scores_path is None:
87
  return None
88
 
89
- with open(samples_scores_path, 'rb') as f:
90
  images_scores = pickle.load(f)
91
 
92
  probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
93
  probs /= probs.sum()
94
- samples_scores = {
95
- 'indices': [x[0] for x in images_scores],
96
- 'probs': probs
97
- }
98
- print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}')
99
  return samples_scores
 
 
1
  import pickle
2
+ import random
3
+
4
  import numpy as np
5
  import torch
6
  from torchvision import transforms
7
+
8
  from .points_sampler import MultiPointSampler
9
  from .sample import DSample
10
 
11
 
12
  class ISDataset(torch.utils.data.dataset.Dataset):
13
+ def __init__(
14
+ self,
15
+ augmentator=None,
16
+ points_sampler=MultiPointSampler(max_num_points=12),
17
+ min_object_area=0,
18
+ keep_background_prob=0.0,
19
+ with_image_info=False,
20
+ samples_scores_path=None,
21
+ samples_scores_gamma=1.0,
22
+ epoch_len=-1,
23
+ ):
24
  super(ISDataset, self).__init__()
25
  self.epoch_len = epoch_len
26
  self.augmentator = augmentator
 
28
  self.keep_background_prob = keep_background_prob
29
  self.points_sampler = points_sampler
30
  self.with_image_info = with_image_info
31
+ self.samples_precomputed_scores = self._load_samples_scores(
32
+ samples_scores_path, samples_scores_gamma
33
+ )
34
  self.to_tensor = transforms.ToTensor()
35
 
36
  self.dataset_samples = None
37
 
38
  def __getitem__(self, index):
39
  if self.samples_precomputed_scores is not None:
40
+ index = np.random.choice(
41
+ self.samples_precomputed_scores["indices"],
42
+ p=self.samples_precomputed_scores["probs"],
43
+ )
44
  else:
45
  if self.epoch_len > 0:
46
  index = random.randrange(0, len(self.dataset_samples))
 
54
  mask = self.points_sampler.selected_mask
55
 
56
  output = {
57
+ "images": self.to_tensor(sample.image),
58
+ "points": points.astype(np.float32),
59
+ "instances": mask,
60
  }
61
 
62
  if self.with_image_info:
63
+ output["image_info"] = sample.sample_id
64
 
65
  return output
66
 
 
71
  valid_augmentation = False
72
  while not valid_augmentation:
73
  sample.augment(self.augmentator)
74
+ keep_sample = (
75
+ self.keep_background_prob < 0.0
76
+ or random.random() < self.keep_background_prob
77
+ )
78
  valid_augmentation = len(sample) > 0 or keep_sample
79
 
80
  return sample
 
96
  if samples_scores_path is None:
97
  return None
98
 
99
+ with open(samples_scores_path, "rb") as f:
100
  images_scores = pickle.load(f)
101
 
102
  probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
103
  probs /= probs.sum()
104
+ samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs}
105
+ print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}")
 
 
 
106
  return samples_scores
isegm/data/compose.py CHANGED
@@ -1,5 +1,7 @@
1
- import numpy as np
2
  from math import isclose
 
 
 
3
  from .base import ISDataset
4
 
5
 
@@ -10,7 +12,9 @@ class ComposeDataset(ISDataset):
10
  self._datasets = datasets
11
  self.dataset_samples = []
12
  for dataset_indx, dataset in enumerate(self._datasets):
13
- self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
 
 
14
 
15
  def get_sample(self, index):
16
  dataset_indx, sample_indx = self.dataset_samples[index]
@@ -21,16 +25,18 @@ class ProportionalComposeDataset(ISDataset):
21
  def __init__(self, datasets, ratios, **kwargs):
22
  super().__init__(**kwargs)
23
 
24
- assert len(ratios) == len(datasets),\
25
- "The number of datasets must match the number of ratios"
26
- assert isclose(sum(ratios), 1.0),\
27
- "The sum of ratios must be equal to 1"
28
 
29
  self._ratios = ratios
30
  self._datasets = datasets
31
  self.dataset_samples = []
32
  for dataset_indx, dataset in enumerate(self._datasets):
33
- self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
 
 
34
 
35
  def get_sample(self, index):
36
  dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
 
 
1
  from math import isclose
2
+
3
+ import numpy as np
4
+
5
  from .base import ISDataset
6
 
7
 
 
12
  self._datasets = datasets
13
  self.dataset_samples = []
14
  for dataset_indx, dataset in enumerate(self._datasets):
15
+ self.dataset_samples.extend(
16
+ [(dataset_indx, i) for i in range(len(dataset))]
17
+ )
18
 
19
  def get_sample(self, index):
20
  dataset_indx, sample_indx = self.dataset_samples[index]
 
25
  def __init__(self, datasets, ratios, **kwargs):
26
  super().__init__(**kwargs)
27
 
28
+ assert len(ratios) == len(
29
+ datasets
30
+ ), "The number of datasets must match the number of ratios"
31
+ assert isclose(sum(ratios), 1.0), "The sum of ratios must be equal to 1"
32
 
33
  self._ratios = ratios
34
  self._datasets = datasets
35
  self.dataset_samples = []
36
  for dataset_indx, dataset in enumerate(self._datasets):
37
+ self.dataset_samples.extend(
38
+ [(dataset_indx, i) for i in range(len(dataset))]
39
+ )
40
 
41
  def get_sample(self, index):
42
  dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
isegm/data/datasets/__init__.py CHANGED
@@ -1,12 +1,13 @@
1
  from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
 
 
2
  from .berkeley import BerkeleyDataset
3
  from .coco import CocoDataset
 
4
  from .davis import DavisDataset
5
  from .grabcut import GrabCutDataset
6
- from .coco_lvis import CocoLvisDataset
7
  from .lvis import LvisDataset
8
  from .openimages import OpenImagesDataset
9
- from .sbd import SBDDataset, SBDEvaluationDataset
10
- from .images_dir import ImagesDirDataset
11
- from .ade20k import ADE20kDataset
12
  from .pascalvoc import PascalVocDataset
 
 
1
  from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
2
+
3
+ from .ade20k import ADE20kDataset
4
  from .berkeley import BerkeleyDataset
5
  from .coco import CocoDataset
6
+ from .coco_lvis import CocoLvisDataset
7
  from .davis import DavisDataset
8
  from .grabcut import GrabCutDataset
9
+ from .images_dir import ImagesDirDataset
10
  from .lvis import LvisDataset
11
  from .openimages import OpenImagesDataset
 
 
 
12
  from .pascalvoc import PascalVocDataset
13
+ from .sbd import SBDDataset, SBDEvaluationDataset
isegm/data/datasets/ade20k.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- import random
3
  import pickle as pkl
 
4
  from pathlib import Path
5
 
6
  import cv2
@@ -12,18 +12,18 @@ from isegm.utils.misc import get_labels_with_sizes
12
 
13
 
14
  class ADE20kDataset(ISDataset):
15
- def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
16
  super().__init__(**kwargs)
17
- assert split in {'train', 'val'}
18
 
19
  self.dataset_path = Path(dataset_path)
20
  self.dataset_split = split
21
- self.dataset_split_folder = 'training' if split == 'train' else 'validation'
22
  self.stuff_prob = stuff_prob
23
 
24
- anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl'
25
  if os.path.exists(anno_path):
26
- with anno_path.open('rb') as f:
27
  annotations = pkl.load(f)
28
  else:
29
  raise RuntimeError(f"Can't find annotations at {anno_path}")
@@ -34,21 +34,23 @@ class ADE20kDataset(ISDataset):
34
  image_id = self.dataset_samples[index]
35
  sample_annos = self.annotations[image_id]
36
 
37
- image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg')
38
  image = cv2.imread(image_path)
39
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
 
41
  # select random mask for an image
42
- layer = random.choice(sample_annos['layers'])
43
- mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name'])
44
- instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances
 
 
45
  instances_mask = instances_mask.astype(np.int32)
46
  object_ids, _ = get_labels_with_sizes(instances_mask)
47
 
48
  if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob):
49
  # remove stuff objects
50
  for i, object_id in enumerate(object_ids):
51
- if i in layer['stuff_instances']:
52
  instances_mask[instances_mask == object_id] = 0
53
  object_ids, _ = get_labels_with_sizes(instances_mask)
54
 
 
1
  import os
 
2
  import pickle as pkl
3
+ import random
4
  from pathlib import Path
5
 
6
  import cv2
 
12
 
13
 
14
  class ADE20kDataset(ISDataset):
15
+ def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs):
16
  super().__init__(**kwargs)
17
+ assert split in {"train", "val"}
18
 
19
  self.dataset_path = Path(dataset_path)
20
  self.dataset_split = split
21
+ self.dataset_split_folder = "training" if split == "train" else "validation"
22
  self.stuff_prob = stuff_prob
23
 
24
+ anno_path = self.dataset_path / f"{split}-annotations-object-segmentation.pkl"
25
  if os.path.exists(anno_path):
26
+ with anno_path.open("rb") as f:
27
  annotations = pkl.load(f)
28
  else:
29
  raise RuntimeError(f"Can't find annotations at {anno_path}")
 
34
  image_id = self.dataset_samples[index]
35
  sample_annos = self.annotations[image_id]
36
 
37
+ image_path = str(self.dataset_path / sample_annos["folder"] / f"{image_id}.jpg")
38
  image = cv2.imread(image_path)
39
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
 
41
  # select random mask for an image
42
+ layer = random.choice(sample_annos["layers"])
43
+ mask_path = str(self.dataset_path / sample_annos["folder"] / layer["mask_name"])
44
+ instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[
45
+ :, :, 0
46
+ ] # the B channel holds instances
47
  instances_mask = instances_mask.astype(np.int32)
48
  object_ids, _ = get_labels_with_sizes(instances_mask)
49
 
50
  if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob):
51
  # remove stuff objects
52
  for i, object_id in enumerate(object_ids):
53
+ if i in layer["stuff_instances"]:
54
  instances_mask[instances_mask == object_id] = 0
55
  object_ids, _ = get_labels_with_sizes(instances_mask)
56
 
isegm/data/datasets/berkeley.py CHANGED
@@ -3,4 +3,6 @@ from .grabcut import GrabCutDataset
3
 
4
  class BerkeleyDataset(GrabCutDataset):
5
  def __init__(self, dataset_path, **kwargs):
6
- super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs)
 
 
 
3
 
4
  class BerkeleyDataset(GrabCutDataset):
5
  def __init__(self, dataset_path, **kwargs):
6
+ super().__init__(
7
+ dataset_path, images_dir_name="images", masks_dir_name="masks", **kwargs
8
+ )
isegm/data/datasets/coco.py CHANGED
@@ -1,14 +1,16 @@
1
- import cv2
2
  import json
3
  import random
4
- import numpy as np
5
  from pathlib import Path
 
 
 
 
6
  from isegm.data.base import ISDataset
7
  from isegm.data.sample import DSample
8
 
9
 
10
  class CocoDataset(ISDataset):
11
- def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
12
  super(CocoDataset, self).__init__(**kwargs)
13
  self.split = split
14
  self.dataset_path = Path(dataset_path)
@@ -17,26 +19,28 @@ class CocoDataset(ISDataset):
17
  self.load_samples()
18
 
19
  def load_samples(self):
20
- annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json'
21
- self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}'
 
 
22
  self.images_path = self.dataset_path / self.split
23
 
24
- with open(annotation_path, 'r') as f:
25
  annotation = json.load(f)
26
 
27
- self.dataset_samples = annotation['annotations']
28
 
29
- self._categories = annotation['categories']
30
- self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0]
31
- self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1]
32
  self._things_labels_set = set(self._things_labels)
33
  self._stuff_labels_set = set(self._stuff_labels)
34
 
35
  def get_sample(self, index) -> DSample:
36
  dataset_sample = self.dataset_samples[index]
37
 
38
- image_path = self.images_path / self.get_image_name(dataset_sample['file_name'])
39
- label_path = self.labels_path / dataset_sample['file_name']
40
 
41
  image = cv2.imread(str(image_path))
42
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -47,11 +51,11 @@ class CocoDataset(ISDataset):
47
  things_ids = []
48
  stuff_ids = []
49
 
50
- for segment in dataset_sample['segments_info']:
51
- class_id = segment['category_id']
52
- obj_id = segment['id']
53
  if class_id in self._things_labels_set:
54
- if segment['iscrowd'] == 1:
55
  continue
56
  things_ids.append(obj_id)
57
  else:
@@ -71,4 +75,4 @@ class CocoDataset(ISDataset):
71
 
72
  @classmethod
73
  def get_image_name(cls, panoptic_name):
74
- return panoptic_name.replace('.png', '.jpg')
 
 
1
  import json
2
  import random
 
3
  from pathlib import Path
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
  from isegm.data.base import ISDataset
9
  from isegm.data.sample import DSample
10
 
11
 
12
  class CocoDataset(ISDataset):
13
+ def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs):
14
  super(CocoDataset, self).__init__(**kwargs)
15
  self.split = split
16
  self.dataset_path = Path(dataset_path)
 
19
  self.load_samples()
20
 
21
  def load_samples(self):
22
+ annotation_path = (
23
+ self.dataset_path / "annotations" / f"panoptic_{self.split}.json"
24
+ )
25
+ self.labels_path = self.dataset_path / "annotations" / f"panoptic_{self.split}"
26
  self.images_path = self.dataset_path / self.split
27
 
28
+ with open(annotation_path, "r") as f:
29
  annotation = json.load(f)
30
 
31
+ self.dataset_samples = annotation["annotations"]
32
 
33
+ self._categories = annotation["categories"]
34
+ self._stuff_labels = [x["id"] for x in self._categories if x["isthing"] == 0]
35
+ self._things_labels = [x["id"] for x in self._categories if x["isthing"] == 1]
36
  self._things_labels_set = set(self._things_labels)
37
  self._stuff_labels_set = set(self._stuff_labels)
38
 
39
  def get_sample(self, index) -> DSample:
40
  dataset_sample = self.dataset_samples[index]
41
 
42
+ image_path = self.images_path / self.get_image_name(dataset_sample["file_name"])
43
+ label_path = self.labels_path / dataset_sample["file_name"]
44
 
45
  image = cv2.imread(str(image_path))
46
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
51
  things_ids = []
52
  stuff_ids = []
53
 
54
+ for segment in dataset_sample["segments_info"]:
55
+ class_id = segment["category_id"]
56
+ obj_id = segment["id"]
57
  if class_id in self._things_labels_set:
58
+ if segment["iscrowd"] == 1:
59
  continue
60
  things_ids.append(obj_id)
61
  else:
 
75
 
76
  @classmethod
77
  def get_image_name(cls, panoptic_name):
78
+ return panoptic_name.replace(".png", ".jpg")
isegm/data/datasets/coco_lvis.py CHANGED
@@ -1,66 +1,78 @@
1
- from pathlib import Path
2
  import pickle
3
  import random
4
- import numpy as np
5
- import json
6
- import cv2
7
  from copy import deepcopy
 
 
 
 
 
8
  from isegm.data.base import ISDataset
9
  from isegm.data.sample import DSample
10
 
11
 
12
  class CocoLvisDataset(ISDataset):
13
- def __init__(self, dataset_path, split='train', stuff_prob=0.0,
14
- allow_list_name=None, anno_file='hannotation.pickle', **kwargs):
 
 
 
 
 
 
 
15
  super(CocoLvisDataset, self).__init__(**kwargs)
16
  dataset_path = Path(dataset_path)
17
  self._split_path = dataset_path / split
18
  self.split = split
19
- self._images_path = self._split_path / 'images'
20
- self._masks_path = self._split_path / 'masks'
21
  self.stuff_prob = stuff_prob
22
 
23
- with open(self._split_path / anno_file, 'rb') as f:
24
  self.dataset_samples = sorted(pickle.load(f).items())
25
 
26
  if allow_list_name is not None:
27
  allow_list_path = self._split_path / allow_list_name
28
- with open(allow_list_path, 'r') as f:
29
  allow_images_ids = json.load(f)
30
  allow_images_ids = set(allow_images_ids)
31
 
32
- self.dataset_samples = [sample for sample in self.dataset_samples
33
- if sample[0] in allow_images_ids]
 
 
 
34
 
35
  def get_sample(self, index) -> DSample:
36
  image_id, sample = self.dataset_samples[index]
37
- image_path = self._images_path / f'{image_id}.jpg'
38
 
39
  image = cv2.imread(str(image_path))
40
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
 
42
- packed_masks_path = self._masks_path / f'{image_id}.pickle'
43
- with open(packed_masks_path, 'rb') as f:
44
  encoded_layers, objs_mapping = pickle.load(f)
45
  layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
46
  layers = np.stack(layers, axis=2)
47
 
48
- instances_info = deepcopy(sample['hierarchy'])
49
  for inst_id, inst_info in list(instances_info.items()):
50
  if inst_info is None:
51
- inst_info = {'children': [], 'parent': None, 'node_level': 0}
52
  instances_info[inst_id] = inst_info
53
- inst_info['mapping'] = objs_mapping[inst_id]
54
 
55
  if self.stuff_prob > 0 and random.random() < self.stuff_prob:
56
- for inst_id in range(sample['num_instance_masks'], len(objs_mapping)):
57
  instances_info[inst_id] = {
58
- 'mapping': objs_mapping[inst_id],
59
- 'parent': None,
60
- 'children': []
61
  }
62
  else:
63
- for inst_id in range(sample['num_instance_masks'], len(objs_mapping)):
64
  layer_indx, mask_id = objs_mapping[inst_id]
65
  layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
66
 
 
1
+ import json
2
  import pickle
3
  import random
 
 
 
4
  from copy import deepcopy
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
  from isegm.data.base import ISDataset
11
  from isegm.data.sample import DSample
12
 
13
 
14
  class CocoLvisDataset(ISDataset):
15
+ def __init__(
16
+ self,
17
+ dataset_path,
18
+ split="train",
19
+ stuff_prob=0.0,
20
+ allow_list_name=None,
21
+ anno_file="hannotation.pickle",
22
+ **kwargs,
23
+ ):
24
  super(CocoLvisDataset, self).__init__(**kwargs)
25
  dataset_path = Path(dataset_path)
26
  self._split_path = dataset_path / split
27
  self.split = split
28
+ self._images_path = self._split_path / "images"
29
+ self._masks_path = self._split_path / "masks"
30
  self.stuff_prob = stuff_prob
31
 
32
+ with open(self._split_path / anno_file, "rb") as f:
33
  self.dataset_samples = sorted(pickle.load(f).items())
34
 
35
  if allow_list_name is not None:
36
  allow_list_path = self._split_path / allow_list_name
37
+ with open(allow_list_path, "r") as f:
38
  allow_images_ids = json.load(f)
39
  allow_images_ids = set(allow_images_ids)
40
 
41
+ self.dataset_samples = [
42
+ sample
43
+ for sample in self.dataset_samples
44
+ if sample[0] in allow_images_ids
45
+ ]
46
 
47
  def get_sample(self, index) -> DSample:
48
  image_id, sample = self.dataset_samples[index]
49
+ image_path = self._images_path / f"{image_id}.jpg"
50
 
51
  image = cv2.imread(str(image_path))
52
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
 
54
+ packed_masks_path = self._masks_path / f"{image_id}.pickle"
55
+ with open(packed_masks_path, "rb") as f:
56
  encoded_layers, objs_mapping = pickle.load(f)
57
  layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
58
  layers = np.stack(layers, axis=2)
59
 
60
+ instances_info = deepcopy(sample["hierarchy"])
61
  for inst_id, inst_info in list(instances_info.items()):
62
  if inst_info is None:
63
+ inst_info = {"children": [], "parent": None, "node_level": 0}
64
  instances_info[inst_id] = inst_info
65
+ inst_info["mapping"] = objs_mapping[inst_id]
66
 
67
  if self.stuff_prob > 0 and random.random() < self.stuff_prob:
68
+ for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
69
  instances_info[inst_id] = {
70
+ "mapping": objs_mapping[inst_id],
71
+ "parent": None,
72
+ "children": [],
73
  }
74
  else:
75
+ for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
76
  layer_indx, mask_id = objs_mapping[inst_id]
77
  layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
78
 
isegm/data/datasets/davis.py CHANGED
@@ -8,22 +8,22 @@ from isegm.data.sample import DSample
8
 
9
 
10
  class DavisDataset(ISDataset):
11
- def __init__(self, dataset_path,
12
- images_dir_name='img', masks_dir_name='gt',
13
- **kwargs):
14
  super(DavisDataset, self).__init__(**kwargs)
15
 
16
  self.dataset_path = Path(dataset_path)
17
  self._images_path = self.dataset_path / images_dir_name
18
  self._insts_path = self.dataset_path / masks_dir_name
19
 
20
- self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
21
- self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
22
 
23
  def get_sample(self, index) -> DSample:
24
  image_name = self.dataset_samples[index]
25
  image_path = str(self._images_path / image_name)
26
- mask_path = str(self._masks_paths[image_name.split('.')[0]])
27
 
28
  image = cv2.imread(image_path)
29
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
8
 
9
 
10
  class DavisDataset(ISDataset):
11
+ def __init__(
12
+ self, dataset_path, images_dir_name="img", masks_dir_name="gt", **kwargs
13
+ ):
14
  super(DavisDataset, self).__init__(**kwargs)
15
 
16
  self.dataset_path = Path(dataset_path)
17
  self._images_path = self.dataset_path / images_dir_name
18
  self._insts_path = self.dataset_path / masks_dir_name
19
 
20
+ self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))]
21
+ self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")}
22
 
23
  def get_sample(self, index) -> DSample:
24
  image_name = self.dataset_samples[index]
25
  image_path = str(self._images_path / image_name)
26
+ mask_path = str(self._masks_paths[image_name.split(".")[0]])
27
 
28
  image = cv2.imread(image_path)
29
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
isegm/data/datasets/grabcut.py CHANGED
@@ -8,22 +8,26 @@ from isegm.data.sample import DSample
8
 
9
 
10
  class GrabCutDataset(ISDataset):
11
- def __init__(self, dataset_path,
12
- images_dir_name='data_GT', masks_dir_name='boundary_GT',
13
- **kwargs):
 
 
 
 
14
  super(GrabCutDataset, self).__init__(**kwargs)
15
 
16
  self.dataset_path = Path(dataset_path)
17
  self._images_path = self.dataset_path / images_dir_name
18
  self._insts_path = self.dataset_path / masks_dir_name
19
 
20
- self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
21
- self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
22
 
23
  def get_sample(self, index) -> DSample:
24
  image_name = self.dataset_samples[index]
25
  image_path = str(self._images_path / image_name)
26
- mask_path = str(self._masks_paths[image_name.split('.')[0]])
27
 
28
  image = cv2.imread(image_path)
29
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -31,4 +35,6 @@ class GrabCutDataset(ISDataset):
31
  instances_mask[instances_mask == 128] = -1
32
  instances_mask[instances_mask > 128] = 1
33
 
34
- return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index)
 
 
 
8
 
9
 
10
  class GrabCutDataset(ISDataset):
11
+ def __init__(
12
+ self,
13
+ dataset_path,
14
+ images_dir_name="data_GT",
15
+ masks_dir_name="boundary_GT",
16
+ **kwargs
17
+ ):
18
  super(GrabCutDataset, self).__init__(**kwargs)
19
 
20
  self.dataset_path = Path(dataset_path)
21
  self._images_path = self.dataset_path / images_dir_name
22
  self._insts_path = self.dataset_path / masks_dir_name
23
 
24
+ self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))]
25
+ self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")}
26
 
27
  def get_sample(self, index) -> DSample:
28
  image_name = self.dataset_samples[index]
29
  image_path = str(self._images_path / image_name)
30
+ mask_path = str(self._masks_paths[image_name.split(".")[0]])
31
 
32
  image = cv2.imread(image_path)
33
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
35
  instances_mask[instances_mask == 128] = -1
36
  instances_mask[instances_mask > 128] = 1
37
 
38
+ return DSample(
39
+ image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index
40
+ )
isegm/data/datasets/images_dir.py CHANGED
@@ -1,49 +1,50 @@
 
 
1
  import cv2
2
  import numpy as np
3
- from pathlib import Path
4
 
5
  from isegm.data.base import ISDataset
6
  from isegm.data.sample import DSample
7
 
8
 
9
  class ImagesDirDataset(ISDataset):
10
- def __init__(self, dataset_path,
11
- images_dir_name='images', masks_dir_name='masks',
12
- **kwargs):
13
  super(ImagesDirDataset, self).__init__(**kwargs)
14
 
15
  self.dataset_path = Path(dataset_path)
16
  self._images_path = self.dataset_path / images_dir_name
17
  self._insts_path = self.dataset_path / masks_dir_name
18
 
19
- images_list = [x for x in sorted(self._images_path.glob('*.*'))]
20
 
21
- samples = {x.stem: {'image': x, 'masks': []} for x in images_list}
22
- for mask_path in self._insts_path.glob('*.*'):
23
  mask_name = mask_path.stem
24
  if mask_name in samples:
25
- samples[mask_name]['masks'].append(mask_path)
26
  continue
27
 
28
- mask_name_split = mask_name.split('_')
29
  if mask_name_split[-1].isdigit():
30
- mask_name = '_'.join(mask_name_split[:-1])
31
  assert mask_name in samples
32
- samples[mask_name]['masks'].append(mask_path)
33
 
34
  for x in samples.values():
35
- assert len(x['masks']) > 0, x['image']
36
 
37
  self.dataset_samples = [v for k, v in sorted(samples.items())]
38
 
39
  def get_sample(self, index) -> DSample:
40
  sample = self.dataset_samples[index]
41
- image_path = str(sample['image'])
42
 
43
  objects = []
44
  ignored_regions = []
45
  masks = []
46
- for indx, mask_path in enumerate(sample['masks']):
47
  gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32)
48
  instances_mask = np.zeros_like(gt_mask)
49
  instances_mask[gt_mask == 128] = 2
@@ -55,5 +56,10 @@ class ImagesDirDataset(ISDataset):
55
  image = cv2.imread(image_path)
56
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
 
58
- return DSample(image, np.stack(masks, axis=2),
59
- objects_ids=objects, ignore_ids=ignored_regions, sample_id=index)
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
  import cv2
4
  import numpy as np
 
5
 
6
  from isegm.data.base import ISDataset
7
  from isegm.data.sample import DSample
8
 
9
 
10
  class ImagesDirDataset(ISDataset):
11
+ def __init__(
12
+ self, dataset_path, images_dir_name="images", masks_dir_name="masks", **kwargs
13
+ ):
14
  super(ImagesDirDataset, self).__init__(**kwargs)
15
 
16
  self.dataset_path = Path(dataset_path)
17
  self._images_path = self.dataset_path / images_dir_name
18
  self._insts_path = self.dataset_path / masks_dir_name
19
 
20
+ images_list = [x for x in sorted(self._images_path.glob("*.*"))]
21
 
22
+ samples = {x.stem: {"image": x, "masks": []} for x in images_list}
23
+ for mask_path in self._insts_path.glob("*.*"):
24
  mask_name = mask_path.stem
25
  if mask_name in samples:
26
+ samples[mask_name]["masks"].append(mask_path)
27
  continue
28
 
29
+ mask_name_split = mask_name.split("_")
30
  if mask_name_split[-1].isdigit():
31
+ mask_name = "_".join(mask_name_split[:-1])
32
  assert mask_name in samples
33
+ samples[mask_name]["masks"].append(mask_path)
34
 
35
  for x in samples.values():
36
+ assert len(x["masks"]) > 0, x["image"]
37
 
38
  self.dataset_samples = [v for k, v in sorted(samples.items())]
39
 
40
  def get_sample(self, index) -> DSample:
41
  sample = self.dataset_samples[index]
42
+ image_path = str(sample["image"])
43
 
44
  objects = []
45
  ignored_regions = []
46
  masks = []
47
+ for indx, mask_path in enumerate(sample["masks"]):
48
  gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32)
49
  instances_mask = np.zeros_like(gt_mask)
50
  instances_mask[gt_mask == 128] = 2
 
56
  image = cv2.imread(image_path)
57
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
 
59
+ return DSample(
60
+ image,
61
+ np.stack(masks, axis=2),
62
+ objects_ids=objects,
63
+ ignore_ids=ignored_regions,
64
+ sample_id=index,
65
+ )
isegm/data/datasets/lvis.py CHANGED
@@ -11,42 +11,41 @@ from isegm.data.sample import DSample
11
 
12
 
13
  class LvisDataset(ISDataset):
14
- def __init__(self, dataset_path, split='train',
15
- max_overlap_ratio=0.5,
16
- **kwargs):
17
  super(LvisDataset, self).__init__(**kwargs)
18
  dataset_path = Path(dataset_path)
19
- train_categories_path = dataset_path / 'train_categories.json'
20
- self._train_path = dataset_path / 'train'
21
- self._val_path = dataset_path / 'val'
22
 
23
  self.split = split
24
  self.max_overlap_ratio = max_overlap_ratio
25
 
26
- with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f:
27
  json_annotation = json.loads(f.read())
28
 
29
  self.annotations = defaultdict(list)
30
- for x in json_annotation['annotations']:
31
- self.annotations[x['image_id']].append(x)
32
 
33
  if not train_categories_path.exists():
34
  self.generate_train_categories(dataset_path, train_categories_path)
35
- self.dataset_samples = [x for x in json_annotation['images']
36
- if len(self.annotations[x['id']]) > 0]
 
37
 
38
  def get_sample(self, index) -> DSample:
39
  image_info = self.dataset_samples[index]
40
- image_id, image_url = image_info['id'], image_info['coco_url']
41
- image_filename = image_url.split('/')[-1]
42
  image_annotations = self.annotations[image_id]
43
  random.shuffle(image_annotations)
44
 
45
  # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017)
46
- if 'train2017' in image_url:
47
- image_path = self._train_path / 'images' / image_filename
48
  else:
49
- image_path = self._val_path / 'images' / image_filename
50
  image = cv2.imread(str(image_path))
51
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
52
 
@@ -62,9 +61,14 @@ class LvisDataset(ISDataset):
62
  instances_mask = np.zeros_like(object_mask, dtype=np.int32)
63
 
64
  overlap_ids = np.bincount(instances_mask[object_mask].flatten())
65
- overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids)
66
- if overlap_area > 0 and inst_id > 0]
67
- overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area
 
 
 
 
 
68
  if overlap_areas:
69
  overlap_ratio = max(overlap_ratio, max(overlap_areas))
70
  if overlap_ratio > self.max_overlap_ratio:
@@ -77,11 +81,10 @@ class LvisDataset(ISDataset):
77
 
78
  return DSample(image, instances_mask, objects_ids=objects_ids)
79
 
80
-
81
  @staticmethod
82
  def get_mask_from_polygon(annotation, image):
83
  mask = np.zeros(image.shape[:2], dtype=np.int32)
84
- for contour_points in annotation['segmentation']:
85
  contour_points = np.array(contour_points).reshape((-1, 2))
86
  contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
87
  cv2.fillPoly(mask, contour_points, 1)
@@ -90,8 +93,8 @@ class LvisDataset(ISDataset):
90
 
91
  @staticmethod
92
  def generate_train_categories(dataset_path, train_categories_path):
93
- with open(dataset_path / 'train/lvis_train.json', 'r') as f:
94
  annotation = json.load(f)
95
 
96
- with open(train_categories_path, 'w') as f:
97
- json.dump(annotation['categories'], f, indent=1)
 
11
 
12
 
13
  class LvisDataset(ISDataset):
14
+ def __init__(self, dataset_path, split="train", max_overlap_ratio=0.5, **kwargs):
 
 
15
  super(LvisDataset, self).__init__(**kwargs)
16
  dataset_path = Path(dataset_path)
17
+ train_categories_path = dataset_path / "train_categories.json"
18
+ self._train_path = dataset_path / "train"
19
+ self._val_path = dataset_path / "val"
20
 
21
  self.split = split
22
  self.max_overlap_ratio = max_overlap_ratio
23
 
24
+ with open(dataset_path / split / f"lvis_{self.split}.json", "r") as f:
25
  json_annotation = json.loads(f.read())
26
 
27
  self.annotations = defaultdict(list)
28
+ for x in json_annotation["annotations"]:
29
+ self.annotations[x["image_id"]].append(x)
30
 
31
  if not train_categories_path.exists():
32
  self.generate_train_categories(dataset_path, train_categories_path)
33
+ self.dataset_samples = [
34
+ x for x in json_annotation["images"] if len(self.annotations[x["id"]]) > 0
35
+ ]
36
 
37
  def get_sample(self, index) -> DSample:
38
  image_info = self.dataset_samples[index]
39
+ image_id, image_url = image_info["id"], image_info["coco_url"]
40
+ image_filename = image_url.split("/")[-1]
41
  image_annotations = self.annotations[image_id]
42
  random.shuffle(image_annotations)
43
 
44
  # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017)
45
+ if "train2017" in image_url:
46
+ image_path = self._train_path / "images" / image_filename
47
  else:
48
+ image_path = self._val_path / "images" / image_filename
49
  image = cv2.imread(str(image_path))
50
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
 
 
61
  instances_mask = np.zeros_like(object_mask, dtype=np.int32)
62
 
63
  overlap_ids = np.bincount(instances_mask[object_mask].flatten())
64
+ overlap_areas = [
65
+ overlap_area / instances_area[inst_id]
66
+ for inst_id, overlap_area in enumerate(overlap_ids)
67
+ if overlap_area > 0 and inst_id > 0
68
+ ]
69
+ overlap_ratio = (
70
+ np.logical_and(object_mask, instances_mask > 0).sum() / object_area
71
+ )
72
  if overlap_areas:
73
  overlap_ratio = max(overlap_ratio, max(overlap_areas))
74
  if overlap_ratio > self.max_overlap_ratio:
 
81
 
82
  return DSample(image, instances_mask, objects_ids=objects_ids)
83
 
 
84
  @staticmethod
85
  def get_mask_from_polygon(annotation, image):
86
  mask = np.zeros(image.shape[:2], dtype=np.int32)
87
+ for contour_points in annotation["segmentation"]:
88
  contour_points = np.array(contour_points).reshape((-1, 2))
89
  contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
90
  cv2.fillPoly(mask, contour_points, 1)
 
93
 
94
  @staticmethod
95
  def generate_train_categories(dataset_path, train_categories_path):
96
+ with open(dataset_path / "train/lvis_train.json", "r") as f:
97
  annotation = json.load(f)
98
 
99
+ with open(train_categories_path, "w") as f:
100
+ json.dump(annotation["categories"], f, indent=1)
isegm/data/datasets/openimages.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- import random
3
  import pickle as pkl
 
4
  from pathlib import Path
5
 
6
  import cv2
@@ -11,29 +11,31 @@ from isegm.data.sample import DSample
11
 
12
 
13
  class OpenImagesDataset(ISDataset):
14
- def __init__(self, dataset_path, split='train', **kwargs):
15
  super().__init__(**kwargs)
16
- assert split in {'train', 'val', 'test'}
17
 
18
  self.dataset_path = Path(dataset_path)
19
  self._split_path = self.dataset_path / split
20
- self._images_path = self._split_path / 'images'
21
- self._masks_path = self._split_path / 'masks'
22
  self.dataset_split = split
23
 
24
- clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl'
 
 
25
  if os.path.exists(clean_anno_path):
26
- with clean_anno_path.open('rb') as f:
27
  annotations = pkl.load(f)
28
  else:
29
  raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
30
- self.image_id_to_masks = annotations['image_id_to_masks']
31
- self.dataset_samples = annotations['dataset_samples']
32
 
33
  def get_sample(self, index) -> DSample:
34
  image_id = self.dataset_samples[index]
35
 
36
- image_path = str(self._images_path / f'{image_id}.jpg')
37
  image = cv2.imread(image_path)
38
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39
 
@@ -49,9 +51,16 @@ class OpenImagesDataset(ISDataset):
49
  min_height = min(image.shape[0], instances_mask.shape[0])
50
 
51
  if image.shape[0] != min_height or image.shape[1] != min_width:
52
- image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR)
53
- if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width:
54
- instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
 
55
 
56
  object_ids = [1] if instances_mask.sum() > 0 else []
57
 
 
1
  import os
 
2
  import pickle as pkl
3
+ import random
4
  from pathlib import Path
5
 
6
  import cv2
 
11
 
12
 
13
  class OpenImagesDataset(ISDataset):
14
+ def __init__(self, dataset_path, split="train", **kwargs):
15
  super().__init__(**kwargs)
16
+ assert split in {"train", "val", "test"}
17
 
18
  self.dataset_path = Path(dataset_path)
19
  self._split_path = self.dataset_path / split
20
+ self._images_path = self._split_path / "images"
21
+ self._masks_path = self._split_path / "masks"
22
  self.dataset_split = split
23
 
24
+ clean_anno_path = (
25
+ self._split_path / f"{split}-annotations-object-segmentation_clean.pkl"
26
+ )
27
  if os.path.exists(clean_anno_path):
28
+ with clean_anno_path.open("rb") as f:
29
  annotations = pkl.load(f)
30
  else:
31
  raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
32
+ self.image_id_to_masks = annotations["image_id_to_masks"]
33
+ self.dataset_samples = annotations["dataset_samples"]
34
 
35
  def get_sample(self, index) -> DSample:
36
  image_id = self.dataset_samples[index]
37
 
38
+ image_path = str(self._images_path / f"{image_id}.jpg")
39
  image = cv2.imread(image_path)
40
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
 
 
51
  min_height = min(image.shape[0], instances_mask.shape[0])
52
 
53
  if image.shape[0] != min_height or image.shape[1] != min_width:
54
+ image = cv2.resize(
55
+ image, (min_width, min_height), interpolation=cv2.INTER_LINEAR
56
+ )
57
+ if (
58
+ instances_mask.shape[0] != min_height
59
+ or instances_mask.shape[1] != min_width
60
+ ):
61
+ instances_mask = cv2.resize(
62
+ instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST
63
+ )
64
 
65
  object_ids = [1] if instances_mask.sum() > 0 else []
66
 
isegm/data/datasets/pascalvoc.py CHANGED
@@ -9,32 +9,38 @@ from isegm.data.sample import DSample
9
 
10
 
11
  class PascalVocDataset(ISDataset):
12
- def __init__(self, dataset_path, split='train', **kwargs):
13
  super().__init__(**kwargs)
14
- assert split in {'train', 'val', 'trainval', 'test'}
15
 
16
  self.dataset_path = Path(dataset_path)
17
  self._images_path = self.dataset_path / "JPEGImages"
18
  self._insts_path = self.dataset_path / "SegmentationObject"
19
  self.dataset_split = split
20
 
21
- if split == 'test':
22
- with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f:
 
 
23
  self.dataset_samples, self.instance_ids = pkl.load(f)
24
  else:
25
- with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f:
 
 
26
  self.dataset_samples = [name.strip() for name in f.readlines()]
27
 
28
  def get_sample(self, index) -> DSample:
29
  sample_id = self.dataset_samples[index]
30
- image_path = str(self._images_path / f'{sample_id}.jpg')
31
- mask_path = str(self._insts_path / f'{sample_id}.png')
32
 
33
  image = cv2.imread(image_path)
34
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
  instances_mask = cv2.imread(mask_path)
36
- instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32)
37
- if self.dataset_split == 'test':
 
 
38
  instance_id = self.instance_ids[index]
39
  mask = np.zeros_like(instances_mask)
40
  mask[instances_mask == 220] = 220 # ignored area
@@ -45,4 +51,10 @@ class PascalVocDataset(ISDataset):
45
  objects_ids = np.unique(instances_mask)
46
  objects_ids = [x for x in objects_ids if x != 0 and x != 220]
47
 
48
- return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index)
 
 
 
 
 
 
 
9
 
10
 
11
  class PascalVocDataset(ISDataset):
12
+ def __init__(self, dataset_path, split="train", **kwargs):
13
  super().__init__(**kwargs)
14
+ assert split in {"train", "val", "trainval", "test"}
15
 
16
  self.dataset_path = Path(dataset_path)
17
  self._images_path = self.dataset_path / "JPEGImages"
18
  self._insts_path = self.dataset_path / "SegmentationObject"
19
  self.dataset_split = split
20
 
21
+ if split == "test":
22
+ with open(
23
+ self.dataset_path / f"ImageSets/Segmentation/test.pickle", "rb"
24
+ ) as f:
25
  self.dataset_samples, self.instance_ids = pkl.load(f)
26
  else:
27
+ with open(
28
+ self.dataset_path / f"ImageSets/Segmentation/{split}.txt", "r"
29
+ ) as f:
30
  self.dataset_samples = [name.strip() for name in f.readlines()]
31
 
32
  def get_sample(self, index) -> DSample:
33
  sample_id = self.dataset_samples[index]
34
+ image_path = str(self._images_path / f"{sample_id}.jpg")
35
+ mask_path = str(self._insts_path / f"{sample_id}.png")
36
 
37
  image = cv2.imread(image_path)
38
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39
  instances_mask = cv2.imread(mask_path)
40
+ instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(
41
+ np.int32
42
+ )
43
+ if self.dataset_split == "test":
44
  instance_id = self.instance_ids[index]
45
  mask = np.zeros_like(instances_mask)
46
  mask[instances_mask == 220] = 220 # ignored area
 
51
  objects_ids = np.unique(instances_mask)
52
  objects_ids = [x for x in objects_ids if x != 0 and x != 220]
53
 
54
+ return DSample(
55
+ image,
56
+ instances_mask,
57
+ objects_ids=objects_ids,
58
+ ignore_ids=[220],
59
+ sample_id=index,
60
+ )
isegm/data/datasets/sbd.py CHANGED
@@ -5,38 +5,42 @@ import cv2
5
  import numpy as np
6
  from scipy.io import loadmat
7
 
8
- from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
9
  from isegm.data.base import ISDataset
10
  from isegm.data.sample import DSample
 
11
 
12
 
13
  class SBDDataset(ISDataset):
14
- def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs):
15
  super(SBDDataset, self).__init__(**kwargs)
16
- assert split in {'train', 'val'}
17
 
18
  self.dataset_path = Path(dataset_path)
19
  self.dataset_split = split
20
- self._images_path = self.dataset_path / 'img'
21
- self._insts_path = self.dataset_path / 'inst'
22
  self._buggy_objects = dict()
23
  self._buggy_mask_thresh = buggy_mask_thresh
24
 
25
- with open(self.dataset_path / f'{split}.txt', 'r') as f:
26
  self.dataset_samples = [x.strip() for x in f.readlines()]
27
 
28
  def get_sample(self, index):
29
  image_name = self.dataset_samples[index]
30
- image_path = str(self._images_path / f'{image_name}.jpg')
31
- inst_info_path = str(self._insts_path / f'{image_name}.mat')
32
 
33
  image = cv2.imread(image_path)
34
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
- instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
 
 
36
  instances_mask = self.remove_buggy_masks(index, instances_mask)
37
  instances_ids, _ = get_labels_with_sizes(instances_mask)
38
 
39
- return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index)
 
 
40
 
41
  def remove_buggy_masks(self, index, instances_mask):
42
  if self._buggy_mask_thresh > 0.0:
@@ -61,51 +65,55 @@ class SBDDataset(ISDataset):
61
 
62
 
63
  class SBDEvaluationDataset(ISDataset):
64
- def __init__(self, dataset_path, split='val', **kwargs):
65
  super(SBDEvaluationDataset, self).__init__(**kwargs)
66
- assert split in {'train', 'val'}
67
 
68
  self.dataset_path = Path(dataset_path)
69
  self.dataset_split = split
70
- self._images_path = self.dataset_path / 'img'
71
- self._insts_path = self.dataset_path / 'inst'
72
 
73
- with open(self.dataset_path / f'{split}.txt', 'r') as f:
74
  self.dataset_samples = [x.strip() for x in f.readlines()]
75
 
76
  self.dataset_samples = self.get_sbd_images_and_ids_list()
77
 
78
  def get_sample(self, index) -> DSample:
79
  image_name, instance_id = self.dataset_samples[index]
80
- image_path = str(self._images_path / f'{image_name}.jpg')
81
- inst_info_path = str(self._insts_path / f'{image_name}.mat')
82
 
83
  image = cv2.imread(image_path)
84
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85
- instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
 
 
86
  instances_mask[instances_mask != instance_id] = 0
87
  instances_mask[instances_mask > 0] = 1
88
 
89
  return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
90
 
91
  def get_sbd_images_and_ids_list(self):
92
- pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl'
93
 
94
  if pkl_path.exists():
95
- with open(str(pkl_path), 'rb') as fp:
96
  images_and_ids_list = pkl.load(fp)
97
  else:
98
  images_and_ids_list = []
99
 
100
  for sample in self.dataset_samples:
101
- inst_info_path = str(self._insts_path / f'{sample}.mat')
102
- instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
 
 
103
  instances_ids, _ = get_labels_with_sizes(instances_mask)
104
 
105
  for instances_id in instances_ids:
106
  images_and_ids_list.append((sample, instances_id))
107
 
108
- with open(str(pkl_path), 'wb') as fp:
109
  pkl.dump(images_and_ids_list, fp)
110
 
111
  return images_and_ids_list
 
5
  import numpy as np
6
  from scipy.io import loadmat
7
 
 
8
  from isegm.data.base import ISDataset
9
  from isegm.data.sample import DSample
10
+ from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
11
 
12
 
13
  class SBDDataset(ISDataset):
14
+ def __init__(self, dataset_path, split="train", buggy_mask_thresh=0.08, **kwargs):
15
  super(SBDDataset, self).__init__(**kwargs)
16
+ assert split in {"train", "val"}
17
 
18
  self.dataset_path = Path(dataset_path)
19
  self.dataset_split = split
20
+ self._images_path = self.dataset_path / "img"
21
+ self._insts_path = self.dataset_path / "inst"
22
  self._buggy_objects = dict()
23
  self._buggy_mask_thresh = buggy_mask_thresh
24
 
25
+ with open(self.dataset_path / f"{split}.txt", "r") as f:
26
  self.dataset_samples = [x.strip() for x in f.readlines()]
27
 
28
  def get_sample(self, index):
29
  image_name = self.dataset_samples[index]
30
+ image_path = str(self._images_path / f"{image_name}.jpg")
31
+ inst_info_path = str(self._insts_path / f"{image_name}.mat")
32
 
33
  image = cv2.imread(image_path)
34
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
+ instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
36
+ np.int32
37
+ )
38
  instances_mask = self.remove_buggy_masks(index, instances_mask)
39
  instances_ids, _ = get_labels_with_sizes(instances_mask)
40
 
41
+ return DSample(
42
+ image, instances_mask, objects_ids=instances_ids, sample_id=index
43
+ )
44
 
45
  def remove_buggy_masks(self, index, instances_mask):
46
  if self._buggy_mask_thresh > 0.0:
 
65
 
66
 
67
  class SBDEvaluationDataset(ISDataset):
68
+ def __init__(self, dataset_path, split="val", **kwargs):
69
  super(SBDEvaluationDataset, self).__init__(**kwargs)
70
+ assert split in {"train", "val"}
71
 
72
  self.dataset_path = Path(dataset_path)
73
  self.dataset_split = split
74
+ self._images_path = self.dataset_path / "img"
75
+ self._insts_path = self.dataset_path / "inst"
76
 
77
+ with open(self.dataset_path / f"{split}.txt", "r") as f:
78
  self.dataset_samples = [x.strip() for x in f.readlines()]
79
 
80
  self.dataset_samples = self.get_sbd_images_and_ids_list()
81
 
82
  def get_sample(self, index) -> DSample:
83
  image_name, instance_id = self.dataset_samples[index]
84
+ image_path = str(self._images_path / f"{image_name}.jpg")
85
+ inst_info_path = str(self._insts_path / f"{image_name}.mat")
86
 
87
  image = cv2.imread(image_path)
88
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
89
+ instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
90
+ np.int32
91
+ )
92
  instances_mask[instances_mask != instance_id] = 0
93
  instances_mask[instances_mask > 0] = 1
94
 
95
  return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
96
 
97
  def get_sbd_images_and_ids_list(self):
98
+ pkl_path = self.dataset_path / f"{self.dataset_split}_images_and_ids_list.pkl"
99
 
100
  if pkl_path.exists():
101
+ with open(str(pkl_path), "rb") as fp:
102
  images_and_ids_list = pkl.load(fp)
103
  else:
104
  images_and_ids_list = []
105
 
106
  for sample in self.dataset_samples:
107
+ inst_info_path = str(self._insts_path / f"{sample}.mat")
108
+ instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
109
+ np.int32
110
+ )
111
  instances_ids, _ = get_labels_with_sizes(instances_mask)
112
 
113
  for instances_id in instances_ids:
114
  images_and_ids_list.append((sample, instances_id))
115
 
116
+ with open(str(pkl_path), "wb") as fp:
117
  pkl.dump(images_and_ids_list, fp)
118
 
119
  return images_and_ids_list
isegm/data/points_sampler.py CHANGED
@@ -1,8 +1,10 @@
1
- import cv2
2
  import math
3
  import random
4
- import numpy as np
5
  from functools import lru_cache
 
 
 
 
6
  from .sample import DSample
7
 
8
 
@@ -28,13 +30,25 @@ class BasePointSampler:
28
 
29
 
30
  class MultiPointSampler(BasePointSampler):
31
- def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1,
32
- positive_erode_prob=0.9, positive_erode_iters=3,
33
- negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5,
34
- merge_objects_prob=0.0, max_num_merged_objects=2,
35
- use_hierarchy=False, soft_targets=False,
36
- first_click_center=False, only_one_first_click=False,
37
- sfc_inner_k=1.7, sfc_full_inner_prob=0.0):
 
 
 
 
 
 
 
 
 
 
 
 
38
  super().__init__()
39
  self.max_num_points = max_num_points
40
  self.expand_ratio = expand_ratio
@@ -52,8 +66,12 @@ class MultiPointSampler(BasePointSampler):
52
  max_num_merged_objects = max_num_points
53
  self.max_num_merged_objects = max_num_merged_objects
54
 
55
- self.neg_strategies = ['bg', 'other', 'border']
56
- self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob]
 
 
 
 
57
  assert math.isclose(sum(self.neg_strategies_prob), 1.0)
58
 
59
  self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
@@ -66,7 +84,7 @@ class MultiPointSampler(BasePointSampler):
66
  self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
67
  self._selected_masks = [[]]
68
  self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
69
- self._neg_masks['required'] = []
70
  return
71
 
72
  gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
@@ -80,14 +98,16 @@ class MultiPointSampler(BasePointSampler):
80
  if len(sample) <= len(self._selected_masks):
81
  neg_mask_other = neg_mask_bg
82
  else:
83
- neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()),
84
- np.logical_not(binary_gt_mask))
 
 
85
 
86
  self._neg_masks = {
87
- 'bg': neg_mask_bg,
88
- 'other': neg_mask_other,
89
- 'border': neg_mask_border,
90
- 'required': neg_masks
91
  }
92
 
93
  def _sample_mask(self, sample: DSample):
@@ -104,7 +124,11 @@ class MultiPointSampler(BasePointSampler):
104
  pos_segments = []
105
  neg_segments = []
106
  for obj_id in random_ids:
107
- obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample)
 
 
 
 
108
  if gt_mask is None:
109
  gt_mask = obj_gt_mask
110
  else:
@@ -123,35 +147,45 @@ class MultiPointSampler(BasePointSampler):
123
 
124
  if not self.use_hierarchy:
125
  node_mask = sample.get_object_mask(obj_id)
126
- gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
 
 
127
  return gt_mask, [node_mask], []
128
 
129
  def _select_node(node_id):
130
  node_info = objs_tree[node_id]
131
- if not node_info['children'] or random.random() < 0.5:
132
  return node_id
133
- return _select_node(random.choice(node_info['children']))
134
 
135
  selected_node = _select_node(obj_id)
136
  node_info = objs_tree[selected_node]
137
  node_mask = sample.get_object_mask(selected_node)
138
- gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask
 
 
 
 
139
  pos_mask = node_mask.copy()
140
 
141
  negative_segments = []
142
- if node_info['parent'] is not None and node_info['parent'] in objs_tree:
143
- parent_mask = sample.get_object_mask(node_info['parent'])
144
- negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask)))
145
-
146
- for child_id in node_info['children']:
147
- if objs_tree[child_id]['area'] / node_info['area'] < 0.10:
 
 
148
  child_mask = sample.get_object_mask(child_id)
149
  pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
150
 
151
- if node_info['children']:
152
- max_disabled_children = min(len(node_info['children']), 3)
153
  num_disabled_children = np.random.randint(0, max_disabled_children + 1)
154
- disabled_children = random.sample(node_info['children'], num_disabled_children)
 
 
155
 
156
  for child_id in disabled_children:
157
  child_mask = sample.get_object_mask(child_id)
@@ -167,24 +201,32 @@ class MultiPointSampler(BasePointSampler):
167
 
168
  def sample_points(self):
169
  assert self._selected_mask is not None
170
- pos_points = self._multi_mask_sample_points(self._selected_masks,
171
- is_negative=[False] * len(self._selected_masks),
172
- with_first_click=self.first_click_center)
173
-
174
- neg_strategy = [(self._neg_masks[k], prob)
175
- for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)]
176
- neg_masks = self._neg_masks['required'] + [neg_strategy]
177
- neg_points = self._multi_mask_sample_points(neg_masks,
178
- is_negative=[False] * len(self._neg_masks['required']) + [True])
 
 
 
 
 
179
 
180
  return pos_points + neg_points
181
 
182
- def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False):
183
- selected_masks = selected_masks[:self.max_num_points]
 
 
184
 
185
  each_obj_points = [
186
- self._sample_points(mask, is_negative=is_negative[i],
187
- with_first_click=with_first_click)
 
188
  for i, mask in enumerate(selected_masks)
189
  ]
190
  each_obj_points = [x for x in each_obj_points if len(x) > 0]
@@ -200,17 +242,27 @@ class MultiPointSampler(BasePointSampler):
200
 
201
  aggregated_masks_with_prob = []
202
  for indx, x in enumerate(selected_masks):
203
- if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)):
 
 
 
 
204
  for t, prob in x:
205
- aggregated_masks_with_prob.append((t, prob / len(selected_masks)))
 
 
206
  else:
207
  aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
208
 
209
- other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True)
 
 
210
  if len(other_points_union) + len(points) <= self.max_num_points:
211
  points.extend(other_points_union)
212
  else:
213
- points.extend(random.sample(other_points_union, self.max_num_points - len(points)))
 
 
214
 
215
  if len(points) < self.max_num_points:
216
  points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
@@ -219,9 +271,13 @@ class MultiPointSampler(BasePointSampler):
219
 
220
  def _sample_points(self, mask, is_negative=False, with_first_click=False):
221
  if is_negative:
222
- num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs)
 
 
223
  else:
224
- num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs)
 
 
225
 
226
  indices_probs = None
227
  if isinstance(mask, (list, tuple)):
@@ -237,9 +293,13 @@ class MultiPointSampler(BasePointSampler):
237
  first_click = with_first_click and j == 0 and indices_probs is None
238
 
239
  if first_click:
240
- point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob)
 
 
241
  elif indices_probs:
242
- point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs)
 
 
243
  point_indices = indices[point_indices_indx][0]
244
  else:
245
  point_indices = indices
@@ -247,7 +307,9 @@ class MultiPointSampler(BasePointSampler):
247
  num_indices = len(point_indices)
248
  if num_indices > 0:
249
  point_indx = 0 if first_click else 100
250
- click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx]
 
 
251
  points.append(click)
252
 
253
  return points
@@ -257,8 +319,9 @@ class MultiPointSampler(BasePointSampler):
257
  return mask
258
 
259
  kernel = np.ones((3, 3), np.uint8)
260
- eroded_mask = cv2.erode(mask.astype(np.uint8),
261
- kernel, iterations=self.positive_erode_iters).astype(np.bool)
 
262
 
263
  if eroded_mask.sum() > 10:
264
  return eroded_mask
@@ -291,7 +354,7 @@ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
291
  if full_prob > 0 and random.random() < full_prob:
292
  return obj_mask
293
 
294
- padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant')
295
 
296
  dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
297
  if k > 0:
 
 
1
  import math
2
  import random
 
3
  from functools import lru_cache
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
  from .sample import DSample
9
 
10
 
 
30
 
31
 
32
  class MultiPointSampler(BasePointSampler):
33
+ def __init__(
34
+ self,
35
+ max_num_points,
36
+ prob_gamma=0.7,
37
+ expand_ratio=0.1,
38
+ positive_erode_prob=0.9,
39
+ positive_erode_iters=3,
40
+ negative_bg_prob=0.1,
41
+ negative_other_prob=0.4,
42
+ negative_border_prob=0.5,
43
+ merge_objects_prob=0.0,
44
+ max_num_merged_objects=2,
45
+ use_hierarchy=False,
46
+ soft_targets=False,
47
+ first_click_center=False,
48
+ only_one_first_click=False,
49
+ sfc_inner_k=1.7,
50
+ sfc_full_inner_prob=0.0,
51
+ ):
52
  super().__init__()
53
  self.max_num_points = max_num_points
54
  self.expand_ratio = expand_ratio
 
66
  max_num_merged_objects = max_num_points
67
  self.max_num_merged_objects = max_num_merged_objects
68
 
69
+ self.neg_strategies = ["bg", "other", "border"]
70
+ self.neg_strategies_prob = [
71
+ negative_bg_prob,
72
+ negative_other_prob,
73
+ negative_border_prob,
74
+ ]
75
  assert math.isclose(sum(self.neg_strategies_prob), 1.0)
76
 
77
  self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
 
84
  self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
85
  self._selected_masks = [[]]
86
  self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
87
+ self._neg_masks["required"] = []
88
  return
89
 
90
  gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
 
98
  if len(sample) <= len(self._selected_masks):
99
  neg_mask_other = neg_mask_bg
100
  else:
101
+ neg_mask_other = np.logical_and(
102
+ np.logical_not(sample.get_background_mask()),
103
+ np.logical_not(binary_gt_mask),
104
+ )
105
 
106
  self._neg_masks = {
107
+ "bg": neg_mask_bg,
108
+ "other": neg_mask_other,
109
+ "border": neg_mask_border,
110
+ "required": neg_masks,
111
  }
112
 
113
  def _sample_mask(self, sample: DSample):
 
124
  pos_segments = []
125
  neg_segments = []
126
  for obj_id in random_ids:
127
+ (
128
+ obj_gt_mask,
129
+ obj_pos_segments,
130
+ obj_neg_segments,
131
+ ) = self._sample_from_masks_layer(obj_id, sample)
132
  if gt_mask is None:
133
  gt_mask = obj_gt_mask
134
  else:
 
147
 
148
  if not self.use_hierarchy:
149
  node_mask = sample.get_object_mask(obj_id)
150
+ gt_mask = (
151
+ sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
152
+ )
153
  return gt_mask, [node_mask], []
154
 
155
  def _select_node(node_id):
156
  node_info = objs_tree[node_id]
157
+ if not node_info["children"] or random.random() < 0.5:
158
  return node_id
159
+ return _select_node(random.choice(node_info["children"]))
160
 
161
  selected_node = _select_node(obj_id)
162
  node_info = objs_tree[selected_node]
163
  node_mask = sample.get_object_mask(selected_node)
164
+ gt_mask = (
165
+ sample.get_soft_object_mask(selected_node)
166
+ if self.soft_targets
167
+ else node_mask
168
+ )
169
  pos_mask = node_mask.copy()
170
 
171
  negative_segments = []
172
+ if node_info["parent"] is not None and node_info["parent"] in objs_tree:
173
+ parent_mask = sample.get_object_mask(node_info["parent"])
174
+ negative_segments.append(
175
+ np.logical_and(parent_mask, np.logical_not(node_mask))
176
+ )
177
+
178
+ for child_id in node_info["children"]:
179
+ if objs_tree[child_id]["area"] / node_info["area"] < 0.10:
180
  child_mask = sample.get_object_mask(child_id)
181
  pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
182
 
183
+ if node_info["children"]:
184
+ max_disabled_children = min(len(node_info["children"]), 3)
185
  num_disabled_children = np.random.randint(0, max_disabled_children + 1)
186
+ disabled_children = random.sample(
187
+ node_info["children"], num_disabled_children
188
+ )
189
 
190
  for child_id in disabled_children:
191
  child_mask = sample.get_object_mask(child_id)
 
201
 
202
  def sample_points(self):
203
  assert self._selected_mask is not None
204
+ pos_points = self._multi_mask_sample_points(
205
+ self._selected_masks,
206
+ is_negative=[False] * len(self._selected_masks),
207
+ with_first_click=self.first_click_center,
208
+ )
209
+
210
+ neg_strategy = [
211
+ (self._neg_masks[k], prob)
212
+ for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)
213
+ ]
214
+ neg_masks = self._neg_masks["required"] + [neg_strategy]
215
+ neg_points = self._multi_mask_sample_points(
216
+ neg_masks, is_negative=[False] * len(self._neg_masks["required"]) + [True]
217
+ )
218
 
219
  return pos_points + neg_points
220
 
221
+ def _multi_mask_sample_points(
222
+ self, selected_masks, is_negative, with_first_click=False
223
+ ):
224
+ selected_masks = selected_masks[: self.max_num_points]
225
 
226
  each_obj_points = [
227
+ self._sample_points(
228
+ mask, is_negative=is_negative[i], with_first_click=with_first_click
229
+ )
230
  for i, mask in enumerate(selected_masks)
231
  ]
232
  each_obj_points = [x for x in each_obj_points if len(x) > 0]
 
242
 
243
  aggregated_masks_with_prob = []
244
  for indx, x in enumerate(selected_masks):
245
+ if (
246
+ isinstance(x, (list, tuple))
247
+ and x
248
+ and isinstance(x[0], (list, tuple))
249
+ ):
250
  for t, prob in x:
251
+ aggregated_masks_with_prob.append(
252
+ (t, prob / len(selected_masks))
253
+ )
254
  else:
255
  aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
256
 
257
+ other_points_union = self._sample_points(
258
+ aggregated_masks_with_prob, is_negative=True
259
+ )
260
  if len(other_points_union) + len(points) <= self.max_num_points:
261
  points.extend(other_points_union)
262
  else:
263
+ points.extend(
264
+ random.sample(other_points_union, self.max_num_points - len(points))
265
+ )
266
 
267
  if len(points) < self.max_num_points:
268
  points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
 
271
 
272
  def _sample_points(self, mask, is_negative=False, with_first_click=False):
273
  if is_negative:
274
+ num_points = np.random.choice(
275
+ np.arange(self.max_num_points + 1), p=self._neg_probs
276
+ )
277
  else:
278
+ num_points = 1 + np.random.choice(
279
+ np.arange(self.max_num_points), p=self._pos_probs
280
+ )
281
 
282
  indices_probs = None
283
  if isinstance(mask, (list, tuple)):
 
293
  first_click = with_first_click and j == 0 and indices_probs is None
294
 
295
  if first_click:
296
+ point_indices = get_point_candidates(
297
+ mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob
298
+ )
299
  elif indices_probs:
300
+ point_indices_indx = np.random.choice(
301
+ np.arange(len(indices)), p=indices_probs
302
+ )
303
  point_indices = indices[point_indices_indx][0]
304
  else:
305
  point_indices = indices
 
307
  num_indices = len(point_indices)
308
  if num_indices > 0:
309
  point_indx = 0 if first_click else 100
310
+ click = point_indices[np.random.randint(0, num_indices)].tolist() + [
311
+ point_indx
312
+ ]
313
  points.append(click)
314
 
315
  return points
 
319
  return mask
320
 
321
  kernel = np.ones((3, 3), np.uint8)
322
+ eroded_mask = cv2.erode(
323
+ mask.astype(np.uint8), kernel, iterations=self.positive_erode_iters
324
+ ).astype(np.bool)
325
 
326
  if eroded_mask.sum() > 10:
327
  return eroded_mask
 
354
  if full_prob > 0 and random.random() < full_prob:
355
  return obj_mask
356
 
357
+ padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), "constant")
358
 
359
  dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
360
  if k > 0:
isegm/data/sample.py CHANGED
@@ -1,13 +1,22 @@
1
- import numpy as np
2
  from copy import deepcopy
3
- from isegm.utils.misc import get_labels_with_sizes
4
- from isegm.data.transforms import remove_image_only_transforms
5
  from albumentations import ReplayCompose
6
 
 
 
 
7
 
8
  class DSample:
9
- def __init__(self, image, encoded_masks, objects=None,
10
- objects_ids=None, ignore_ids=None, sample_id=None):
 
 
 
 
 
 
 
11
  self.image = image
12
  self.sample_id = sample_id
13
 
@@ -24,9 +33,9 @@ class DSample:
24
  self._objects = dict()
25
  for indx, obj_mapping in enumerate(objects_ids):
26
  self._objects[indx] = {
27
- 'parent': None,
28
- 'mapping': obj_mapping,
29
- 'children': []
30
  }
31
 
32
  if ignore_ids:
@@ -44,10 +53,10 @@ class DSample:
44
  def augment(self, augmentator):
45
  self.reset_augmentation()
46
  aug_output = augmentator(image=self.image, mask=self._encoded_masks)
47
- self.image = aug_output['image']
48
- self._encoded_masks = aug_output['mask']
49
 
50
- aug_replay = aug_output.get('replay', None)
51
  if aug_replay:
52
  assert len(self._ignored_regions) == 0
53
  mask_replay = remove_image_only_transforms(aug_replay)
@@ -69,15 +78,15 @@ class DSample:
69
  self._soft_mask_aug = None
70
 
71
  def remove_small_objects(self, min_area):
72
- if self._objects and not 'area' in list(self._objects.values())[0]:
73
  self._compute_objects_areas()
74
 
75
  for obj_id, obj_info in list(self._objects.items()):
76
- if obj_info['area'] < min_area:
77
  self._remove_object(obj_id)
78
 
79
  def get_object_mask(self, obj_id):
80
- layer_indx, mask_id = self._objects[obj_id]['mapping']
81
  obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
82
  if self._ignored_regions:
83
  for layer_indx, mask_id in self._ignored_regions:
@@ -89,9 +98,13 @@ class DSample:
89
  def get_soft_object_mask(self, obj_id):
90
  assert self._soft_mask_aug is not None
91
  original_encoded_masks = self._original_data[1]
92
- layer_indx, mask_id = self._objects[obj_id]['mapping']
93
- obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32)
94
- obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image']
 
 
 
 
95
  return np.clip(obj_mask, 0, 1)
96
 
97
  def get_background_mask(self):
@@ -108,20 +121,28 @@ class DSample:
108
 
109
  @property
110
  def root_objects(self):
111
- return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None]
 
 
 
 
112
 
113
  def _compute_objects_areas(self):
114
- inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()}
 
 
115
  ignored_regions_keys = set(self._ignored_regions)
116
 
117
  for layer_indx in range(self._encoded_masks.shape[2]):
118
- objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx])
 
 
119
  for obj_id, obj_area in zip(objects_ids, objects_areas):
120
  inv_key = (layer_indx, obj_id)
121
  if inv_key in ignored_regions_keys:
122
  continue
123
  try:
124
- self._objects[inverse_index[inv_key]]['area'] = obj_area
125
  del inverse_index[inv_key]
126
  except KeyError:
127
  layer = self._encoded_masks[:, :, layer_indx]
@@ -129,18 +150,20 @@ class DSample:
129
  self._encoded_masks[:, :, layer_indx] = layer
130
 
131
  for obj_id in inverse_index.values():
132
- self._objects[obj_id]['area'] = 0
133
 
134
  def _remove_object(self, obj_id):
135
  obj_info = self._objects[obj_id]
136
- obj_parent = obj_info['parent']
137
- for child_id in obj_info['children']:
138
- self._objects[child_id]['parent'] = obj_parent
139
 
140
  if obj_parent is not None:
141
- parent_children = self._objects[obj_parent]['children']
142
  parent_children = [x for x in parent_children if x != obj_id]
143
- self._objects[obj_parent]['children'] = parent_children + obj_info['children']
 
 
144
 
145
  del self._objects[obj_id]
146
 
 
 
1
  from copy import deepcopy
2
+
3
+ import numpy as np
4
  from albumentations import ReplayCompose
5
 
6
+ from isegm.data.transforms import remove_image_only_transforms
7
+ from isegm.utils.misc import get_labels_with_sizes
8
+
9
 
10
  class DSample:
11
+ def __init__(
12
+ self,
13
+ image,
14
+ encoded_masks,
15
+ objects=None,
16
+ objects_ids=None,
17
+ ignore_ids=None,
18
+ sample_id=None,
19
+ ):
20
  self.image = image
21
  self.sample_id = sample_id
22
 
 
33
  self._objects = dict()
34
  for indx, obj_mapping in enumerate(objects_ids):
35
  self._objects[indx] = {
36
+ "parent": None,
37
+ "mapping": obj_mapping,
38
+ "children": [],
39
  }
40
 
41
  if ignore_ids:
 
53
  def augment(self, augmentator):
54
  self.reset_augmentation()
55
  aug_output = augmentator(image=self.image, mask=self._encoded_masks)
56
+ self.image = aug_output["image"]
57
+ self._encoded_masks = aug_output["mask"]
58
 
59
+ aug_replay = aug_output.get("replay", None)
60
  if aug_replay:
61
  assert len(self._ignored_regions) == 0
62
  mask_replay = remove_image_only_transforms(aug_replay)
 
78
  self._soft_mask_aug = None
79
 
80
  def remove_small_objects(self, min_area):
81
+ if self._objects and not "area" in list(self._objects.values())[0]:
82
  self._compute_objects_areas()
83
 
84
  for obj_id, obj_info in list(self._objects.items()):
85
+ if obj_info["area"] < min_area:
86
  self._remove_object(obj_id)
87
 
88
  def get_object_mask(self, obj_id):
89
+ layer_indx, mask_id = self._objects[obj_id]["mapping"]
90
  obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
91
  if self._ignored_regions:
92
  for layer_indx, mask_id in self._ignored_regions:
 
98
  def get_soft_object_mask(self, obj_id):
99
  assert self._soft_mask_aug is not None
100
  original_encoded_masks = self._original_data[1]
101
+ layer_indx, mask_id = self._objects[obj_id]["mapping"]
102
+ obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
103
+ np.float32
104
+ )
105
+ obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[
106
+ "image"
107
+ ]
108
  return np.clip(obj_mask, 0, 1)
109
 
110
  def get_background_mask(self):
 
121
 
122
  @property
123
  def root_objects(self):
124
+ return [
125
+ obj_id
126
+ for obj_id, obj_info in self._objects.items()
127
+ if obj_info["parent"] is None
128
+ ]
129
 
130
  def _compute_objects_areas(self):
131
+ inverse_index = {
132
+ node["mapping"]: node_id for node_id, node in self._objects.items()
133
+ }
134
  ignored_regions_keys = set(self._ignored_regions)
135
 
136
  for layer_indx in range(self._encoded_masks.shape[2]):
137
+ objects_ids, objects_areas = get_labels_with_sizes(
138
+ self._encoded_masks[:, :, layer_indx]
139
+ )
140
  for obj_id, obj_area in zip(objects_ids, objects_areas):
141
  inv_key = (layer_indx, obj_id)
142
  if inv_key in ignored_regions_keys:
143
  continue
144
  try:
145
+ self._objects[inverse_index[inv_key]]["area"] = obj_area
146
  del inverse_index[inv_key]
147
  except KeyError:
148
  layer = self._encoded_masks[:, :, layer_indx]
 
150
  self._encoded_masks[:, :, layer_indx] = layer
151
 
152
  for obj_id in inverse_index.values():
153
+ self._objects[obj_id]["area"] = 0
154
 
155
  def _remove_object(self, obj_id):
156
  obj_info = self._objects[obj_id]
157
+ obj_parent = obj_info["parent"]
158
+ for child_id in obj_info["children"]:
159
+ self._objects[child_id]["parent"] = obj_parent
160
 
161
  if obj_parent is not None:
162
+ parent_children = self._objects[obj_parent]["children"]
163
  parent_children = [x for x in parent_children if x != obj_id]
164
+ self._objects[obj_parent]["children"] = (
165
+ parent_children + obj_info["children"]
166
+ )
167
 
168
  del self._objects[obj_id]
169
 
isegm/data/transforms.py CHANGED
@@ -1,28 +1,40 @@
1
- import cv2
2
  import random
3
- import numpy as np
4
 
 
 
 
 
5
  from albumentations.core.serialization import SERIALIZABLE_REGISTRY
6
- from albumentations import ImageOnlyTransform, DualTransform
7
  from albumentations.core.transforms_interface import to_tuple
8
- from albumentations.augmentations import functional as F
9
- from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes
 
10
 
11
 
12
  class UniformRandomResize(DualTransform):
13
- def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1):
 
 
 
 
 
 
14
  super().__init__(always_apply, p)
15
  self.scale_range = scale_range
16
  self.interpolation = interpolation
17
 
18
  def get_params_dependent_on_targets(self, params):
19
  scale = random.uniform(*self.scale_range)
20
- height = int(round(params['image'].shape[0] * scale))
21
- width = int(round(params['image'].shape[1] * scale))
22
- return {'new_height': height, 'new_width': width}
23
 
24
- def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params):
25
- return F.resize(img, height=new_height, width=new_width, interpolation=interpolation)
 
 
 
 
26
 
27
  def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
28
  scale_x = new_width / params["cols"]
@@ -39,16 +51,16 @@ class UniformRandomResize(DualTransform):
39
 
40
  class ZoomIn(DualTransform):
41
  def __init__(
42
- self,
43
- height,
44
- width,
45
- bbox_jitter=0.1,
46
- expansion_ratio=1.4,
47
- min_crop_size=200,
48
- min_area=100,
49
- always_resize=False,
50
- always_apply=False,
51
- p=0.5,
52
  ):
53
  super(ZoomIn, self).__init__(always_apply, p)
54
  self.height = height
@@ -66,7 +78,7 @@ class ZoomIn(DualTransform):
66
  return img
67
 
68
  rmin, rmax, cmin, cmax = bbox
69
- img = img[rmin:rmax + 1, cmin:cmax + 1]
70
  img = F.resize(img, height=self.height, width=self.width)
71
 
72
  return img
@@ -74,12 +86,16 @@ class ZoomIn(DualTransform):
74
  def apply_to_mask(self, mask, selected_object, bbox, **params):
75
  if selected_object is None:
76
  if self.always_resize:
77
- mask = F.resize(mask, height=self.height, width=self.width,
78
- interpolation=cv2.INTER_NEAREST)
 
 
 
 
79
  return mask
80
 
81
  rmin, rmax, cmin, cmax = bbox
82
- mask = mask[rmin:rmax + 1, cmin:cmax + 1]
83
  if isinstance(selected_object, tuple):
84
  layer_indx, mask_id = selected_object
85
  obj_mask = mask[:, :, layer_indx] == mask_id
@@ -90,25 +106,34 @@ class ZoomIn(DualTransform):
90
  new_mask = mask.copy()
91
  new_mask[np.logical_not(obj_mask)] = 0
92
 
93
- new_mask = F.resize(new_mask, height=self.height, width=self.width,
94
- interpolation=cv2.INTER_NEAREST)
 
 
 
 
95
  return new_mask
96
 
97
  def get_params_dependent_on_targets(self, params):
98
- instances = params['mask']
99
 
100
  is_mask_layer = len(instances.shape) > 2
101
  candidates = []
102
  if is_mask_layer:
103
  for layer_indx in range(instances.shape[2]):
104
  labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
105
- candidates.extend([(layer_indx, obj_id)
106
- for obj_id, area in zip(labels, areas)
107
- if area > self.min_area])
 
 
 
 
108
  else:
109
  labels, areas = get_labels_with_sizes(instances)
110
- candidates = [obj_id for obj_id, area in zip(labels, areas)
111
- if area > self.min_area]
 
112
 
113
  selected_object = None
114
  bbox = None
@@ -131,10 +156,7 @@ class ZoomIn(DualTransform):
131
  bbox = self._jitter_bbox(bbox)
132
  bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
133
 
134
- return {
135
- 'selected_object': selected_object,
136
- 'bbox': bbox
137
- }
138
 
139
  def _jitter_bbox(self, bbox):
140
  rmin, rmax, cmin, cmax = bbox
@@ -158,21 +180,28 @@ class ZoomIn(DualTransform):
158
  return ["mask"]
159
 
160
  def get_transform_init_args_names(self):
161
- return ("height", "width", "bbox_jitter",
162
- "expansion_ratio", "min_crop_size", "min_area", "always_resize")
 
 
 
 
 
 
 
163
 
164
 
165
  def remove_image_only_transforms(sdict):
166
- if not 'transforms' in sdict:
167
  return sdict
168
 
169
  keep_transforms = []
170
- for tdict in sdict['transforms']:
171
- cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']]
172
- if 'transforms' in tdict:
173
  keep_transforms.append(remove_image_only_transforms(tdict))
174
  elif not issubclass(cls, ImageOnlyTransform):
175
  keep_transforms.append(tdict)
176
- sdict['transforms'] = keep_transforms
177
 
178
  return sdict
 
 
1
  import random
 
2
 
3
+ import cv2
4
+ import numpy as np
5
+ from albumentations import DualTransform, ImageOnlyTransform
6
+ from albumentations.augmentations import functional as F
7
  from albumentations.core.serialization import SERIALIZABLE_REGISTRY
 
8
  from albumentations.core.transforms_interface import to_tuple
9
+
10
+ from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask,
11
+ get_labels_with_sizes)
12
 
13
 
14
  class UniformRandomResize(DualTransform):
15
+ def __init__(
16
+ self,
17
+ scale_range=(0.9, 1.1),
18
+ interpolation=cv2.INTER_LINEAR,
19
+ always_apply=False,
20
+ p=1,
21
+ ):
22
  super().__init__(always_apply, p)
23
  self.scale_range = scale_range
24
  self.interpolation = interpolation
25
 
26
  def get_params_dependent_on_targets(self, params):
27
  scale = random.uniform(*self.scale_range)
28
+ height = int(round(params["image"].shape[0] * scale))
29
+ width = int(round(params["image"].shape[1] * scale))
30
+ return {"new_height": height, "new_width": width}
31
 
32
+ def apply(
33
+ self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params
34
+ ):
35
+ return F.resize(
36
+ img, height=new_height, width=new_width, interpolation=interpolation
37
+ )
38
 
39
  def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
40
  scale_x = new_width / params["cols"]
 
51
 
52
  class ZoomIn(DualTransform):
53
  def __init__(
54
+ self,
55
+ height,
56
+ width,
57
+ bbox_jitter=0.1,
58
+ expansion_ratio=1.4,
59
+ min_crop_size=200,
60
+ min_area=100,
61
+ always_resize=False,
62
+ always_apply=False,
63
+ p=0.5,
64
  ):
65
  super(ZoomIn, self).__init__(always_apply, p)
66
  self.height = height
 
78
  return img
79
 
80
  rmin, rmax, cmin, cmax = bbox
81
+ img = img[rmin : rmax + 1, cmin : cmax + 1]
82
  img = F.resize(img, height=self.height, width=self.width)
83
 
84
  return img
 
86
  def apply_to_mask(self, mask, selected_object, bbox, **params):
87
  if selected_object is None:
88
  if self.always_resize:
89
+ mask = F.resize(
90
+ mask,
91
+ height=self.height,
92
+ width=self.width,
93
+ interpolation=cv2.INTER_NEAREST,
94
+ )
95
  return mask
96
 
97
  rmin, rmax, cmin, cmax = bbox
98
+ mask = mask[rmin : rmax + 1, cmin : cmax + 1]
99
  if isinstance(selected_object, tuple):
100
  layer_indx, mask_id = selected_object
101
  obj_mask = mask[:, :, layer_indx] == mask_id
 
106
  new_mask = mask.copy()
107
  new_mask[np.logical_not(obj_mask)] = 0
108
 
109
+ new_mask = F.resize(
110
+ new_mask,
111
+ height=self.height,
112
+ width=self.width,
113
+ interpolation=cv2.INTER_NEAREST,
114
+ )
115
  return new_mask
116
 
117
  def get_params_dependent_on_targets(self, params):
118
+ instances = params["mask"]
119
 
120
  is_mask_layer = len(instances.shape) > 2
121
  candidates = []
122
  if is_mask_layer:
123
  for layer_indx in range(instances.shape[2]):
124
  labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
125
+ candidates.extend(
126
+ [
127
+ (layer_indx, obj_id)
128
+ for obj_id, area in zip(labels, areas)
129
+ if area > self.min_area
130
+ ]
131
+ )
132
  else:
133
  labels, areas = get_labels_with_sizes(instances)
134
+ candidates = [
135
+ obj_id for obj_id, area in zip(labels, areas) if area > self.min_area
136
+ ]
137
 
138
  selected_object = None
139
  bbox = None
 
156
  bbox = self._jitter_bbox(bbox)
157
  bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
158
 
159
+ return {"selected_object": selected_object, "bbox": bbox}
 
 
 
160
 
161
  def _jitter_bbox(self, bbox):
162
  rmin, rmax, cmin, cmax = bbox
 
180
  return ["mask"]
181
 
182
  def get_transform_init_args_names(self):
183
+ return (
184
+ "height",
185
+ "width",
186
+ "bbox_jitter",
187
+ "expansion_ratio",
188
+ "min_crop_size",
189
+ "min_area",
190
+ "always_resize",
191
+ )
192
 
193
 
194
  def remove_image_only_transforms(sdict):
195
+ if not "transforms" in sdict:
196
  return sdict
197
 
198
  keep_transforms = []
199
+ for tdict in sdict["transforms"]:
200
+ cls = SERIALIZABLE_REGISTRY[tdict["__class_fullname__"]]
201
+ if "transforms" in tdict:
202
  keep_transforms.append(remove_image_only_transforms(tdict))
203
  elif not issubclass(cls, ImageOnlyTransform):
204
  keep_transforms.append(tdict)
205
+ sdict["transforms"] = keep_transforms
206
 
207
  return sdict
isegm/engine/optimizer.py CHANGED
@@ -1,27 +1,29 @@
1
- import torch
2
  import math
 
 
 
3
  from isegm.utils.log import logger
4
 
5
 
6
  def get_optimizer(model, opt_name, opt_kwargs):
7
  params = []
8
- base_lr = opt_kwargs['lr']
9
  for name, param in model.named_parameters():
10
- param_group = {'params': [param]}
11
  if not param.requires_grad:
12
  params.append(param_group)
13
  continue
14
 
15
- if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0):
16
  logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
17
- param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult
18
 
19
  params.append(param_group)
20
 
21
  optimizer = {
22
- 'sgd': torch.optim.SGD,
23
- 'adam': torch.optim.Adam,
24
- 'adamw': torch.optim.AdamW
25
  }[opt_name.lower()](params, **opt_kwargs)
26
 
27
  return optimizer
 
 
1
  import math
2
+
3
+ import torch
4
+
5
  from isegm.utils.log import logger
6
 
7
 
8
  def get_optimizer(model, opt_name, opt_kwargs):
9
  params = []
10
+ base_lr = opt_kwargs["lr"]
11
  for name, param in model.named_parameters():
12
+ param_group = {"params": [param]}
13
  if not param.requires_grad:
14
  params.append(param_group)
15
  continue
16
 
17
+ if not math.isclose(getattr(param, "lr_mult", 1.0), 1.0):
18
  logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
19
+ param_group["lr"] = param_group.get("lr", base_lr) * param.lr_mult
20
 
21
  params.append(param_group)
22
 
23
  optimizer = {
24
+ "sgd": torch.optim.SGD,
25
+ "adam": torch.optim.Adam,
26
+ "adamw": torch.optim.AdamW,
27
  }[opt_name.lower()](params, **opt_kwargs)
28
 
29
  return optimizer
isegm/engine/trainer.py CHANGED
@@ -1,40 +1,48 @@
 
1
  import os
2
  import random
3
- import logging
4
- from copy import deepcopy
5
  from collections import defaultdict
 
6
 
7
  import cv2
8
- import torch
9
  import numpy as np
10
- from tqdm import tqdm
11
  from torch.utils.data import DataLoader
 
12
 
13
- from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg
14
- from isegm.utils.vis import draw_probmap, draw_points
 
15
  from isegm.utils.misc import save_checkpoint
16
  from isegm.utils.serialization import get_config_repr
17
- from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict
 
18
  from .optimizer import get_optimizer
19
 
20
 
21
  class ISTrainer(object):
22
- def __init__(self, model, cfg, model_cfg, loss_cfg,
23
- trainset, valset,
24
- optimizer='adam',
25
- optimizer_params=None,
26
- image_dump_interval=200,
27
- checkpoint_interval=10,
28
- tb_dump_period=25,
29
- max_interactive_points=0,
30
- lr_scheduler=None,
31
- metrics=None,
32
- additional_val_metrics=None,
33
- net_inputs=('images', 'points'),
34
- max_num_next_clicks=0,
35
- click_models=None,
36
- prev_mask_drop_prob=0.0,
37
- ):
 
 
 
 
 
 
38
  self.cfg = cfg
39
  self.model_cfg = model_cfg
40
  self.max_interactive_points = max_interactive_points
@@ -60,35 +68,44 @@ class ISTrainer(object):
60
 
61
  self.checkpoint_interval = checkpoint_interval
62
  self.image_dump_interval = image_dump_interval
63
- self.task_prefix = ''
64
  self.sw = None
65
 
66
  self.trainset = trainset
67
  self.valset = valset
68
 
69
- logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.')
70
- logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.')
 
 
 
 
71
 
72
  self.train_data = DataLoader(
73
- trainset, cfg.batch_size,
 
74
  sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
75
- drop_last=True, pin_memory=True,
76
- num_workers=cfg.workers
 
77
  )
78
 
79
  self.val_data = DataLoader(
80
- valset, cfg.val_batch_size,
 
81
  sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
82
- drop_last=True, pin_memory=True,
83
- num_workers=cfg.workers
 
84
  )
85
 
86
  self.optim = get_optimizer(model, optimizer, optimizer_params)
87
  model = self._load_weights(model)
88
 
89
  if cfg.multi_gpu:
90
- model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids,
91
- output_device=cfg.gpu_ids[0])
 
92
 
93
  if self.is_master:
94
  logger.info(model)
@@ -96,7 +113,7 @@ class ISTrainer(object):
96
 
97
  self.device = cfg.device
98
  self.net = model.to(self.device)
99
- self.lr = optimizer_params['lr']
100
 
101
  if lr_scheduler is not None:
102
  self.lr_scheduler = lr_scheduler(optimizer=self.optim)
@@ -117,8 +134,8 @@ class ISTrainer(object):
117
  if start_epoch is None:
118
  start_epoch = self.cfg.start_epoch
119
 
120
- logger.info(f'Starting Epoch: {start_epoch}')
121
- logger.info(f'Total Epochs: {num_epochs}')
122
  for epoch in range(start_epoch, num_epochs):
123
  self.training(epoch)
124
  if validation:
@@ -126,15 +143,21 @@ class ISTrainer(object):
126
 
127
  def training(self, epoch):
128
  if self.sw is None and self.is_master:
129
- self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH),
130
- flush_secs=10, dump_period=self.tb_dump_period)
 
 
 
131
 
132
  if self.cfg.distributed:
133
  self.train_data.sampler.set_epoch(epoch)
134
 
135
- log_prefix = 'Train' + self.task_prefix.capitalize()
136
- tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\
137
- if self.is_master else self.train_data
 
 
 
138
 
139
  for metric in self.train_metrics:
140
  metric.reset_epoch_stats()
@@ -144,67 +167,109 @@ class ISTrainer(object):
144
  for i, batch_data in enumerate(tbar):
145
  global_step = epoch * len(self.train_data) + i
146
 
147
- loss, losses_logging, splitted_batch_data, outputs = \
148
- self.batch_forward(batch_data)
 
149
 
150
  self.optim.zero_grad()
151
  loss.backward()
152
  self.optim.step()
153
 
154
- losses_logging['overall'] = loss
155
  reduce_loss_dict(losses_logging)
156
 
157
- train_loss += losses_logging['overall'].item()
158
 
159
  if self.is_master:
160
  for loss_name, loss_value in losses_logging.items():
161
- self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}',
162
- value=loss_value.item(),
163
- global_step=global_step)
 
 
164
 
165
  for k, v in self.loss_cfg.items():
166
- if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0:
167
- v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step)
168
-
169
- if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0:
170
- self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train')
171
-
172
- self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate',
173
- value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1],
174
- global_step=global_step)
175
-
176
- tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  for metric in self.train_metrics:
178
- metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step)
 
 
179
 
180
  if self.is_master:
181
  for metric in self.train_metrics:
182
- self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}',
183
- value=metric.get_epoch_value(),
184
- global_step=epoch, disable_avg=True)
185
-
186
- save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix,
187
- epoch=None, multi_gpu=self.cfg.multi_gpu)
 
 
 
 
 
 
 
 
188
 
189
  if isinstance(self.checkpoint_interval, (list, tuple)):
190
- checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1]
 
 
191
  else:
192
  checkpoint_interval = self.checkpoint_interval
193
 
194
  if epoch % checkpoint_interval == 0:
195
- save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix,
196
- epoch=epoch, multi_gpu=self.cfg.multi_gpu)
197
-
198
- if hasattr(self, 'lr_scheduler'):
 
 
 
 
 
199
  self.lr_scheduler.step()
200
 
201
  def validation(self, epoch):
202
  if self.sw is None and self.is_master:
203
- self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH),
204
- flush_secs=10, dump_period=self.tb_dump_period)
205
-
206
- log_prefix = 'Val' + self.task_prefix.capitalize()
207
- tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data
 
 
 
 
 
 
 
208
 
209
  for metric in self.val_metrics:
210
  metric.reset_epoch_stats()
@@ -215,29 +280,45 @@ class ISTrainer(object):
215
  self.net.eval()
216
  for i, batch_data in enumerate(tbar):
217
  global_step = epoch * len(self.val_data) + i
218
- loss, batch_losses_logging, splitted_batch_data, outputs = \
219
- self.batch_forward(batch_data, validation=True)
220
-
221
- batch_losses_logging['overall'] = loss
 
 
 
 
222
  reduce_loss_dict(batch_losses_logging)
223
  for loss_name, loss_value in batch_losses_logging.items():
224
  losses_logging[loss_name].append(loss_value.item())
225
 
226
- val_loss += batch_losses_logging['overall'].item()
227
 
228
  if self.is_master:
229
- tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}')
 
 
230
  for metric in self.val_metrics:
231
- metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step)
 
 
232
 
233
  if self.is_master:
234
  for loss_name, loss_values in losses_logging.items():
235
- self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(),
236
- global_step=epoch, disable_avg=True)
 
 
 
 
237
 
238
  for metric in self.val_metrics:
239
- self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(),
240
- global_step=epoch, disable_avg=True)
 
 
 
 
241
 
242
  def batch_forward(self, batch_data, validation=False):
243
  metrics = self.val_metrics if validation else self.train_metrics
@@ -245,8 +326,16 @@ class ISTrainer(object):
245
 
246
  with torch.set_grad_enabled(not validation):
247
  batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
248
- image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points']
249
- orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone()
 
 
 
 
 
 
 
 
250
 
251
  prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
252
 
@@ -261,44 +350,79 @@ class ISTrainer(object):
261
  if not validation:
262
  self.net.eval()
263
 
264
- if self.click_models is None or click_indx >= len(self.click_models):
 
 
265
  eval_model = self.net
266
  else:
267
  eval_model = self.click_models[click_indx]
268
 
269
- net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image
270
- prev_output = torch.sigmoid(eval_model(net_input, points)['instances'])
 
 
 
 
 
 
271
 
272
- points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1)
 
 
273
 
274
  if not validation:
275
  self.net.train()
276
 
277
- if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None:
278
- zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob
 
 
 
 
 
 
 
279
  prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
280
 
281
- batch_data['points'] = points
282
 
283
- net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image
 
 
 
 
284
  output = self.net(net_input, points)
285
 
286
  loss = 0.0
287
- loss = self.add_loss('instance_loss', loss, losses_logging, validation,
288
- lambda: (output['instances'], batch_data['instances']))
289
- loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation,
290
- lambda: (output['instances_aux'], batch_data['instances']))
 
 
 
 
 
 
 
 
 
 
291
 
292
  if self.is_master:
293
  with torch.no_grad():
294
  for m in metrics:
295
- m.update(*(output.get(x) for x in m.pred_outputs),
296
- *(batch_data[x] for x in m.gt_outputs))
 
 
297
  return loss, losses_logging, batch_data, output
298
 
299
- def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs):
 
 
300
  loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
301
- loss_weight = loss_cfg.get(loss_name + '_weight', 0.0)
302
  if loss_weight > 0.0:
303
  loss_criterion = loss_cfg.get(loss_name)
304
  loss = loss_criterion(*lambda_loss_inputs())
@@ -316,18 +440,23 @@ class ISTrainer(object):
316
 
317
  if not output_images_path.exists():
318
  output_images_path.mkdir(parents=True)
319
- image_name_prefix = f'{global_step:06d}'
320
 
321
  def _save_image(suffix, image):
322
- cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'),
323
- image, [cv2.IMWRITE_JPEG_QUALITY, 85])
 
 
 
324
 
325
- images = splitted_batch_data['images']
326
- points = splitted_batch_data['points']
327
- instance_masks = splitted_batch_data['instances']
328
 
329
  gt_instance_masks = instance_masks.cpu().numpy()
330
- predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy()
 
 
331
  points = points.detach().cpu().numpy()
332
 
333
  image_blob, points = images[0], points[0]
@@ -337,15 +466,21 @@ class ISTrainer(object):
337
  image = image_blob.cpu().numpy() * 255
338
  image = image.transpose((1, 2, 0))
339
 
340
- image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0))
341
- image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255))
 
 
 
 
342
 
343
  gt_mask[gt_mask < 0] = 0.25
344
  gt_mask = draw_probmap(gt_mask)
345
  predicted_mask = draw_probmap(predicted_mask)
346
- viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8)
 
 
347
 
348
- _save_image('instance_segmentation', viz_image[:, :, ::-1])
349
 
350
  def _load_weights(self, net):
351
  if self.cfg.weights is not None:
@@ -355,11 +490,13 @@ class ISTrainer(object):
355
  else:
356
  raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
357
  elif self.cfg.resume_exp is not None:
358
- checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth'))
 
 
359
  assert len(checkpoints) == 1
360
 
361
  checkpoint_path = checkpoints[0]
362
- logger.info(f'Load checkpoint from path: {checkpoint_path}')
363
  load_weights(net, str(checkpoint_path))
364
  return net
365
 
@@ -376,8 +513,8 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
376
  fn_mask = np.logical_and(gt, pred < pred_thresh)
377
  fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
378
 
379
- fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
380
- fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
381
  num_points = points.size(1) // 2
382
  points = points.clone()
383
 
@@ -408,6 +545,6 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
408
 
409
  def load_weights(model, path_to_weights):
410
  current_state_dict = model.state_dict()
411
- new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict']
412
  current_state_dict.update(new_state_dict)
413
  model.load_state_dict(current_state_dict)
 
1
+ import logging
2
  import os
3
  import random
 
 
4
  from collections import defaultdict
5
+ from copy import deepcopy
6
 
7
  import cv2
 
8
  import numpy as np
9
+ import torch
10
  from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
 
13
+ from isegm.utils.distributed import (get_dp_wrapper, get_sampler,
14
+ reduce_loss_dict)
15
+ from isegm.utils.log import SummaryWriterAvg, TqdmToLogger, logger
16
  from isegm.utils.misc import save_checkpoint
17
  from isegm.utils.serialization import get_config_repr
18
+ from isegm.utils.vis import draw_points, draw_probmap
19
+
20
  from .optimizer import get_optimizer
21
 
22
 
23
  class ISTrainer(object):
24
+ def __init__(
25
+ self,
26
+ model,
27
+ cfg,
28
+ model_cfg,
29
+ loss_cfg,
30
+ trainset,
31
+ valset,
32
+ optimizer="adam",
33
+ optimizer_params=None,
34
+ image_dump_interval=200,
35
+ checkpoint_interval=10,
36
+ tb_dump_period=25,
37
+ max_interactive_points=0,
38
+ lr_scheduler=None,
39
+ metrics=None,
40
+ additional_val_metrics=None,
41
+ net_inputs=("images", "points"),
42
+ max_num_next_clicks=0,
43
+ click_models=None,
44
+ prev_mask_drop_prob=0.0,
45
+ ):
46
  self.cfg = cfg
47
  self.model_cfg = model_cfg
48
  self.max_interactive_points = max_interactive_points
 
68
 
69
  self.checkpoint_interval = checkpoint_interval
70
  self.image_dump_interval = image_dump_interval
71
+ self.task_prefix = ""
72
  self.sw = None
73
 
74
  self.trainset = trainset
75
  self.valset = valset
76
 
77
+ logger.info(
78
+ f"Dataset of {trainset.get_samples_number()} samples was loaded for training."
79
+ )
80
+ logger.info(
81
+ f"Dataset of {valset.get_samples_number()} samples was loaded for validation."
82
+ )
83
 
84
  self.train_data = DataLoader(
85
+ trainset,
86
+ cfg.batch_size,
87
  sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
88
+ drop_last=True,
89
+ pin_memory=True,
90
+ num_workers=cfg.workers,
91
  )
92
 
93
  self.val_data = DataLoader(
94
+ valset,
95
+ cfg.val_batch_size,
96
  sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
97
+ drop_last=True,
98
+ pin_memory=True,
99
+ num_workers=cfg.workers,
100
  )
101
 
102
  self.optim = get_optimizer(model, optimizer, optimizer_params)
103
  model = self._load_weights(model)
104
 
105
  if cfg.multi_gpu:
106
+ model = get_dp_wrapper(cfg.distributed)(
107
+ model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0]
108
+ )
109
 
110
  if self.is_master:
111
  logger.info(model)
 
113
 
114
  self.device = cfg.device
115
  self.net = model.to(self.device)
116
+ self.lr = optimizer_params["lr"]
117
 
118
  if lr_scheduler is not None:
119
  self.lr_scheduler = lr_scheduler(optimizer=self.optim)
 
134
  if start_epoch is None:
135
  start_epoch = self.cfg.start_epoch
136
 
137
+ logger.info(f"Starting Epoch: {start_epoch}")
138
+ logger.info(f"Total Epochs: {num_epochs}")
139
  for epoch in range(start_epoch, num_epochs):
140
  self.training(epoch)
141
  if validation:
 
143
 
144
  def training(self, epoch):
145
  if self.sw is None and self.is_master:
146
+ self.sw = SummaryWriterAvg(
147
+ log_dir=str(self.cfg.LOGS_PATH),
148
+ flush_secs=10,
149
+ dump_period=self.tb_dump_period,
150
+ )
151
 
152
  if self.cfg.distributed:
153
  self.train_data.sampler.set_epoch(epoch)
154
 
155
+ log_prefix = "Train" + self.task_prefix.capitalize()
156
+ tbar = (
157
+ tqdm(self.train_data, file=self.tqdm_out, ncols=100)
158
+ if self.is_master
159
+ else self.train_data
160
+ )
161
 
162
  for metric in self.train_metrics:
163
  metric.reset_epoch_stats()
 
167
  for i, batch_data in enumerate(tbar):
168
  global_step = epoch * len(self.train_data) + i
169
 
170
+ loss, losses_logging, splitted_batch_data, outputs = self.batch_forward(
171
+ batch_data
172
+ )
173
 
174
  self.optim.zero_grad()
175
  loss.backward()
176
  self.optim.step()
177
 
178
+ losses_logging["overall"] = loss
179
  reduce_loss_dict(losses_logging)
180
 
181
+ train_loss += losses_logging["overall"].item()
182
 
183
  if self.is_master:
184
  for loss_name, loss_value in losses_logging.items():
185
+ self.sw.add_scalar(
186
+ tag=f"{log_prefix}Losses/{loss_name}",
187
+ value=loss_value.item(),
188
+ global_step=global_step,
189
+ )
190
 
191
  for k, v in self.loss_cfg.items():
192
+ if (
193
+ "_loss" in k
194
+ and hasattr(v, "log_states")
195
+ and self.loss_cfg.get(k + "_weight", 0.0) > 0
196
+ ):
197
+ v.log_states(self.sw, f"{log_prefix}Losses/{k}", global_step)
198
+
199
+ if (
200
+ self.image_dump_interval > 0
201
+ and global_step % self.image_dump_interval == 0
202
+ ):
203
+ self.save_visualization(
204
+ splitted_batch_data, outputs, global_step, prefix="train"
205
+ )
206
+
207
+ self.sw.add_scalar(
208
+ tag=f"{log_prefix}States/learning_rate",
209
+ value=self.lr
210
+ if not hasattr(self, "lr_scheduler")
211
+ else self.lr_scheduler.get_lr()[-1],
212
+ global_step=global_step,
213
+ )
214
+
215
+ tbar.set_description(
216
+ f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}"
217
+ )
218
  for metric in self.train_metrics:
219
+ metric.log_states(
220
+ self.sw, f"{log_prefix}Metrics/{metric.name}", global_step
221
+ )
222
 
223
  if self.is_master:
224
  for metric in self.train_metrics:
225
+ self.sw.add_scalar(
226
+ tag=f"{log_prefix}Metrics/{metric.name}",
227
+ value=metric.get_epoch_value(),
228
+ global_step=epoch,
229
+ disable_avg=True,
230
+ )
231
+
232
+ save_checkpoint(
233
+ self.net,
234
+ self.cfg.CHECKPOINTS_PATH,
235
+ prefix=self.task_prefix,
236
+ epoch=None,
237
+ multi_gpu=self.cfg.multi_gpu,
238
+ )
239
 
240
  if isinstance(self.checkpoint_interval, (list, tuple)):
241
+ checkpoint_interval = [
242
+ x for x in self.checkpoint_interval if x[0] <= epoch
243
+ ][-1][1]
244
  else:
245
  checkpoint_interval = self.checkpoint_interval
246
 
247
  if epoch % checkpoint_interval == 0:
248
+ save_checkpoint(
249
+ self.net,
250
+ self.cfg.CHECKPOINTS_PATH,
251
+ prefix=self.task_prefix,
252
+ epoch=epoch,
253
+ multi_gpu=self.cfg.multi_gpu,
254
+ )
255
+
256
+ if hasattr(self, "lr_scheduler"):
257
  self.lr_scheduler.step()
258
 
259
  def validation(self, epoch):
260
  if self.sw is None and self.is_master:
261
+ self.sw = SummaryWriterAvg(
262
+ log_dir=str(self.cfg.LOGS_PATH),
263
+ flush_secs=10,
264
+ dump_period=self.tb_dump_period,
265
+ )
266
+
267
+ log_prefix = "Val" + self.task_prefix.capitalize()
268
+ tbar = (
269
+ tqdm(self.val_data, file=self.tqdm_out, ncols=100)
270
+ if self.is_master
271
+ else self.val_data
272
+ )
273
 
274
  for metric in self.val_metrics:
275
  metric.reset_epoch_stats()
 
280
  self.net.eval()
281
  for i, batch_data in enumerate(tbar):
282
  global_step = epoch * len(self.val_data) + i
283
+ (
284
+ loss,
285
+ batch_losses_logging,
286
+ splitted_batch_data,
287
+ outputs,
288
+ ) = self.batch_forward(batch_data, validation=True)
289
+
290
+ batch_losses_logging["overall"] = loss
291
  reduce_loss_dict(batch_losses_logging)
292
  for loss_name, loss_value in batch_losses_logging.items():
293
  losses_logging[loss_name].append(loss_value.item())
294
 
295
+ val_loss += batch_losses_logging["overall"].item()
296
 
297
  if self.is_master:
298
+ tbar.set_description(
299
+ f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}"
300
+ )
301
  for metric in self.val_metrics:
302
+ metric.log_states(
303
+ self.sw, f"{log_prefix}Metrics/{metric.name}", global_step
304
+ )
305
 
306
  if self.is_master:
307
  for loss_name, loss_values in losses_logging.items():
308
+ self.sw.add_scalar(
309
+ tag=f"{log_prefix}Losses/{loss_name}",
310
+ value=np.array(loss_values).mean(),
311
+ global_step=epoch,
312
+ disable_avg=True,
313
+ )
314
 
315
  for metric in self.val_metrics:
316
+ self.sw.add_scalar(
317
+ tag=f"{log_prefix}Metrics/{metric.name}",
318
+ value=metric.get_epoch_value(),
319
+ global_step=epoch,
320
+ disable_avg=True,
321
+ )
322
 
323
  def batch_forward(self, batch_data, validation=False):
324
  metrics = self.val_metrics if validation else self.train_metrics
 
326
 
327
  with torch.set_grad_enabled(not validation):
328
  batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
329
+ image, gt_mask, points = (
330
+ batch_data["images"],
331
+ batch_data["instances"],
332
+ batch_data["points"],
333
+ )
334
+ orig_image, orig_gt_mask, orig_points = (
335
+ image.clone(),
336
+ gt_mask.clone(),
337
+ points.clone(),
338
+ )
339
 
340
  prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
341
 
 
350
  if not validation:
351
  self.net.eval()
352
 
353
+ if self.click_models is None or click_indx >= len(
354
+ self.click_models
355
+ ):
356
  eval_model = self.net
357
  else:
358
  eval_model = self.click_models[click_indx]
359
 
360
+ net_input = (
361
+ torch.cat((image, prev_output), dim=1)
362
+ if self.net.with_prev_mask
363
+ else image
364
+ )
365
+ prev_output = torch.sigmoid(
366
+ eval_model(net_input, points)["instances"]
367
+ )
368
 
369
+ points = get_next_points(
370
+ prev_output, orig_gt_mask, points, click_indx + 1
371
+ )
372
 
373
  if not validation:
374
  self.net.train()
375
 
376
+ if (
377
+ self.net.with_prev_mask
378
+ and self.prev_mask_drop_prob > 0
379
+ and last_click_indx is not None
380
+ ):
381
+ zero_mask = (
382
+ np.random.random(size=prev_output.size(0))
383
+ < self.prev_mask_drop_prob
384
+ )
385
  prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
386
 
387
+ batch_data["points"] = points
388
 
389
+ net_input = (
390
+ torch.cat((image, prev_output), dim=1)
391
+ if self.net.with_prev_mask
392
+ else image
393
+ )
394
  output = self.net(net_input, points)
395
 
396
  loss = 0.0
397
+ loss = self.add_loss(
398
+ "instance_loss",
399
+ loss,
400
+ losses_logging,
401
+ validation,
402
+ lambda: (output["instances"], batch_data["instances"]),
403
+ )
404
+ loss = self.add_loss(
405
+ "instance_aux_loss",
406
+ loss,
407
+ losses_logging,
408
+ validation,
409
+ lambda: (output["instances_aux"], batch_data["instances"]),
410
+ )
411
 
412
  if self.is_master:
413
  with torch.no_grad():
414
  for m in metrics:
415
+ m.update(
416
+ *(output.get(x) for x in m.pred_outputs),
417
+ *(batch_data[x] for x in m.gt_outputs),
418
+ )
419
  return loss, losses_logging, batch_data, output
420
 
421
+ def add_loss(
422
+ self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs
423
+ ):
424
  loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
425
+ loss_weight = loss_cfg.get(loss_name + "_weight", 0.0)
426
  if loss_weight > 0.0:
427
  loss_criterion = loss_cfg.get(loss_name)
428
  loss = loss_criterion(*lambda_loss_inputs())
 
440
 
441
  if not output_images_path.exists():
442
  output_images_path.mkdir(parents=True)
443
+ image_name_prefix = f"{global_step:06d}"
444
 
445
  def _save_image(suffix, image):
446
+ cv2.imwrite(
447
+ str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"),
448
+ image,
449
+ [cv2.IMWRITE_JPEG_QUALITY, 85],
450
+ )
451
 
452
+ images = splitted_batch_data["images"]
453
+ points = splitted_batch_data["points"]
454
+ instance_masks = splitted_batch_data["instances"]
455
 
456
  gt_instance_masks = instance_masks.cpu().numpy()
457
+ predicted_instance_masks = (
458
+ torch.sigmoid(outputs["instances"]).detach().cpu().numpy()
459
+ )
460
  points = points.detach().cpu().numpy()
461
 
462
  image_blob, points = images[0], points[0]
 
466
  image = image_blob.cpu().numpy() * 255
467
  image = image.transpose((1, 2, 0))
468
 
469
+ image_with_points = draw_points(
470
+ image, points[: self.max_interactive_points], (0, 255, 0)
471
+ )
472
+ image_with_points = draw_points(
473
+ image_with_points, points[self.max_interactive_points :], (0, 0, 255)
474
+ )
475
 
476
  gt_mask[gt_mask < 0] = 0.25
477
  gt_mask = draw_probmap(gt_mask)
478
  predicted_mask = draw_probmap(predicted_mask)
479
+ viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(
480
+ np.uint8
481
+ )
482
 
483
+ _save_image("instance_segmentation", viz_image[:, :, ::-1])
484
 
485
  def _load_weights(self, net):
486
  if self.cfg.weights is not None:
 
490
  else:
491
  raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
492
  elif self.cfg.resume_exp is not None:
493
+ checkpoints = list(
494
+ self.cfg.CHECKPOINTS_PATH.glob(f"{self.cfg.resume_prefix}*.pth")
495
+ )
496
  assert len(checkpoints) == 1
497
 
498
  checkpoint_path = checkpoints[0]
499
+ logger.info(f"Load checkpoint from path: {checkpoint_path}")
500
  load_weights(net, str(checkpoint_path))
501
  return net
502
 
 
513
  fn_mask = np.logical_and(gt, pred < pred_thresh)
514
  fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
515
 
516
+ fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
517
+ fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
518
  num_points = points.size(1) // 2
519
  points = points.clone()
520
 
 
545
 
546
  def load_weights(model, path_to_weights):
547
  current_state_dict = model.state_dict()
548
+ new_state_dict = torch.load(path_to_weights, map_location="cpu")["state_dict"]
549
  current_state_dict.update(new_state_dict)
550
  model.load_state_dict(current_state_dict)
isegm/inference/clicker.py CHANGED
@@ -1,10 +1,13 @@
1
- import numpy as np
2
  from copy import deepcopy
 
3
  import cv2
 
4
 
5
 
6
  class Clicker(object):
7
- def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0):
 
 
8
  self.click_indx_offset = click_indx_offset
9
  if gt_mask is not None:
10
  self.gt_mask = gt_mask == 1
@@ -27,12 +30,18 @@ class Clicker(object):
27
  return self.clicks_list[:clicks_limit]
28
 
29
  def _get_next_click(self, pred_mask, padding=True):
30
- fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
31
- fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
 
 
 
 
 
 
32
 
33
  if padding:
34
- fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
35
- fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
36
 
37
  fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
38
  fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
 
 
1
  from copy import deepcopy
2
+
3
  import cv2
4
+ import numpy as np
5
 
6
 
7
  class Clicker(object):
8
+ def __init__(
9
+ self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0
10
+ ):
11
  self.click_indx_offset = click_indx_offset
12
  if gt_mask is not None:
13
  self.gt_mask = gt_mask == 1
 
30
  return self.clicks_list[:clicks_limit]
31
 
32
  def _get_next_click(self, pred_mask, padding=True):
33
+ fn_mask = np.logical_and(
34
+ np.logical_and(self.gt_mask, np.logical_not(pred_mask)),
35
+ self.not_ignore_mask,
36
+ )
37
+ fp_mask = np.logical_and(
38
+ np.logical_and(np.logical_not(self.gt_mask), pred_mask),
39
+ self.not_ignore_mask,
40
+ )
41
 
42
  if padding:
43
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
44
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
45
 
46
  fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
47
  fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
isegm/inference/evaluation.py CHANGED
@@ -20,8 +20,9 @@ def evaluate_dataset(dataset, predictor, **kwargs):
20
  for index in tqdm(range(len(dataset)), leave=False):
21
  sample = dataset.get_sample(index)
22
 
23
- _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, predictor,
24
- sample_id=index, **kwargs)
 
25
  all_ious.append(sample_ious)
26
  end_time = time()
27
  elapsed_time = end_time - start_time
@@ -29,9 +30,17 @@ def evaluate_dataset(dataset, predictor, **kwargs):
29
  return all_ious, elapsed_time
30
 
31
 
32
- def evaluate_sample(image, gt_mask, predictor, max_iou_thr,
33
- pred_thr=0.49, min_clicks=1, max_clicks=20,
34
- sample_id=None, callback=None):
 
 
 
 
 
 
 
 
35
  clicker = Clicker(gt_mask=gt_mask)
36
  pred_mask = np.zeros_like(gt_mask)
37
  ious_list = []
@@ -45,7 +54,14 @@ def evaluate_sample(image, gt_mask, predictor, max_iou_thr,
45
  pred_mask = pred_probs > pred_thr
46
 
47
  if callback is not None:
48
- callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
 
 
 
 
 
 
 
49
 
50
  iou = utils.get_iou(gt_mask, pred_mask)
51
  ious_list.append(iou)
 
20
  for index in tqdm(range(len(dataset)), leave=False):
21
  sample = dataset.get_sample(index)
22
 
23
+ _, sample_ious, _ = evaluate_sample(
24
+ sample.image, sample.gt_mask, predictor, sample_id=index, **kwargs
25
+ )
26
  all_ious.append(sample_ious)
27
  end_time = time()
28
  elapsed_time = end_time - start_time
 
30
  return all_ious, elapsed_time
31
 
32
 
33
+ def evaluate_sample(
34
+ image,
35
+ gt_mask,
36
+ predictor,
37
+ max_iou_thr,
38
+ pred_thr=0.49,
39
+ min_clicks=1,
40
+ max_clicks=20,
41
+ sample_id=None,
42
+ callback=None,
43
+ ):
44
  clicker = Clicker(gt_mask=gt_mask)
45
  pred_mask = np.zeros_like(gt_mask)
46
  ious_list = []
 
54
  pred_mask = pred_probs > pred_thr
55
 
56
  if callback is not None:
57
+ callback(
58
+ image,
59
+ gt_mask,
60
+ pred_probs,
61
+ sample_id,
62
+ click_indx,
63
+ clicker.clicks_list,
64
+ )
65
 
66
  iou = utils.get_iou(gt_mask, pred_mask)
67
  ious_list.append(iou)
isegm/inference/predictors/__init__.py CHANGED
@@ -1,27 +1,31 @@
1
- from .base import BasePredictor
2
- from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
3
- from .brs_functors import InputOptimizer, ScaleBiasOptimizer
4
  from isegm.inference.transforms import ZoomIn
5
  from isegm.model.is_hrnet_model import HRNetModel
6
 
 
 
 
 
 
7
 
8
- def get_predictor(net, brs_mode, device,
9
- prob_thresh=0.49,
10
- with_flip=True,
11
- zoom_in_params=dict(),
12
- predictor_params=None,
13
- brs_opt_func_params=None,
14
- lbfgs_params=None):
 
 
 
 
15
  lbfgs_params_ = {
16
- 'm': 20,
17
- 'factr': 0,
18
- 'pgtol': 1e-8,
19
- 'maxfun': 20,
20
  }
21
 
22
- predictor_params_ = {
23
- 'optimize_after_n_clicks': 1
24
- }
25
 
26
  if zoom_in_params is not None:
27
  zoom_in = ZoomIn(**zoom_in_params)
@@ -30,68 +34,86 @@ def get_predictor(net, brs_mode, device,
30
 
31
  if lbfgs_params is not None:
32
  lbfgs_params_.update(lbfgs_params)
33
- lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
34
 
35
  if brs_opt_func_params is None:
36
  brs_opt_func_params = dict()
37
 
38
  if isinstance(net, (list, tuple)):
39
- assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode."
40
 
41
- if brs_mode == 'NoBRS':
42
  if predictor_params is not None:
43
  predictor_params_.update(predictor_params)
44
- predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
45
- elif brs_mode.startswith('f-BRS'):
46
- predictor_params_.update({
47
- 'net_clicks_limit': 8,
48
- })
 
 
 
 
49
  if predictor_params is not None:
50
  predictor_params_.update(predictor_params)
51
 
52
  insertion_mode = {
53
- 'f-BRS-A': 'after_c4',
54
- 'f-BRS-B': 'after_aspp',
55
- 'f-BRS-C': 'after_deeplab'
56
  }[brs_mode]
57
 
58
- opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
59
- with_flip=with_flip,
60
- optimizer_params=lbfgs_params_,
61
- **brs_opt_func_params)
 
 
62
 
63
  if isinstance(net, HRNetModel):
64
  FeaturePredictor = HRNetFeatureBRSPredictor
65
- insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
 
 
66
  else:
67
  FeaturePredictor = FeatureBRSPredictor
68
 
69
- predictor = FeaturePredictor(net, device,
70
- opt_functor=opt_functor,
71
- with_flip=with_flip,
72
- insertion_mode=insertion_mode,
73
- zoom_in=zoom_in,
74
- **predictor_params_)
75
- elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
76
- use_dmaps = brs_mode == 'DistMap-BRS'
77
-
78
- predictor_params_.update({
79
- 'net_clicks_limit': 5,
80
- })
 
 
 
 
 
81
  if predictor_params is not None:
82
  predictor_params_.update(predictor_params)
83
 
84
- opt_functor = InputOptimizer(prob_thresh=prob_thresh,
85
- with_flip=with_flip,
86
- optimizer_params=lbfgs_params_,
87
- **brs_opt_func_params)
88
-
89
- predictor = InputBRSPredictor(net, device,
90
- optimize_target='dmaps' if use_dmaps else 'rgb',
91
- opt_functor=opt_functor,
92
- with_flip=with_flip,
93
- zoom_in=zoom_in,
94
- **predictor_params_)
 
 
 
 
 
95
  else:
96
  raise NotImplementedError
97
 
 
 
 
 
1
  from isegm.inference.transforms import ZoomIn
2
  from isegm.model.is_hrnet_model import HRNetModel
3
 
4
+ from .base import BasePredictor
5
+ from .brs import (FeatureBRSPredictor, HRNetFeatureBRSPredictor,
6
+ InputBRSPredictor)
7
+ from .brs_functors import InputOptimizer, ScaleBiasOptimizer
8
+
9
 
10
+ def get_predictor(
11
+ net,
12
+ brs_mode,
13
+ device,
14
+ prob_thresh=0.49,
15
+ with_flip=True,
16
+ zoom_in_params=dict(),
17
+ predictor_params=None,
18
+ brs_opt_func_params=None,
19
+ lbfgs_params=None,
20
+ ):
21
  lbfgs_params_ = {
22
+ "m": 20,
23
+ "factr": 0,
24
+ "pgtol": 1e-8,
25
+ "maxfun": 20,
26
  }
27
 
28
+ predictor_params_ = {"optimize_after_n_clicks": 1}
 
 
29
 
30
  if zoom_in_params is not None:
31
  zoom_in = ZoomIn(**zoom_in_params)
 
34
 
35
  if lbfgs_params is not None:
36
  lbfgs_params_.update(lbfgs_params)
37
+ lbfgs_params_["maxiter"] = 2 * lbfgs_params_["maxfun"]
38
 
39
  if brs_opt_func_params is None:
40
  brs_opt_func_params = dict()
41
 
42
  if isinstance(net, (list, tuple)):
43
+ assert brs_mode == "NoBRS", "Multi-stage models support only NoBRS mode."
44
 
45
+ if brs_mode == "NoBRS":
46
  if predictor_params is not None:
47
  predictor_params_.update(predictor_params)
48
+ predictor = BasePredictor(
49
+ net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_
50
+ )
51
+ elif brs_mode.startswith("f-BRS"):
52
+ predictor_params_.update(
53
+ {
54
+ "net_clicks_limit": 8,
55
+ }
56
+ )
57
  if predictor_params is not None:
58
  predictor_params_.update(predictor_params)
59
 
60
  insertion_mode = {
61
+ "f-BRS-A": "after_c4",
62
+ "f-BRS-B": "after_aspp",
63
+ "f-BRS-C": "after_deeplab",
64
  }[brs_mode]
65
 
66
+ opt_functor = ScaleBiasOptimizer(
67
+ prob_thresh=prob_thresh,
68
+ with_flip=with_flip,
69
+ optimizer_params=lbfgs_params_,
70
+ **brs_opt_func_params
71
+ )
72
 
73
  if isinstance(net, HRNetModel):
74
  FeaturePredictor = HRNetFeatureBRSPredictor
75
+ insertion_mode = {"after_c4": "A", "after_aspp": "A", "after_deeplab": "C"}[
76
+ insertion_mode
77
+ ]
78
  else:
79
  FeaturePredictor = FeatureBRSPredictor
80
 
81
+ predictor = FeaturePredictor(
82
+ net,
83
+ device,
84
+ opt_functor=opt_functor,
85
+ with_flip=with_flip,
86
+ insertion_mode=insertion_mode,
87
+ zoom_in=zoom_in,
88
+ **predictor_params_
89
+ )
90
+ elif brs_mode == "RGB-BRS" or brs_mode == "DistMap-BRS":
91
+ use_dmaps = brs_mode == "DistMap-BRS"
92
+
93
+ predictor_params_.update(
94
+ {
95
+ "net_clicks_limit": 5,
96
+ }
97
+ )
98
  if predictor_params is not None:
99
  predictor_params_.update(predictor_params)
100
 
101
+ opt_functor = InputOptimizer(
102
+ prob_thresh=prob_thresh,
103
+ with_flip=with_flip,
104
+ optimizer_params=lbfgs_params_,
105
+ **brs_opt_func_params
106
+ )
107
+
108
+ predictor = InputBRSPredictor(
109
+ net,
110
+ device,
111
+ optimize_target="dmaps" if use_dmaps else "rgb",
112
+ opt_functor=opt_functor,
113
+ with_flip=with_flip,
114
+ zoom_in=zoom_in,
115
+ **predictor_params_
116
+ )
117
  else:
118
  raise NotImplementedError
119
 
isegm/inference/predictors/base.py CHANGED
@@ -1,16 +1,22 @@
1
  import torch
2
  import torch.nn.functional as F
3
  from torchvision import transforms
4
- from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
 
 
5
 
6
 
7
  class BasePredictor(object):
8
- def __init__(self, model, device,
9
- net_clicks_limit=None,
10
- with_flip=False,
11
- zoom_in=None,
12
- max_size=None,
13
- **kwargs):
 
 
 
 
14
  self.with_flip = with_flip
15
  self.net_clicks_limit = net_clicks_limit
16
  self.original_image = None
@@ -48,7 +54,12 @@ class BasePredictor(object):
48
  clicks_list = clicker.get_clicks()
49
 
50
  if self.click_models is not None:
51
- model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1
 
 
 
 
 
52
  if model_indx != self.model_indx:
53
  self.model_indx = model_indx
54
  self.net = self.click_models[model_indx]
@@ -56,15 +67,16 @@ class BasePredictor(object):
56
  input_image = self.original_image
57
  if prev_mask is None:
58
  prev_mask = self.prev_prediction
59
- if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask:
60
  input_image = torch.cat((input_image, prev_mask), dim=1)
61
  image_nd, clicks_lists, is_image_changed = self.apply_transforms(
62
  input_image, [clicks_list]
63
  )
64
 
65
  pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
66
- prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
67
- size=image_nd.size()[2:])
 
68
 
69
  for t in reversed(self.transforms):
70
  prediction = t.inv_transform(prediction)
@@ -77,7 +89,7 @@ class BasePredictor(object):
77
 
78
  def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
79
  points_nd = self.get_points_nd(clicks_lists)
80
- return self.net(image_nd, points_nd)['instances']
81
 
82
  def _get_transform_states(self):
83
  return [x.get_state() for x in self.transforms]
@@ -97,30 +109,43 @@ class BasePredictor(object):
97
 
98
  def get_points_nd(self, clicks_lists):
99
  total_clicks = []
100
- num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
101
- num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
 
 
 
 
 
102
  num_max_points = max(num_pos_clicks + num_neg_clicks)
103
  if self.net_clicks_limit is not None:
104
  num_max_points = min(self.net_clicks_limit, num_max_points)
105
  num_max_points = max(1, num_max_points)
106
 
107
  for clicks_list in clicks_lists:
108
- clicks_list = clicks_list[:self.net_clicks_limit]
109
- pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
110
- pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
111
-
112
- neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
113
- neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
 
 
 
 
 
 
 
 
114
  total_clicks.append(pos_clicks + neg_clicks)
115
 
116
  return torch.tensor(total_clicks, device=self.device)
117
 
118
  def get_states(self):
119
  return {
120
- 'transform_states': self._get_transform_states(),
121
- 'prev_prediction': self.prev_prediction.clone()
122
  }
123
 
124
  def set_states(self, states):
125
- self._set_transform_states(states['transform_states'])
126
- self.prev_prediction = states['prev_prediction']
 
1
  import torch
2
  import torch.nn.functional as F
3
  from torchvision import transforms
4
+
5
+ from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide,
6
+ SigmoidForPred)
7
 
8
 
9
  class BasePredictor(object):
10
+ def __init__(
11
+ self,
12
+ model,
13
+ device,
14
+ net_clicks_limit=None,
15
+ with_flip=False,
16
+ zoom_in=None,
17
+ max_size=None,
18
+ **kwargs
19
+ ):
20
  self.with_flip = with_flip
21
  self.net_clicks_limit = net_clicks_limit
22
  self.original_image = None
 
54
  clicks_list = clicker.get_clicks()
55
 
56
  if self.click_models is not None:
57
+ model_indx = (
58
+ min(
59
+ clicker.click_indx_offset + len(clicks_list), len(self.click_models)
60
+ )
61
+ - 1
62
+ )
63
  if model_indx != self.model_indx:
64
  self.model_indx = model_indx
65
  self.net = self.click_models[model_indx]
 
67
  input_image = self.original_image
68
  if prev_mask is None:
69
  prev_mask = self.prev_prediction
70
+ if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask:
71
  input_image = torch.cat((input_image, prev_mask), dim=1)
72
  image_nd, clicks_lists, is_image_changed = self.apply_transforms(
73
  input_image, [clicks_list]
74
  )
75
 
76
  pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
77
+ prediction = F.interpolate(
78
+ pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
79
+ )
80
 
81
  for t in reversed(self.transforms):
82
  prediction = t.inv_transform(prediction)
 
89
 
90
  def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
91
  points_nd = self.get_points_nd(clicks_lists)
92
+ return self.net(image_nd, points_nd)["instances"]
93
 
94
  def _get_transform_states(self):
95
  return [x.get_state() for x in self.transforms]
 
109
 
110
  def get_points_nd(self, clicks_lists):
111
  total_clicks = []
112
+ num_pos_clicks = [
113
+ sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
114
+ ]
115
+ num_neg_clicks = [
116
+ len(clicks_list) - num_pos
117
+ for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
118
+ ]
119
  num_max_points = max(num_pos_clicks + num_neg_clicks)
120
  if self.net_clicks_limit is not None:
121
  num_max_points = min(self.net_clicks_limit, num_max_points)
122
  num_max_points = max(1, num_max_points)
123
 
124
  for clicks_list in clicks_lists:
125
+ clicks_list = clicks_list[: self.net_clicks_limit]
126
+ pos_clicks = [
127
+ click.coords_and_indx for click in clicks_list if click.is_positive
128
+ ]
129
+ pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
130
+ (-1, -1, -1)
131
+ ]
132
+
133
+ neg_clicks = [
134
+ click.coords_and_indx for click in clicks_list if not click.is_positive
135
+ ]
136
+ neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
137
+ (-1, -1, -1)
138
+ ]
139
  total_clicks.append(pos_clicks + neg_clicks)
140
 
141
  return torch.tensor(total_clicks, device=self.device)
142
 
143
  def get_states(self):
144
  return {
145
+ "transform_states": self._get_transform_states(),
146
+ "prev_prediction": self.prev_prediction.clone(),
147
  }
148
 
149
  def set_states(self, states):
150
+ self._set_transform_states(states["transform_states"])
151
+ self.prev_prediction = states["prev_prediction"]
isegm/inference/predictors/brs.py CHANGED
@@ -1,6 +1,6 @@
 
1
  import torch
2
  import torch.nn.functional as F
3
- import numpy as np
4
  from scipy.optimize import fmin_l_bfgs_b
5
 
6
  from .base import BasePredictor
@@ -21,8 +21,12 @@ class BRSBasePredictor(BasePredictor):
21
  self.input_data = None
22
 
23
  def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
24
- pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
25
- neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
 
 
 
 
26
 
27
  for list_indx, clicks_list in enumerate(clicks_lists):
28
  for click in clicks_list:
@@ -43,24 +47,29 @@ class BRSBasePredictor(BasePredictor):
43
  return pos_clicks_map, neg_clicks_map
44
 
45
  def get_states(self):
46
- return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
 
 
 
47
 
48
  def set_states(self, states):
49
- self._set_transform_states(states['transform_states'])
50
- self.opt_data = states['opt_data']
51
 
52
 
53
  class FeatureBRSPredictor(BRSBasePredictor):
54
- def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
 
 
55
  super().__init__(model, device, opt_functor=opt_functor, **kwargs)
56
  self.insertion_mode = insertion_mode
57
  self._c1_features = None
58
 
59
- if self.insertion_mode == 'after_deeplab':
60
  self.num_channels = model.feature_extractor.ch
61
- elif self.insertion_mode == 'after_c4':
62
  self.num_channels = model.feature_extractor.aspp_in_channels
63
- elif self.insertion_mode == 'after_aspp':
64
  self.num_channels = model.feature_extractor.ch + 32
65
  else:
66
  raise NotImplementedError
@@ -72,10 +81,17 @@ class FeatureBRSPredictor(BRSBasePredictor):
72
  num_clicks = len(clicks_lists[0])
73
  bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
74
 
75
- if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
 
 
 
76
  self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
77
 
78
- if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
 
 
 
 
79
  self.input_data = self._get_head_input(image_nd, points_nd)
80
 
81
  def get_prediction_logits(scale, bias):
@@ -87,24 +103,39 @@ class FeatureBRSPredictor(BRSBasePredictor):
87
 
88
  scaled_backbone_features = self.input_data * scale
89
  scaled_backbone_features = scaled_backbone_features + bias
90
- if self.insertion_mode == 'after_c4':
91
  x = self.net.feature_extractor.aspp(scaled_backbone_features)
92
- x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
93
- align_corners=True)
 
 
 
 
94
  x = torch.cat((x, self._c1_features), dim=1)
95
  scaled_backbone_features = self.net.feature_extractor.head(x)
96
- elif self.insertion_mode == 'after_aspp':
97
- scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
 
 
98
 
99
  pred_logits = self.net.head(scaled_backbone_features)
100
- pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
101
- align_corners=True)
 
 
 
 
102
  return pred_logits
103
 
104
- self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
 
 
105
  if num_clicks > self.optimize_after_n_clicks:
106
- opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
107
- **self.opt_functor.optimizer_params)
 
 
 
108
  self.opt_data = opt_result[0]
109
 
110
  with torch.no_grad():
@@ -125,37 +156,45 @@ class FeatureBRSPredictor(BRSBasePredictor):
125
  if self.net.rgb_conv is not None:
126
  x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
127
  additional_features = None
128
- elif hasattr(self.net, 'maps_transform'):
129
  x = image_nd
130
  additional_features = self.net.maps_transform(coord_features)
131
 
132
- if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
133
- c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features)
 
 
134
  c1 = self.net.feature_extractor.skip_project(c1)
135
 
136
- if self.insertion_mode == 'after_aspp':
137
  x = self.net.feature_extractor.aspp(c4)
138
- x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
 
 
139
  x = torch.cat((x, c1), dim=1)
140
  backbone_features = x
141
  else:
142
  backbone_features = c4
143
  self._c1_features = c1
144
  else:
145
- backbone_features = self.net.feature_extractor(x, additional_features)[0]
 
 
146
 
147
  return backbone_features
148
 
149
 
150
  class HRNetFeatureBRSPredictor(BRSBasePredictor):
151
- def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
152
  super().__init__(model, device, opt_functor=opt_functor, **kwargs)
153
  self.insertion_mode = insertion_mode
154
  self._c1_features = None
155
 
156
- if self.insertion_mode == 'A':
157
- self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
158
- elif self.insertion_mode == 'C':
 
 
159
  self.num_channels = 2 * model.feature_extractor.ocr_width
160
  else:
161
  raise NotImplementedError
@@ -166,10 +205,17 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
166
  num_clicks = len(clicks_lists[0])
167
  bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
168
 
169
- if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
 
 
 
170
  self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
171
 
172
- if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
 
 
 
 
173
  self.input_data = self._get_head_input(image_nd, points_nd)
174
 
175
  def get_prediction_logits(scale, bias):
@@ -181,29 +227,44 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
181
 
182
  scaled_backbone_features = self.input_data * scale
183
  scaled_backbone_features = scaled_backbone_features + bias
184
- if self.insertion_mode == 'A':
185
  if self.net.feature_extractor.ocr_width > 0:
186
- out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
187
- feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
 
 
 
 
188
 
189
  context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
190
  feats = self.net.feature_extractor.ocr_distri_head(feats, context)
191
  else:
192
  feats = scaled_backbone_features
193
  pred_logits = self.net.feature_extractor.cls_head(feats)
194
- elif self.insertion_mode == 'C':
195
- pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
 
 
196
  else:
197
  raise NotImplementedError
198
 
199
- pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
200
- align_corners=True)
 
 
 
 
201
  return pred_logits
202
 
203
- self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
 
 
204
  if num_clicks > self.optimize_after_n_clicks:
205
- opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
206
- **self.opt_functor.optimizer_params)
 
 
 
207
  self.opt_data = opt_result[0]
208
 
209
  with torch.no_grad():
@@ -224,20 +285,24 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
224
  if self.net.rgb_conv is not None:
225
  x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
226
  additional_features = None
227
- elif hasattr(self.net, 'maps_transform'):
228
  x = image_nd
229
  additional_features = self.net.maps_transform(coord_features)
230
 
231
- feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features)
 
 
232
 
233
- if self.insertion_mode == 'A':
234
  backbone_features = feats
235
- elif self.insertion_mode == 'C':
236
  out_aux = self.net.feature_extractor.aux_head(feats)
237
  feats = self.net.feature_extractor.conv3x3_ocr(feats)
238
 
239
  context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
240
- backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
 
 
241
  else:
242
  raise NotImplementedError
243
 
@@ -245,7 +310,7 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
245
 
246
 
247
  class InputBRSPredictor(BRSBasePredictor):
248
- def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
249
  super().__init__(model, device, opt_functor=opt_functor, **kwargs)
250
  self.optimize_target = optimize_target
251
 
@@ -255,21 +320,28 @@ class InputBRSPredictor(BRSBasePredictor):
255
  num_clicks = len(clicks_lists[0])
256
 
257
  if self.opt_data is None or is_image_changed:
258
- if self.optimize_target == 'dmaps':
259
- opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch
 
 
 
 
260
  else:
261
  opt_channels = 3
262
  bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
263
- self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
264
- device=self.device, dtype=torch.float32)
 
 
 
265
 
266
  def get_prediction_logits(opt_bias):
267
  input_image, prev_mask = self.net.prepare_input(image_nd)
268
  dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
269
 
270
- if self.optimize_target == 'rgb':
271
  input_image = input_image + opt_bias
272
- elif self.optimize_target == 'dmaps':
273
  if self.net.with_prev_mask:
274
  dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
275
  else:
@@ -277,25 +349,44 @@ class InputBRSPredictor(BRSBasePredictor):
277
 
278
  if self.net.rgb_conv is not None:
279
  x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
280
- if self.optimize_target == 'all':
281
  x = x + opt_bias
282
  coord_features = None
283
- elif hasattr(self.net, 'maps_transform'):
284
  x = input_image
285
  coord_features = self.net.maps_transform(dmaps)
286
 
287
- pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances']
288
- pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
 
 
 
 
 
 
 
289
 
290
  return pred_logits
291
 
292
- self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
293
- shape=self.opt_data.shape)
 
 
 
 
 
294
  if num_clicks > self.optimize_after_n_clicks:
295
- opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
296
- **self.opt_functor.optimizer_params)
297
-
298
- self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
 
 
 
 
 
 
 
299
 
300
  with torch.no_grad():
301
  if self.opt_functor.best_prediction is not None:
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn.functional as F
 
4
  from scipy.optimize import fmin_l_bfgs_b
5
 
6
  from .base import BasePredictor
 
21
  self.input_data = None
22
 
23
  def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
24
+ pos_clicks_map = np.zeros(
25
+ (len(clicks_lists), 1) + image_shape, dtype=np.float32
26
+ )
27
+ neg_clicks_map = np.zeros(
28
+ (len(clicks_lists), 1) + image_shape, dtype=np.float32
29
+ )
30
 
31
  for list_indx, clicks_list in enumerate(clicks_lists):
32
  for click in clicks_list:
 
47
  return pos_clicks_map, neg_clicks_map
48
 
49
  def get_states(self):
50
+ return {
51
+ "transform_states": self._get_transform_states(),
52
+ "opt_data": self.opt_data,
53
+ }
54
 
55
  def set_states(self, states):
56
+ self._set_transform_states(states["transform_states"])
57
+ self.opt_data = states["opt_data"]
58
 
59
 
60
  class FeatureBRSPredictor(BRSBasePredictor):
61
+ def __init__(
62
+ self, model, device, opt_functor, insertion_mode="after_deeplab", **kwargs
63
+ ):
64
  super().__init__(model, device, opt_functor=opt_functor, **kwargs)
65
  self.insertion_mode = insertion_mode
66
  self._c1_features = None
67
 
68
+ if self.insertion_mode == "after_deeplab":
69
  self.num_channels = model.feature_extractor.ch
70
+ elif self.insertion_mode == "after_c4":
71
  self.num_channels = model.feature_extractor.aspp_in_channels
72
+ elif self.insertion_mode == "after_aspp":
73
  self.num_channels = model.feature_extractor.ch + 32
74
  else:
75
  raise NotImplementedError
 
81
  num_clicks = len(clicks_lists[0])
82
  bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
83
 
84
+ if (
85
+ self.opt_data is None
86
+ or self.opt_data.shape[0] // (2 * self.num_channels) != bs
87
+ ):
88
  self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
89
 
90
+ if (
91
+ num_clicks <= self.net_clicks_limit
92
+ or is_image_changed
93
+ or self.input_data is None
94
+ ):
95
  self.input_data = self._get_head_input(image_nd, points_nd)
96
 
97
  def get_prediction_logits(scale, bias):
 
103
 
104
  scaled_backbone_features = self.input_data * scale
105
  scaled_backbone_features = scaled_backbone_features + bias
106
+ if self.insertion_mode == "after_c4":
107
  x = self.net.feature_extractor.aspp(scaled_backbone_features)
108
+ x = F.interpolate(
109
+ x,
110
+ mode="bilinear",
111
+ size=self._c1_features.size()[2:],
112
+ align_corners=True,
113
+ )
114
  x = torch.cat((x, self._c1_features), dim=1)
115
  scaled_backbone_features = self.net.feature_extractor.head(x)
116
+ elif self.insertion_mode == "after_aspp":
117
+ scaled_backbone_features = self.net.feature_extractor.head(
118
+ scaled_backbone_features
119
+ )
120
 
121
  pred_logits = self.net.head(scaled_backbone_features)
122
+ pred_logits = F.interpolate(
123
+ pred_logits,
124
+ size=image_nd.size()[2:],
125
+ mode="bilinear",
126
+ align_corners=True,
127
+ )
128
  return pred_logits
129
 
130
+ self.opt_functor.init_click(
131
+ get_prediction_logits, pos_mask, neg_mask, self.device
132
+ )
133
  if num_clicks > self.optimize_after_n_clicks:
134
+ opt_result = fmin_l_bfgs_b(
135
+ func=self.opt_functor,
136
+ x0=self.opt_data,
137
+ **self.opt_functor.optimizer_params
138
+ )
139
  self.opt_data = opt_result[0]
140
 
141
  with torch.no_grad():
 
156
  if self.net.rgb_conv is not None:
157
  x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
158
  additional_features = None
159
+ elif hasattr(self.net, "maps_transform"):
160
  x = image_nd
161
  additional_features = self.net.maps_transform(coord_features)
162
 
163
+ if self.insertion_mode == "after_c4" or self.insertion_mode == "after_aspp":
164
+ c1, _, c3, c4 = self.net.feature_extractor.backbone(
165
+ x, additional_features
166
+ )
167
  c1 = self.net.feature_extractor.skip_project(c1)
168
 
169
+ if self.insertion_mode == "after_aspp":
170
  x = self.net.feature_extractor.aspp(c4)
171
+ x = F.interpolate(
172
+ x, size=c1.size()[2:], mode="bilinear", align_corners=True
173
+ )
174
  x = torch.cat((x, c1), dim=1)
175
  backbone_features = x
176
  else:
177
  backbone_features = c4
178
  self._c1_features = c1
179
  else:
180
+ backbone_features = self.net.feature_extractor(x, additional_features)[
181
+ 0
182
+ ]
183
 
184
  return backbone_features
185
 
186
 
187
  class HRNetFeatureBRSPredictor(BRSBasePredictor):
188
+ def __init__(self, model, device, opt_functor, insertion_mode="A", **kwargs):
189
  super().__init__(model, device, opt_functor=opt_functor, **kwargs)
190
  self.insertion_mode = insertion_mode
191
  self._c1_features = None
192
 
193
+ if self.insertion_mode == "A":
194
+ self.num_channels = sum(
195
+ k * model.feature_extractor.width for k in [1, 2, 4, 8]
196
+ )
197
+ elif self.insertion_mode == "C":
198
  self.num_channels = 2 * model.feature_extractor.ocr_width
199
  else:
200
  raise NotImplementedError
 
205
  num_clicks = len(clicks_lists[0])
206
  bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
207
 
208
+ if (
209
+ self.opt_data is None
210
+ or self.opt_data.shape[0] // (2 * self.num_channels) != bs
211
+ ):
212
  self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
213
 
214
+ if (
215
+ num_clicks <= self.net_clicks_limit
216
+ or is_image_changed
217
+ or self.input_data is None
218
+ ):
219
  self.input_data = self._get_head_input(image_nd, points_nd)
220
 
221
  def get_prediction_logits(scale, bias):
 
227
 
228
  scaled_backbone_features = self.input_data * scale
229
  scaled_backbone_features = scaled_backbone_features + bias
230
+ if self.insertion_mode == "A":
231
  if self.net.feature_extractor.ocr_width > 0:
232
+ out_aux = self.net.feature_extractor.aux_head(
233
+ scaled_backbone_features
234
+ )
235
+ feats = self.net.feature_extractor.conv3x3_ocr(
236
+ scaled_backbone_features
237
+ )
238
 
239
  context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
240
  feats = self.net.feature_extractor.ocr_distri_head(feats, context)
241
  else:
242
  feats = scaled_backbone_features
243
  pred_logits = self.net.feature_extractor.cls_head(feats)
244
+ elif self.insertion_mode == "C":
245
+ pred_logits = self.net.feature_extractor.cls_head(
246
+ scaled_backbone_features
247
+ )
248
  else:
249
  raise NotImplementedError
250
 
251
+ pred_logits = F.interpolate(
252
+ pred_logits,
253
+ size=image_nd.size()[2:],
254
+ mode="bilinear",
255
+ align_corners=True,
256
+ )
257
  return pred_logits
258
 
259
+ self.opt_functor.init_click(
260
+ get_prediction_logits, pos_mask, neg_mask, self.device
261
+ )
262
  if num_clicks > self.optimize_after_n_clicks:
263
+ opt_result = fmin_l_bfgs_b(
264
+ func=self.opt_functor,
265
+ x0=self.opt_data,
266
+ **self.opt_functor.optimizer_params
267
+ )
268
  self.opt_data = opt_result[0]
269
 
270
  with torch.no_grad():
 
285
  if self.net.rgb_conv is not None:
286
  x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
287
  additional_features = None
288
+ elif hasattr(self.net, "maps_transform"):
289
  x = image_nd
290
  additional_features = self.net.maps_transform(coord_features)
291
 
292
+ feats = self.net.feature_extractor.compute_hrnet_feats(
293
+ x, additional_features
294
+ )
295
 
296
+ if self.insertion_mode == "A":
297
  backbone_features = feats
298
+ elif self.insertion_mode == "C":
299
  out_aux = self.net.feature_extractor.aux_head(feats)
300
  feats = self.net.feature_extractor.conv3x3_ocr(feats)
301
 
302
  context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
303
+ backbone_features = self.net.feature_extractor.ocr_distri_head(
304
+ feats, context
305
+ )
306
  else:
307
  raise NotImplementedError
308
 
 
310
 
311
 
312
  class InputBRSPredictor(BRSBasePredictor):
313
+ def __init__(self, model, device, opt_functor, optimize_target="rgb", **kwargs):
314
  super().__init__(model, device, opt_functor=opt_functor, **kwargs)
315
  self.optimize_target = optimize_target
316
 
 
320
  num_clicks = len(clicks_lists[0])
321
 
322
  if self.opt_data is None or is_image_changed:
323
+ if self.optimize_target == "dmaps":
324
+ opt_channels = (
325
+ self.net.coord_feature_ch - 1
326
+ if self.net.with_prev_mask
327
+ else self.net.coord_feature_ch
328
+ )
329
  else:
330
  opt_channels = 3
331
  bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
332
+ self.opt_data = torch.zeros(
333
+ (bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
334
+ device=self.device,
335
+ dtype=torch.float32,
336
+ )
337
 
338
  def get_prediction_logits(opt_bias):
339
  input_image, prev_mask = self.net.prepare_input(image_nd)
340
  dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
341
 
342
+ if self.optimize_target == "rgb":
343
  input_image = input_image + opt_bias
344
+ elif self.optimize_target == "dmaps":
345
  if self.net.with_prev_mask:
346
  dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
347
  else:
 
349
 
350
  if self.net.rgb_conv is not None:
351
  x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
352
+ if self.optimize_target == "all":
353
  x = x + opt_bias
354
  coord_features = None
355
+ elif hasattr(self.net, "maps_transform"):
356
  x = input_image
357
  coord_features = self.net.maps_transform(dmaps)
358
 
359
+ pred_logits = self.net.backbone_forward(x, coord_features=coord_features)[
360
+ "instances"
361
+ ]
362
+ pred_logits = F.interpolate(
363
+ pred_logits,
364
+ size=image_nd.size()[2:],
365
+ mode="bilinear",
366
+ align_corners=True,
367
+ )
368
 
369
  return pred_logits
370
 
371
+ self.opt_functor.init_click(
372
+ get_prediction_logits,
373
+ pos_mask,
374
+ neg_mask,
375
+ self.device,
376
+ shape=self.opt_data.shape,
377
+ )
378
  if num_clicks > self.optimize_after_n_clicks:
379
+ opt_result = fmin_l_bfgs_b(
380
+ func=self.opt_functor,
381
+ x0=self.opt_data.cpu().numpy().ravel(),
382
+ **self.opt_functor.optimizer_params
383
+ )
384
+
385
+ self.opt_data = (
386
+ torch.from_numpy(opt_result[0])
387
+ .view(self.opt_data.shape)
388
+ .to(self.device)
389
+ )
390
 
391
  with torch.no_grad():
392
  if self.opt_functor.best_prediction is not None:
isegm/inference/predictors/brs_functors.py CHANGED
@@ -1,19 +1,23 @@
1
- import torch
2
  import numpy as np
 
3
 
4
  from isegm.model.metrics import _compute_iou
 
5
  from .brs_losses import BRSMaskLoss
6
 
7
 
8
  class BaseOptimizer:
9
- def __init__(self, optimizer_params,
10
- prob_thresh=0.49,
11
- reg_weight=1e-3,
12
- min_iou_diff=0.01,
13
- brs_loss=BRSMaskLoss(),
14
- with_flip=False,
15
- flip_average=False,
16
- **kwargs):
 
 
 
17
  self.brs_loss = brs_loss
18
  self.optimizer_params = optimizer_params
19
  self.prob_thresh = prob_thresh
@@ -51,7 +55,10 @@ class BaseOptimizer:
51
  if self.with_flip and self.flip_average:
52
  result, result_flipped = torch.chunk(result, 2, dim=0)
53
  result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
54
- pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
 
 
 
55
 
56
  loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
57
  loss = loss + reg_loss
@@ -99,11 +106,13 @@ class ScaleBiasOptimizer(BaseOptimizer):
99
 
100
  def unpack_opt_params(self, opt_params):
101
  scale, bias = torch.chunk(opt_params, 2, dim=0)
102
- reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
 
 
103
 
104
- if self.scale_act == 'tanh':
105
  scale = torch.tanh(scale)
106
- elif self.scale_act == 'sin':
107
  scale = torch.sin(scale)
108
 
109
  return (1 + scale, bias), reg_loss
 
 
1
  import numpy as np
2
+ import torch
3
 
4
  from isegm.model.metrics import _compute_iou
5
+
6
  from .brs_losses import BRSMaskLoss
7
 
8
 
9
  class BaseOptimizer:
10
+ def __init__(
11
+ self,
12
+ optimizer_params,
13
+ prob_thresh=0.49,
14
+ reg_weight=1e-3,
15
+ min_iou_diff=0.01,
16
+ brs_loss=BRSMaskLoss(),
17
+ with_flip=False,
18
+ flip_average=False,
19
+ **kwargs
20
+ ):
21
  self.brs_loss = brs_loss
22
  self.optimizer_params = optimizer_params
23
  self.prob_thresh = prob_thresh
 
55
  if self.with_flip and self.flip_average:
56
  result, result_flipped = torch.chunk(result, 2, dim=0)
57
  result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
58
+ pos_mask, neg_mask = (
59
+ pos_mask[: result.shape[0]],
60
+ neg_mask[: result.shape[0]],
61
+ )
62
 
63
  loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
64
  loss = loss + reg_loss
 
106
 
107
  def unpack_opt_params(self, opt_params):
108
  scale, bias = torch.chunk(opt_params, 2, dim=0)
109
+ reg_loss = self.reg_weight * (
110
+ torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)
111
+ )
112
 
113
+ if self.scale_act == "tanh":
114
  scale = torch.tanh(scale)
115
+ elif self.scale_act == "sin":
116
  scale = torch.sin(scale)
117
 
118
  return (1 + scale, bias), reg_loss
isegm/inference/predictors/brs_losses.py CHANGED
@@ -10,13 +10,13 @@ class BRSMaskLoss(torch.nn.Module):
10
 
11
  def forward(self, result, pos_mask, neg_mask):
12
  pos_diff = (1 - result) * pos_mask
13
- pos_target = torch.sum(pos_diff ** 2)
14
  pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
15
 
16
  neg_diff = result * neg_mask
17
- neg_target = torch.sum(neg_diff ** 2)
18
  neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
19
-
20
  loss = pos_target + neg_target
21
 
22
  with torch.no_grad():
@@ -42,8 +42,10 @@ class OracleMaskLoss(torch.nn.Module):
42
  gt_mask = self.gt_mask.to(result.device)
43
  if self.predictor.object_roi is not None:
44
  r1, r2, c1, c2 = self.predictor.object_roi[:4]
45
- gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
46
- gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
 
 
47
 
48
  if result.shape[0] == 2:
49
  gt_mask_flipped = torch.flip(gt_mask, dims=[3])
 
10
 
11
  def forward(self, result, pos_mask, neg_mask):
12
  pos_diff = (1 - result) * pos_mask
13
+ pos_target = torch.sum(pos_diff**2)
14
  pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
15
 
16
  neg_diff = result * neg_mask
17
+ neg_target = torch.sum(neg_diff**2)
18
  neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
19
+
20
  loss = pos_target + neg_target
21
 
22
  with torch.no_grad():
 
42
  gt_mask = self.gt_mask.to(result.device)
43
  if self.predictor.object_roi is not None:
44
  r1, r2, c1, c2 = self.predictor.object_roi[:4]
45
+ gt_mask = gt_mask[:, :, r1 : r2 + 1, c1 : c2 + 1]
46
+ gt_mask = torch.nn.functional.interpolate(
47
+ gt_mask, result.size()[2:], mode="bilinear", align_corners=True
48
+ )
49
 
50
  if result.shape[0] == 2:
51
  gt_mask_flipped = torch.flip(gt_mask, dims=[3])
isegm/inference/transforms/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .base import SigmoidForPred
 
2
  from .flip import AddHorizontalFlip
3
- from .zoom_in import ZoomIn
4
  from .limit_longest_side import LimitLongestSide
5
- from .crops import Crops
 
1
  from .base import SigmoidForPred
2
+ from .crops import Crops
3
  from .flip import AddHorizontalFlip
 
4
  from .limit_longest_side import LimitLongestSide
5
+ from .zoom_in import ZoomIn
isegm/inference/transforms/crops.py CHANGED
@@ -1,10 +1,11 @@
1
  import math
 
2
 
3
- import torch
4
  import numpy as np
5
- from typing import List
6
 
7
  from isegm.inference.clicker import Click
 
8
  from .base import BaseTransform
9
 
10
 
@@ -33,17 +34,24 @@ class Crops(BaseTransform):
33
  image_crops = []
34
  for dy in self.y_offsets:
35
  for dx in self.x_offsets:
36
- self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
37
- image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
 
 
38
  image_crops.append(image_crop)
39
  image_crops = torch.cat(image_crops, dim=0)
40
- self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
 
 
41
 
42
  clicks_list = clicks_lists[0]
43
  clicks_lists = []
44
  for dy in self.y_offsets:
45
  for dx in self.x_offsets:
46
- crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list]
 
 
 
47
  clicks_lists.append(crop_clicks)
48
 
49
  return image_crops, clicks_lists
@@ -52,13 +60,16 @@ class Crops(BaseTransform):
52
  if self._counts is None:
53
  return prob_map
54
 
55
- new_prob_map = torch.zeros((1, 1, *self._counts.shape),
56
- dtype=prob_map.dtype, device=prob_map.device)
 
57
 
58
  crop_indx = 0
59
  for dy in self.y_offsets:
60
  for dx in self.x_offsets:
61
- new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
 
 
62
  crop_indx += 1
63
  new_prob_map = torch.div(new_prob_map, self._counts)
64
 
 
1
  import math
2
+ from typing import List
3
 
 
4
  import numpy as np
5
+ import torch
6
 
7
  from isegm.inference.clicker import Click
8
+
9
  from .base import BaseTransform
10
 
11
 
 
34
  image_crops = []
35
  for dy in self.y_offsets:
36
  for dx in self.x_offsets:
37
+ self._counts[dy : dy + self.crop_height, dx : dx + self.crop_width] += 1
38
+ image_crop = image_nd[
39
+ :, :, dy : dy + self.crop_height, dx : dx + self.crop_width
40
+ ]
41
  image_crops.append(image_crop)
42
  image_crops = torch.cat(image_crops, dim=0)
43
+ self._counts = torch.tensor(
44
+ self._counts, device=image_nd.device, dtype=torch.float32
45
+ )
46
 
47
  clicks_list = clicks_lists[0]
48
  clicks_lists = []
49
  for dy in self.y_offsets:
50
  for dx in self.x_offsets:
51
+ crop_clicks = [
52
+ x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx))
53
+ for x in clicks_list
54
+ ]
55
  clicks_lists.append(crop_clicks)
56
 
57
  return image_crops, clicks_lists
 
60
  if self._counts is None:
61
  return prob_map
62
 
63
+ new_prob_map = torch.zeros(
64
+ (1, 1, *self._counts.shape), dtype=prob_map.dtype, device=prob_map.device
65
+ )
66
 
67
  crop_indx = 0
68
  for dy in self.y_offsets:
69
  for dx in self.x_offsets:
70
+ new_prob_map[
71
+ 0, 0, dy : dy + self.crop_height, dx : dx + self.crop_width
72
+ ] += prob_map[crop_indx, 0]
73
  crop_indx += 1
74
  new_prob_map = torch.div(new_prob_map, self._counts)
75
 
isegm/inference/transforms/flip.py CHANGED
@@ -1,7 +1,9 @@
 
 
1
  import torch
2
 
3
- from typing import List
4
  from isegm.inference.clicker import Click
 
5
  from .base import BaseTransform
6
 
7
 
@@ -13,8 +15,10 @@ class AddHorizontalFlip(BaseTransform):
13
  image_width = image_nd.shape[3]
14
  clicks_lists_flipped = []
15
  for clicks_list in clicks_lists:
16
- clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
17
- for click in clicks_list]
 
 
18
  clicks_lists_flipped.append(clicks_list_flipped)
19
  clicks_lists = clicks_lists + clicks_lists_flipped
20
 
 
1
+ from typing import List
2
+
3
  import torch
4
 
 
5
  from isegm.inference.clicker import Click
6
+
7
  from .base import BaseTransform
8
 
9
 
 
15
  image_width = image_nd.shape[3]
16
  clicks_lists_flipped = []
17
  for clicks_list in clicks_lists:
18
+ clicks_list_flipped = [
19
+ click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
20
+ for click in clicks_list
21
+ ]
22
  clicks_lists_flipped.append(clicks_list_flipped)
23
  clicks_lists = clicks_lists + clicks_lists_flipped
24
 
isegm/inference/transforms/zoom_in.py CHANGED
@@ -1,19 +1,24 @@
 
 
1
  import torch
2
 
3
- from typing import List
4
  from isegm.inference.clicker import Click
5
- from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
 
 
6
  from .base import BaseTransform
7
 
8
 
9
  class ZoomIn(BaseTransform):
10
- def __init__(self,
11
- target_size=400,
12
- skip_clicks=1,
13
- expansion_ratio=1.4,
14
- min_crop_size=200,
15
- recompute_thresh_iou=0.5,
16
- prob_thresh=0.50):
 
 
17
  super().__init__()
18
  self.target_size = target_size
19
  self.min_crop_size = min_crop_size
@@ -41,8 +46,12 @@ class ZoomIn(BaseTransform):
41
  if self._prev_probs is not None:
42
  current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
43
  if current_pred_mask.sum() > 0:
44
- current_object_roi = get_object_roi(current_pred_mask, clicks_list,
45
- self.expansion_ratio, self.min_crop_size)
 
 
 
 
46
 
47
  if current_object_roi is None:
48
  if self.skip_clicks >= 0:
@@ -55,7 +64,10 @@ class ZoomIn(BaseTransform):
55
  update_object_roi = True
56
  elif not check_object_roi(self._object_roi, clicks_list):
57
  update_object_roi = True
58
- elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
 
 
 
59
  update_object_roi = True
60
 
61
  if update_object_roi:
@@ -73,12 +85,18 @@ class ZoomIn(BaseTransform):
73
 
74
  assert prob_map.shape[0] == 1
75
  rmin, rmax, cmin, cmax = self._object_roi
76
- prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
77
- mode='bilinear', align_corners=True)
 
 
 
 
78
 
79
  if self._prev_probs is not None:
80
- new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
81
- new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
 
 
82
  else:
83
  new_prob_map = prob_map
84
 
@@ -87,24 +105,46 @@ class ZoomIn(BaseTransform):
87
  return new_prob_map
88
 
89
  def check_possible_recalculation(self):
90
- if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
 
 
 
 
91
  return False
92
 
93
  pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
94
  if pred_mask.sum() > 0:
95
- possible_object_roi = get_object_roi(pred_mask, [],
96
- self.expansion_ratio, self.min_crop_size)
97
- image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
 
 
 
 
 
 
98
  if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
99
  return True
100
  return False
101
 
102
  def get_state(self):
103
  roi_image = self._roi_image.cpu() if self._roi_image is not None else None
104
- return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
 
 
 
 
 
 
105
 
106
  def set_state(self, state):
107
- self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
 
 
 
 
 
 
108
 
109
  def reset(self):
110
  self._input_image_shape = None
@@ -157,9 +197,13 @@ def get_roi_image_nd(image_nd, object_roi, target_size):
157
  new_width = int(round(width * scale))
158
 
159
  with torch.no_grad():
160
- roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
161
- roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
162
- mode='bilinear', align_corners=True)
 
 
 
 
163
 
164
  return roi_image_nd
165
 
 
1
+ from typing import List
2
+
3
  import torch
4
 
 
5
  from isegm.inference.clicker import Click
6
+ from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask,
7
+ get_bbox_iou)
8
+
9
  from .base import BaseTransform
10
 
11
 
12
  class ZoomIn(BaseTransform):
13
+ def __init__(
14
+ self,
15
+ target_size=400,
16
+ skip_clicks=1,
17
+ expansion_ratio=1.4,
18
+ min_crop_size=200,
19
+ recompute_thresh_iou=0.5,
20
+ prob_thresh=0.50,
21
+ ):
22
  super().__init__()
23
  self.target_size = target_size
24
  self.min_crop_size = min_crop_size
 
46
  if self._prev_probs is not None:
47
  current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
48
  if current_pred_mask.sum() > 0:
49
+ current_object_roi = get_object_roi(
50
+ current_pred_mask,
51
+ clicks_list,
52
+ self.expansion_ratio,
53
+ self.min_crop_size,
54
+ )
55
 
56
  if current_object_roi is None:
57
  if self.skip_clicks >= 0:
 
64
  update_object_roi = True
65
  elif not check_object_roi(self._object_roi, clicks_list):
66
  update_object_roi = True
67
+ elif (
68
+ get_bbox_iou(current_object_roi, self._object_roi)
69
+ < self.recompute_thresh_iou
70
+ ):
71
  update_object_roi = True
72
 
73
  if update_object_roi:
 
85
 
86
  assert prob_map.shape[0] == 1
87
  rmin, rmax, cmin, cmax = self._object_roi
88
+ prob_map = torch.nn.functional.interpolate(
89
+ prob_map,
90
+ size=(rmax - rmin + 1, cmax - cmin + 1),
91
+ mode="bilinear",
92
+ align_corners=True,
93
+ )
94
 
95
  if self._prev_probs is not None:
96
+ new_prob_map = torch.zeros(
97
+ *self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype
98
+ )
99
+ new_prob_map[:, :, rmin : rmax + 1, cmin : cmax + 1] = prob_map
100
  else:
101
  new_prob_map = prob_map
102
 
 
105
  return new_prob_map
106
 
107
  def check_possible_recalculation(self):
108
+ if (
109
+ self._prev_probs is None
110
+ or self._object_roi is not None
111
+ or self.skip_clicks > 0
112
+ ):
113
  return False
114
 
115
  pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
116
  if pred_mask.sum() > 0:
117
+ possible_object_roi = get_object_roi(
118
+ pred_mask, [], self.expansion_ratio, self.min_crop_size
119
+ )
120
+ image_roi = (
121
+ 0,
122
+ self._input_image_shape[2] - 1,
123
+ 0,
124
+ self._input_image_shape[3] - 1,
125
+ )
126
  if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
127
  return True
128
  return False
129
 
130
  def get_state(self):
131
  roi_image = self._roi_image.cpu() if self._roi_image is not None else None
132
+ return (
133
+ self._input_image_shape,
134
+ self._object_roi,
135
+ self._prev_probs,
136
+ roi_image,
137
+ self.image_changed,
138
+ )
139
 
140
  def set_state(self, state):
141
+ (
142
+ self._input_image_shape,
143
+ self._object_roi,
144
+ self._prev_probs,
145
+ self._roi_image,
146
+ self.image_changed,
147
+ ) = state
148
 
149
  def reset(self):
150
  self._input_image_shape = None
 
197
  new_width = int(round(width * scale))
198
 
199
  with torch.no_grad():
200
+ roi_image_nd = image_nd[:, :, rmin : rmax + 1, cmin : cmax + 1]
201
+ roi_image_nd = torch.nn.functional.interpolate(
202
+ roi_image_nd,
203
+ size=(new_height, new_width),
204
+ mode="bilinear",
205
+ align_corners=True,
206
+ )
207
 
208
  return roi_image_nd
209
 
isegm/inference/utils.py CHANGED
@@ -1,10 +1,11 @@
1
  from datetime import timedelta
2
  from pathlib import Path
3
 
4
- import torch
5
  import numpy as np
 
6
 
7
- from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset
 
8
  from isegm.utils.serialization import load_model
9
 
10
 
@@ -20,7 +21,7 @@ def get_time_metrics(all_ious, elapsed_time):
20
 
21
  def load_is_model(checkpoint, device, **kwargs):
22
  if isinstance(checkpoint, (str, Path)):
23
- state_dict = torch.load(checkpoint, map_location='cpu')
24
  else:
25
  state_dict = checkpoint
26
 
@@ -34,8 +35,8 @@ def load_is_model(checkpoint, device, **kwargs):
34
 
35
 
36
  def load_single_is_model(state_dict, device, **kwargs):
37
- model = load_model(state_dict['config'], **kwargs)
38
- model.load_state_dict(state_dict['state_dict'], strict=False)
39
 
40
  for param in model.parameters():
41
  param.requires_grad = False
@@ -46,19 +47,19 @@ def load_single_is_model(state_dict, device, **kwargs):
46
 
47
 
48
  def get_dataset(dataset_name, cfg):
49
- if dataset_name == 'GrabCut':
50
  dataset = GrabCutDataset(cfg.GRABCUT_PATH)
51
- elif dataset_name == 'Berkeley':
52
  dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
53
- elif dataset_name == 'DAVIS':
54
  dataset = DavisDataset(cfg.DAVIS_PATH)
55
- elif dataset_name == 'SBD':
56
  dataset = SBDEvaluationDataset(cfg.SBD_PATH)
57
- elif dataset_name == 'SBD_Train':
58
- dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train')
59
- elif dataset_name == 'PascalVOC':
60
- dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test')
61
- elif dataset_name == 'COCO_MVal':
62
  dataset = DavisDataset(cfg.COCO_MVAL_PATH)
63
  else:
64
  dataset = None
@@ -70,8 +71,12 @@ def get_iou(gt_mask, pred_mask, ignore_label=-1):
70
  ignore_gt_mask_inv = gt_mask != ignore_label
71
  obj_gt_mask = gt_mask == 1
72
 
73
- intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
74
- union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
 
 
 
 
75
 
76
  return intersection / union
77
 
@@ -84,8 +89,9 @@ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
84
  noc_list = []
85
  over_max_list = []
86
  for iou_thr in iou_thrs:
87
- scores_arr = np.array([_get_noc(iou_arr, iou_thr)
88
- for iou_arr in all_ious], dtype=np.int)
 
89
 
90
  score = scores_arr.mean()
91
  over_max = (scores_arr == max_clicks).sum()
@@ -98,46 +104,58 @@ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
98
 
99
  def find_checkpoint(weights_folder, checkpoint_name):
100
  weights_folder = Path(weights_folder)
101
- if ':' in checkpoint_name:
102
- model_name, checkpoint_name = checkpoint_name.split(':')
103
- models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
 
 
104
  assert len(models_candidates) == 1
105
  model_folder = models_candidates[0]
106
  else:
107
  model_folder = weights_folder
108
 
109
- if checkpoint_name.endswith('.pth'):
110
  if Path(checkpoint_name).exists():
111
  checkpoint_path = checkpoint_name
112
  else:
113
  checkpoint_path = weights_folder / checkpoint_name
114
  else:
115
- model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
116
  assert len(model_checkpoints) == 1
117
  checkpoint_path = model_checkpoints[0]
118
 
119
  return str(checkpoint_path)
120
 
121
 
122
- def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
123
- n_clicks=20, model_name=None):
124
- table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
125
- f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
126
- f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
127
- f'{"SPC,s":^7}|{"Time":^9}|')
 
 
 
 
 
 
 
 
 
 
128
  row_width = len(table_header)
129
 
130
- header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
131
- header += '-' * row_width + '\n'
132
- header += table_header + '\n' + '-' * row_width
133
 
134
  eval_time = str(timedelta(seconds=int(elapsed_time)))
135
- table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
136
- table_row += f'{noc_list[0]:^9.2f}|'
137
- table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
138
- table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
139
- table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
140
- table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
141
- table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
142
-
143
- return header, table_row
 
1
  from datetime import timedelta
2
  from pathlib import Path
3
 
 
4
  import numpy as np
5
+ import torch
6
 
7
+ from isegm.data.datasets import (BerkeleyDataset, DavisDataset, GrabCutDataset,
8
+ PascalVocDataset, SBDEvaluationDataset)
9
  from isegm.utils.serialization import load_model
10
 
11
 
 
21
 
22
  def load_is_model(checkpoint, device, **kwargs):
23
  if isinstance(checkpoint, (str, Path)):
24
+ state_dict = torch.load(checkpoint, map_location="cpu")
25
  else:
26
  state_dict = checkpoint
27
 
 
35
 
36
 
37
  def load_single_is_model(state_dict, device, **kwargs):
38
+ model = load_model(state_dict["config"], **kwargs)
39
+ model.load_state_dict(state_dict["state_dict"], strict=False)
40
 
41
  for param in model.parameters():
42
  param.requires_grad = False
 
47
 
48
 
49
  def get_dataset(dataset_name, cfg):
50
+ if dataset_name == "GrabCut":
51
  dataset = GrabCutDataset(cfg.GRABCUT_PATH)
52
+ elif dataset_name == "Berkeley":
53
  dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
54
+ elif dataset_name == "DAVIS":
55
  dataset = DavisDataset(cfg.DAVIS_PATH)
56
+ elif dataset_name == "SBD":
57
  dataset = SBDEvaluationDataset(cfg.SBD_PATH)
58
+ elif dataset_name == "SBD_Train":
59
+ dataset = SBDEvaluationDataset(cfg.SBD_PATH, split="train")
60
+ elif dataset_name == "PascalVOC":
61
+ dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split="test")
62
+ elif dataset_name == "COCO_MVal":
63
  dataset = DavisDataset(cfg.COCO_MVAL_PATH)
64
  else:
65
  dataset = None
 
71
  ignore_gt_mask_inv = gt_mask != ignore_label
72
  obj_gt_mask = gt_mask == 1
73
 
74
+ intersection = np.logical_and(
75
+ np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv
76
+ ).sum()
77
+ union = np.logical_and(
78
+ np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv
79
+ ).sum()
80
 
81
  return intersection / union
82
 
 
89
  noc_list = []
90
  over_max_list = []
91
  for iou_thr in iou_thrs:
92
+ scores_arr = np.array(
93
+ [_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int
94
+ )
95
 
96
  score = scores_arr.mean()
97
  over_max = (scores_arr == max_clicks).sum()
 
104
 
105
  def find_checkpoint(weights_folder, checkpoint_name):
106
  weights_folder = Path(weights_folder)
107
+ if ":" in checkpoint_name:
108
+ model_name, checkpoint_name = checkpoint_name.split(":")
109
+ models_candidates = [
110
+ x for x in weights_folder.glob(f"{model_name}*") if x.is_dir()
111
+ ]
112
  assert len(models_candidates) == 1
113
  model_folder = models_candidates[0]
114
  else:
115
  model_folder = weights_folder
116
 
117
+ if checkpoint_name.endswith(".pth"):
118
  if Path(checkpoint_name).exists():
119
  checkpoint_path = checkpoint_name
120
  else:
121
  checkpoint_path = weights_folder / checkpoint_name
122
  else:
123
+ model_checkpoints = list(model_folder.rglob(f"{checkpoint_name}*.pth"))
124
  assert len(model_checkpoints) == 1
125
  checkpoint_path = model_checkpoints[0]
126
 
127
  return str(checkpoint_path)
128
 
129
 
130
+ def get_results_table(
131
+ noc_list,
132
+ over_max_list,
133
+ brs_type,
134
+ dataset_name,
135
+ mean_spc,
136
+ elapsed_time,
137
+ n_clicks=20,
138
+ model_name=None,
139
+ ):
140
+ table_header = (
141
+ f'|{"BRS Type":^13}|{"Dataset":^11}|'
142
+ f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
143
+ f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
144
+ f'{"SPC,s":^7}|{"Time":^9}|'
145
+ )
146
  row_width = len(table_header)
147
 
148
+ header = f"Eval results for model: {model_name}\n" if model_name is not None else ""
149
+ header += "-" * row_width + "\n"
150
+ header += table_header + "\n" + "-" * row_width
151
 
152
  eval_time = str(timedelta(seconds=int(elapsed_time)))
153
+ table_row = f"|{brs_type:^13}|{dataset_name:^11}|"
154
+ table_row += f"{noc_list[0]:^9.2f}|"
155
+ table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|'
156
+ table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|'
157
+ table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|'
158
+ table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|'
159
+ table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|"
160
+
161
+ return header, table_row
isegm/model/initializer.py CHANGED
@@ -1,6 +1,6 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import numpy as np
4
 
5
 
6
  class Initializer(object):
@@ -9,24 +9,37 @@ class Initializer(object):
9
  self.gamma = gamma
10
 
11
  def __call__(self, m):
12
- if getattr(m, '__initialized', False):
13
  return
14
 
15
- if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
16
- nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
17
- nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if m.weight is not None:
19
  self._init_gamma(m.weight.data)
20
  if m.bias is not None:
21
  self._init_beta(m.bias.data)
22
  else:
23
- if getattr(m, 'weight', None) is not None:
24
  self._init_weight(m.weight.data)
25
- if getattr(m, 'bias', None) is not None:
26
  self._init_bias(m.bias.data)
27
 
28
  if self.local_init:
29
- object.__setattr__(m, '__initialized', True)
30
 
31
  def _init_weight(self, data):
32
  nn.init.uniform_(data, -0.07, 0.07)
@@ -71,13 +84,15 @@ class Bilinear(Initializer):
71
  center = scale - 0.5 * (1 + kernel_size % 2)
72
 
73
  og = np.ogrid[:kernel_size, :kernel_size]
74
- kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
 
 
75
 
76
  return torch.tensor(kernel, dtype=torch.float32)
77
 
78
 
79
  class XavierGluon(Initializer):
80
- def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
81
  super().__init__(**kwargs)
82
 
83
  self.rnd_type = rnd_type
@@ -87,19 +102,19 @@ class XavierGluon(Initializer):
87
  def _init_weight(self, arr):
88
  fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
89
 
90
- if self.factor_type == 'avg':
91
  factor = (fan_in + fan_out) / 2.0
92
- elif self.factor_type == 'in':
93
  factor = fan_in
94
- elif self.factor_type == 'out':
95
  factor = fan_out
96
  else:
97
- raise ValueError('Incorrect factor type')
98
  scale = np.sqrt(self.magnitude / factor)
99
 
100
- if self.rnd_type == 'uniform':
101
  nn.init.uniform_(arr, -scale, scale)
102
- elif self.rnd_type == 'gaussian':
103
  nn.init.normal_(arr, 0, scale)
104
  else:
105
- raise ValueError('Unknown random type')
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn as nn
 
4
 
5
 
6
  class Initializer(object):
 
9
  self.gamma = gamma
10
 
11
  def __call__(self, m):
12
+ if getattr(m, "__initialized", False):
13
  return
14
 
15
+ if (
16
+ isinstance(
17
+ m,
18
+ (
19
+ nn.BatchNorm1d,
20
+ nn.BatchNorm2d,
21
+ nn.BatchNorm3d,
22
+ nn.InstanceNorm1d,
23
+ nn.InstanceNorm2d,
24
+ nn.InstanceNorm3d,
25
+ nn.GroupNorm,
26
+ nn.SyncBatchNorm,
27
+ ),
28
+ )
29
+ or "BatchNorm" in m.__class__.__name__
30
+ ):
31
  if m.weight is not None:
32
  self._init_gamma(m.weight.data)
33
  if m.bias is not None:
34
  self._init_beta(m.bias.data)
35
  else:
36
+ if getattr(m, "weight", None) is not None:
37
  self._init_weight(m.weight.data)
38
+ if getattr(m, "bias", None) is not None:
39
  self._init_bias(m.bias.data)
40
 
41
  if self.local_init:
42
+ object.__setattr__(m, "__initialized", True)
43
 
44
  def _init_weight(self, data):
45
  nn.init.uniform_(data, -0.07, 0.07)
 
84
  center = scale - 0.5 * (1 + kernel_size % 2)
85
 
86
  og = np.ogrid[:kernel_size, :kernel_size]
87
+ kernel = (1 - np.abs(og[0] - center) / scale) * (
88
+ 1 - np.abs(og[1] - center) / scale
89
+ )
90
 
91
  return torch.tensor(kernel, dtype=torch.float32)
92
 
93
 
94
  class XavierGluon(Initializer):
95
+ def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3, **kwargs):
96
  super().__init__(**kwargs)
97
 
98
  self.rnd_type = rnd_type
 
102
  def _init_weight(self, arr):
103
  fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
104
 
105
+ if self.factor_type == "avg":
106
  factor = (fan_in + fan_out) / 2.0
107
+ elif self.factor_type == "in":
108
  factor = fan_in
109
+ elif self.factor_type == "out":
110
  factor = fan_out
111
  else:
112
+ raise ValueError("Incorrect factor type")
113
  scale = np.sqrt(self.magnitude / factor)
114
 
115
+ if self.rnd_type == "uniform":
116
  nn.init.uniform_(arr, -scale, scale)
117
+ elif self.rnd_type == "gaussian":
118
  nn.init.normal_(arr, 0, scale)
119
  else:
120
+ raise ValueError("Unknown random type")
isegm/model/is_deeplab_model.py CHANGED
@@ -1,25 +1,44 @@
1
  import torch.nn as nn
2
 
 
3
  from isegm.utils.serialization import serialize
 
4
  from .is_model import ISModel
5
- from .modeling.deeplab_v3 import DeepLabV3Plus
6
  from .modeling.basic_blocks import SepConvHead
7
- from isegm.model.modifiers import LRMult
8
 
9
 
10
  class DeeplabModel(ISModel):
11
  @serialize
12
- def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
13
- backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs):
 
 
 
 
 
 
 
 
14
  super().__init__(norm_layer=norm_layer, **kwargs)
15
 
16
- self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout,
17
- norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer)
 
 
 
 
 
18
  self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
19
- self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
20
- num_layers=2, norm_layer=norm_layer)
 
 
 
 
 
21
 
22
  def backbone_forward(self, image, coord_features=None):
23
  backbone_features = self.feature_extractor(image, coord_features)
24
 
25
- return {'instances': self.head(backbone_features[0])}
 
1
  import torch.nn as nn
2
 
3
+ from isegm.model.modifiers import LRMult
4
  from isegm.utils.serialization import serialize
5
+
6
  from .is_model import ISModel
 
7
  from .modeling.basic_blocks import SepConvHead
8
+ from .modeling.deeplab_v3 import DeepLabV3Plus
9
 
10
 
11
  class DeeplabModel(ISModel):
12
  @serialize
13
+ def __init__(
14
+ self,
15
+ backbone="resnet50",
16
+ deeplab_ch=256,
17
+ aspp_dropout=0.5,
18
+ backbone_norm_layer=None,
19
+ backbone_lr_mult=0.1,
20
+ norm_layer=nn.BatchNorm2d,
21
+ **kwargs
22
+ ):
23
  super().__init__(norm_layer=norm_layer, **kwargs)
24
 
25
+ self.feature_extractor = DeepLabV3Plus(
26
+ backbone=backbone,
27
+ ch=deeplab_ch,
28
+ project_dropout=aspp_dropout,
29
+ norm_layer=norm_layer,
30
+ backbone_norm_layer=backbone_norm_layer,
31
+ )
32
  self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
33
+ self.head = SepConvHead(
34
+ 1,
35
+ in_channels=deeplab_ch,
36
+ mid_channels=deeplab_ch // 2,
37
+ num_layers=2,
38
+ norm_layer=norm_layer,
39
+ )
40
 
41
  def backbone_forward(self, image, coord_features=None):
42
  backbone_features = self.feature_extractor(image, coord_features)
43
 
44
+ return {"instances": self.head(backbone_features[0])}
isegm/model/is_hrnet_model.py CHANGED
@@ -1,19 +1,32 @@
1
  import torch.nn as nn
2
 
 
3
  from isegm.utils.serialization import serialize
 
4
  from .is_model import ISModel
5
  from .modeling.hrnet_ocr import HighResolutionNet
6
- from isegm.model.modifiers import LRMult
7
 
8
 
9
  class HRNetModel(ISModel):
10
  @serialize
11
- def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1,
12
- norm_layer=nn.BatchNorm2d, **kwargs):
 
 
 
 
 
 
 
13
  super().__init__(norm_layer=norm_layer, **kwargs)
14
 
15
- self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
16
- num_classes=1, norm_layer=norm_layer)
 
 
 
 
 
17
  self.feature_extractor.apply(LRMult(backbone_lr_mult))
18
  if ocr_width > 0:
19
  self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
@@ -23,4 +36,4 @@ class HRNetModel(ISModel):
23
  def backbone_forward(self, image, coord_features=None):
24
  net_outputs = self.feature_extractor(image, coord_features)
25
 
26
- return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]}
 
1
  import torch.nn as nn
2
 
3
+ from isegm.model.modifiers import LRMult
4
  from isegm.utils.serialization import serialize
5
+
6
  from .is_model import ISModel
7
  from .modeling.hrnet_ocr import HighResolutionNet
 
8
 
9
 
10
  class HRNetModel(ISModel):
11
  @serialize
12
+ def __init__(
13
+ self,
14
+ width=48,
15
+ ocr_width=256,
16
+ small=False,
17
+ backbone_lr_mult=0.1,
18
+ norm_layer=nn.BatchNorm2d,
19
+ **kwargs
20
+ ):
21
  super().__init__(norm_layer=norm_layer, **kwargs)
22
 
23
+ self.feature_extractor = HighResolutionNet(
24
+ width=width,
25
+ ocr_width=ocr_width,
26
+ small=small,
27
+ num_classes=1,
28
+ norm_layer=norm_layer,
29
+ )
30
  self.feature_extractor.apply(LRMult(backbone_lr_mult))
31
  if ocr_width > 0:
32
  self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
 
36
  def backbone_forward(self, image, coord_features=None):
37
  net_outputs = self.feature_extractor(image, coord_features)
38
 
39
+ return {"instances": net_outputs[0], "instances_aux": net_outputs[1]}
isegm/model/is_model.py CHANGED
@@ -1,17 +1,27 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import numpy as np
4
 
5
- from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize
6
  from isegm.model.modifiers import LRMult
 
7
 
8
 
9
  class ISModel(nn.Module):
10
- def __init__(self, use_rgb_conv=True, with_aux_output=False,
11
- norm_radius=260, use_disks=False, cpu_dist_maps=False,
12
- clicks_groups=None, with_prev_mask=False, use_leaky_relu=False,
13
- binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d,
14
- norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
 
 
 
 
 
 
 
 
 
 
15
  super().__init__()
16
  self.with_aux_output = with_aux_output
17
  self.clicks_groups = clicks_groups
@@ -28,35 +38,64 @@ class ISModel(nn.Module):
28
 
29
  if use_rgb_conv:
30
  rgb_conv_layers = [
31
- nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1),
 
 
 
 
32
  norm_layer(6 + self.coord_feature_ch),
33
- nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
34
- nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1)
 
 
 
 
35
  ]
36
  self.rgb_conv = nn.Sequential(*rgb_conv_layers)
37
  elif conv_extend:
38
  self.rgb_conv = None
39
- self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64,
40
- kernel_size=3, stride=2, padding=1)
 
 
 
 
 
41
  self.maps_transform.apply(LRMult(0.1))
42
  else:
43
  self.rgb_conv = None
44
  mt_layers = [
45
- nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1),
46
- nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
47
- nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1),
48
- ScaleLayer(init_value=0.05, lr_mult=1)
 
 
 
 
 
 
49
  ]
50
  self.maps_transform = nn.Sequential(*mt_layers)
51
 
52
  if self.clicks_groups is not None:
53
  self.dist_maps = nn.ModuleList()
54
  for click_radius in self.clicks_groups:
55
- self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0,
56
- cpu_mode=cpu_dist_maps, use_disks=use_disks))
 
 
 
 
 
 
57
  else:
58
- self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
59
- cpu_mode=cpu_dist_maps, use_disks=use_disks)
 
 
 
 
60
 
61
  def forward(self, image, points):
62
  image, prev_mask = self.prepare_input(image)
@@ -69,11 +108,19 @@ class ISModel(nn.Module):
69
  coord_features = self.maps_transform(coord_features)
70
  outputs = self.backbone_forward(image, coord_features)
71
 
72
- outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:],
73
- mode='bilinear', align_corners=True)
 
 
 
 
74
  if self.with_aux_output:
75
- outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:],
76
- mode='bilinear', align_corners=True)
 
 
 
 
77
 
78
  return outputs
79
 
@@ -93,8 +140,13 @@ class ISModel(nn.Module):
93
 
94
  def get_coord_features(self, image, prev_mask, points):
95
  if self.clicks_groups is not None:
96
- points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,))
97
- coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)]
 
 
 
 
 
98
  coord_features = torch.cat(coord_features, dim=1)
99
  else:
100
  coord_features = self.dist_maps(image, points)
@@ -112,8 +164,7 @@ def split_points_by_order(tpoints: torch.Tensor, groups):
112
  num_points = points.shape[1] // 2
113
 
114
  groups = [x if x > 0 else num_points for x in groups]
115
- group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
116
- for x in groups]
117
 
118
  last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
119
  for group_indx, group_size in enumerate(groups):
@@ -127,7 +178,9 @@ def split_points_by_order(tpoints: torch.Tensor, groups):
127
  continue
128
 
129
  is_negative = int(pindx >= num_points)
130
- if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click
 
 
131
  group_id = num_groups - 1
132
 
133
  new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
@@ -135,7 +188,9 @@ def split_points_by_order(tpoints: torch.Tensor, groups):
135
 
136
  group_points[group_id][bindx, new_point_indx, :] = point
137
 
138
- group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
139
- for x in group_points]
 
 
140
 
141
  return group_points
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn as nn
 
4
 
 
5
  from isegm.model.modifiers import LRMult
6
+ from isegm.model.ops import BatchImageNormalize, DistMaps, ScaleLayer
7
 
8
 
9
  class ISModel(nn.Module):
10
+ def __init__(
11
+ self,
12
+ use_rgb_conv=True,
13
+ with_aux_output=False,
14
+ norm_radius=260,
15
+ use_disks=False,
16
+ cpu_dist_maps=False,
17
+ clicks_groups=None,
18
+ with_prev_mask=False,
19
+ use_leaky_relu=False,
20
+ binary_prev_mask=False,
21
+ conv_extend=False,
22
+ norm_layer=nn.BatchNorm2d,
23
+ norm_mean_std=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
+ ):
25
  super().__init__()
26
  self.with_aux_output = with_aux_output
27
  self.clicks_groups = clicks_groups
 
38
 
39
  if use_rgb_conv:
40
  rgb_conv_layers = [
41
+ nn.Conv2d(
42
+ in_channels=3 + self.coord_feature_ch,
43
+ out_channels=6 + self.coord_feature_ch,
44
+ kernel_size=1,
45
+ ),
46
  norm_layer(6 + self.coord_feature_ch),
47
+ nn.LeakyReLU(negative_slope=0.2)
48
+ if use_leaky_relu
49
+ else nn.ReLU(inplace=True),
50
+ nn.Conv2d(
51
+ in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1
52
+ ),
53
  ]
54
  self.rgb_conv = nn.Sequential(*rgb_conv_layers)
55
  elif conv_extend:
56
  self.rgb_conv = None
57
+ self.maps_transform = nn.Conv2d(
58
+ in_channels=self.coord_feature_ch,
59
+ out_channels=64,
60
+ kernel_size=3,
61
+ stride=2,
62
+ padding=1,
63
+ )
64
  self.maps_transform.apply(LRMult(0.1))
65
  else:
66
  self.rgb_conv = None
67
  mt_layers = [
68
+ nn.Conv2d(
69
+ in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1
70
+ ),
71
+ nn.LeakyReLU(negative_slope=0.2)
72
+ if use_leaky_relu
73
+ else nn.ReLU(inplace=True),
74
+ nn.Conv2d(
75
+ in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1
76
+ ),
77
+ ScaleLayer(init_value=0.05, lr_mult=1),
78
  ]
79
  self.maps_transform = nn.Sequential(*mt_layers)
80
 
81
  if self.clicks_groups is not None:
82
  self.dist_maps = nn.ModuleList()
83
  for click_radius in self.clicks_groups:
84
+ self.dist_maps.append(
85
+ DistMaps(
86
+ norm_radius=click_radius,
87
+ spatial_scale=1.0,
88
+ cpu_mode=cpu_dist_maps,
89
+ use_disks=use_disks,
90
+ )
91
+ )
92
  else:
93
+ self.dist_maps = DistMaps(
94
+ norm_radius=norm_radius,
95
+ spatial_scale=1.0,
96
+ cpu_mode=cpu_dist_maps,
97
+ use_disks=use_disks,
98
+ )
99
 
100
  def forward(self, image, points):
101
  image, prev_mask = self.prepare_input(image)
 
108
  coord_features = self.maps_transform(coord_features)
109
  outputs = self.backbone_forward(image, coord_features)
110
 
111
+ outputs["instances"] = nn.functional.interpolate(
112
+ outputs["instances"],
113
+ size=image.size()[2:],
114
+ mode="bilinear",
115
+ align_corners=True,
116
+ )
117
  if self.with_aux_output:
118
+ outputs["instances_aux"] = nn.functional.interpolate(
119
+ outputs["instances_aux"],
120
+ size=image.size()[2:],
121
+ mode="bilinear",
122
+ align_corners=True,
123
+ )
124
 
125
  return outputs
126
 
 
140
 
141
  def get_coord_features(self, image, prev_mask, points):
142
  if self.clicks_groups is not None:
143
+ points_groups = split_points_by_order(
144
+ points, groups=(2,) + (1,) * (len(self.clicks_groups) - 2) + (-1,)
145
+ )
146
+ coord_features = [
147
+ dist_map(image, pg)
148
+ for dist_map, pg in zip(self.dist_maps, points_groups)
149
+ ]
150
  coord_features = torch.cat(coord_features, dim=1)
151
  else:
152
  coord_features = self.dist_maps(image, points)
 
164
  num_points = points.shape[1] // 2
165
 
166
  groups = [x if x > 0 else num_points for x in groups]
167
+ group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) for x in groups]
 
168
 
169
  last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
170
  for group_indx, group_size in enumerate(groups):
 
178
  continue
179
 
180
  is_negative = int(pindx >= num_points)
181
+ if group_id >= num_groups or (
182
+ group_id == 0 and is_negative
183
+ ): # disable negative first click
184
  group_id = num_groups - 1
185
 
186
  new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
 
188
 
189
  group_points[group_id][bindx, new_point_indx, :] = point
190
 
191
+ group_points = [
192
+ torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
193
+ for x in group_points
194
+ ]
195
 
196
  return group_points
isegm/model/losses.py CHANGED
@@ -7,10 +7,20 @@ from isegm.utils import misc
7
 
8
 
9
  class NormalizedFocalLossSigmoid(nn.Module):
10
- def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12,
11
- from_sigmoid=False, detach_delimeter=True,
12
- batch_axis=0, weight=None, size_average=True,
13
- ignore_label=-1):
 
 
 
 
 
 
 
 
 
 
14
  super(NormalizedFocalLossSigmoid, self).__init__()
15
  self._axis = axis
16
  self._alpha = alpha
@@ -34,8 +44,12 @@ class NormalizedFocalLossSigmoid(nn.Module):
34
  if not self._from_logits:
35
  pred = torch.sigmoid(pred)
36
 
37
- alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
38
- pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
 
 
 
 
39
 
40
  beta = (1 - pt) ** self._gamma
41
 
@@ -49,37 +63,69 @@ class NormalizedFocalLossSigmoid(nn.Module):
49
  beta = torch.clamp_max(beta, self._max_mult)
50
 
51
  with torch.no_grad():
52
- ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
53
- sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
 
 
 
 
 
 
54
  if np.any(ignore_area == 0):
55
- self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
 
 
56
 
57
  beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
58
  beta_pmax = beta_pmax.mean().item()
59
  self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
60
 
61
- loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
 
 
 
 
 
 
 
 
62
  loss = self._weight * (loss * sample_weight)
63
 
64
  if self._size_average:
65
- bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
66
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
 
 
 
 
 
67
  else:
68
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
 
 
69
 
70
  return loss
71
 
72
  def log_states(self, sw, name, global_step):
73
- sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
74
- sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step)
75
 
76
 
77
  class FocalLoss(nn.Module):
78
- def __init__(self, axis=-1, alpha=0.25, gamma=2,
79
- from_logits=False, batch_axis=0,
80
- weight=None, num_class=None,
81
- eps=1e-9, size_average=True, scale=1.0,
82
- ignore_label=-1):
 
 
 
 
 
 
 
 
 
83
  super(FocalLoss, self).__init__()
84
  self._axis = axis
85
  self._alpha = alpha
@@ -101,19 +147,38 @@ class FocalLoss(nn.Module):
101
  if not self._from_logits:
102
  pred = torch.sigmoid(pred)
103
 
104
- alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
105
- pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
 
 
 
 
106
 
107
  beta = (1 - pt) ** self._gamma
108
 
109
- loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
 
 
 
 
 
 
 
 
110
  loss = self._weight * (loss * sample_weight)
111
 
112
  if self._size_average:
113
- tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
114
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
 
 
 
 
 
115
  else:
116
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
 
 
117
 
118
  return self._scale * loss
119
 
@@ -131,8 +196,9 @@ class SoftIoU(nn.Module):
131
  if not self._from_sigmoid:
132
  pred = torch.sigmoid(pred)
133
 
134
- loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \
135
- / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8)
 
136
 
137
  return loss
138
 
@@ -154,8 +220,12 @@ class SigmoidBinaryCrossEntropyLoss(nn.Module):
154
  loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
155
  else:
156
  eps = 1e-12
157
- loss = -(torch.log(pred + eps) * label
158
- + torch.log(1. - pred + eps) * (1. - label))
 
 
159
 
160
  loss = self._weight * (loss * sample_weight)
161
- return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
 
 
 
7
 
8
 
9
  class NormalizedFocalLossSigmoid(nn.Module):
10
+ def __init__(
11
+ self,
12
+ axis=-1,
13
+ alpha=0.25,
14
+ gamma=2,
15
+ max_mult=-1,
16
+ eps=1e-12,
17
+ from_sigmoid=False,
18
+ detach_delimeter=True,
19
+ batch_axis=0,
20
+ weight=None,
21
+ size_average=True,
22
+ ignore_label=-1,
23
+ ):
24
  super(NormalizedFocalLossSigmoid, self).__init__()
25
  self._axis = axis
26
  self._alpha = alpha
 
44
  if not self._from_logits:
45
  pred = torch.sigmoid(pred)
46
 
47
+ alpha = torch.where(
48
+ one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight
49
+ )
50
+ pt = torch.where(
51
+ sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)
52
+ )
53
 
54
  beta = (1 - pt) ** self._gamma
55
 
 
63
  beta = torch.clamp_max(beta, self._max_mult)
64
 
65
  with torch.no_grad():
66
+ ignore_area = (
67
+ torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim())))
68
+ .cpu()
69
+ .numpy()
70
+ )
71
+ sample_mult = (
72
+ torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
73
+ )
74
  if np.any(ignore_area == 0):
75
+ self._k_sum = (
76
+ 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
77
+ )
78
 
79
  beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
80
  beta_pmax = beta_pmax.mean().item()
81
  self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
82
 
83
+ loss = (
84
+ -alpha
85
+ * beta
86
+ * torch.log(
87
+ torch.min(
88
+ pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)
89
+ )
90
+ )
91
+ )
92
  loss = self._weight * (loss * sample_weight)
93
 
94
  if self._size_average:
95
+ bsum = torch.sum(
96
+ sample_weight,
97
+ dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis),
98
+ )
99
+ loss = torch.sum(
100
+ loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
101
+ ) / (bsum + self._eps)
102
  else:
103
+ loss = torch.sum(
104
+ loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
105
+ )
106
 
107
  return loss
108
 
109
  def log_states(self, sw, name, global_step):
110
+ sw.add_scalar(tag=name + "_k", value=self._k_sum, global_step=global_step)
111
+ sw.add_scalar(tag=name + "_m", value=self._m_max, global_step=global_step)
112
 
113
 
114
  class FocalLoss(nn.Module):
115
+ def __init__(
116
+ self,
117
+ axis=-1,
118
+ alpha=0.25,
119
+ gamma=2,
120
+ from_logits=False,
121
+ batch_axis=0,
122
+ weight=None,
123
+ num_class=None,
124
+ eps=1e-9,
125
+ size_average=True,
126
+ scale=1.0,
127
+ ignore_label=-1,
128
+ ):
129
  super(FocalLoss, self).__init__()
130
  self._axis = axis
131
  self._alpha = alpha
 
147
  if not self._from_logits:
148
  pred = torch.sigmoid(pred)
149
 
150
+ alpha = torch.where(
151
+ one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight
152
+ )
153
+ pt = torch.where(
154
+ sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)
155
+ )
156
 
157
  beta = (1 - pt) ** self._gamma
158
 
159
+ loss = (
160
+ -alpha
161
+ * beta
162
+ * torch.log(
163
+ torch.min(
164
+ pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)
165
+ )
166
+ )
167
+ )
168
  loss = self._weight * (loss * sample_weight)
169
 
170
  if self._size_average:
171
+ tsum = torch.sum(
172
+ sample_weight,
173
+ dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis),
174
+ )
175
+ loss = torch.sum(
176
+ loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
177
+ ) / (tsum + self._eps)
178
  else:
179
+ loss = torch.sum(
180
+ loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
181
+ )
182
 
183
  return self._scale * loss
184
 
 
196
  if not self._from_sigmoid:
197
  pred = torch.sigmoid(pred)
198
 
199
+ loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) / (
200
+ torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8
201
+ )
202
 
203
  return loss
204
 
 
220
  loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
221
  else:
222
  eps = 1e-12
223
+ loss = -(
224
+ torch.log(pred + eps) * label
225
+ + torch.log(1.0 - pred + eps) * (1.0 - label)
226
+ )
227
 
228
  loss = self._weight * (loss * sample_weight)
229
+ return torch.mean(
230
+ loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
231
+ )
isegm/model/metrics.py CHANGED
@@ -1,5 +1,5 @@
1
- import torch
2
  import numpy as np
 
3
 
4
  from isegm.utils import misc
5
 
@@ -27,9 +27,17 @@ class TrainMetric(object):
27
 
28
 
29
  class AdaptiveIoU(TrainMetric):
30
- def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
31
- ignore_label=-1, from_logits=True,
32
- pred_output='instances', gt_output='instances'):
 
 
 
 
 
 
 
 
33
  super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
34
  self._ignore_label = ignore_label
35
  self._from_logits = from_logits
@@ -59,7 +67,9 @@ class AdaptiveIoU(TrainMetric):
59
  max_iou = temp_iou
60
  best_thresh = t
61
 
62
- self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
 
 
63
  self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
64
  self._epoch_iou_sum += max_iou
65
  self._epoch_batch_count += 1
@@ -75,8 +85,14 @@ class AdaptiveIoU(TrainMetric):
75
  self._epoch_batch_count = 0
76
 
77
  def log_states(self, sw, tag_prefix, global_step):
78
- sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
79
- sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
 
 
 
 
 
 
80
 
81
  @property
82
  def iou_thresh(self):
@@ -88,8 +104,18 @@ def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
88
  pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
89
 
90
  reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
91
- union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
92
- intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
93
  nonzero = union > 0
94
 
95
  iou = intersection[nonzero] / union[nonzero]
 
 
1
  import numpy as np
2
+ import torch
3
 
4
  from isegm.utils import misc
5
 
 
27
 
28
 
29
  class AdaptiveIoU(TrainMetric):
30
+ def __init__(
31
+ self,
32
+ init_thresh=0.4,
33
+ thresh_step=0.025,
34
+ thresh_beta=0.99,
35
+ iou_beta=0.9,
36
+ ignore_label=-1,
37
+ from_logits=True,
38
+ pred_output="instances",
39
+ gt_output="instances",
40
+ ):
41
  super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
42
  self._ignore_label = ignore_label
43
  self._from_logits = from_logits
 
67
  max_iou = temp_iou
68
  best_thresh = t
69
 
70
+ self._iou_thresh = (
71
+ self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
72
+ )
73
  self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
74
  self._epoch_iou_sum += max_iou
75
  self._epoch_batch_count += 1
 
85
  self._epoch_batch_count = 0
86
 
87
  def log_states(self, sw, tag_prefix, global_step):
88
+ sw.add_scalar(
89
+ tag=tag_prefix + "_ema_iou", value=self._ema_iou, global_step=global_step
90
+ )
91
+ sw.add_scalar(
92
+ tag=tag_prefix + "_iou_thresh",
93
+ value=self._iou_thresh,
94
+ global_step=global_step,
95
+ )
96
 
97
  @property
98
  def iou_thresh(self):
 
104
  pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
105
 
106
  reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
107
+ union = (
108
+ torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims)
109
+ .detach()
110
+ .cpu()
111
+ .numpy()
112
+ )
113
+ intersection = (
114
+ torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims)
115
+ .detach()
116
+ .cpu()
117
+ .numpy()
118
+ )
119
  nonzero = union > 0
120
 
121
  iou = intersection[nonzero] / union[nonzero]
isegm/model/modeling/basic_blocks.py CHANGED
@@ -4,18 +4,28 @@ from isegm.model import ops
4
 
5
 
6
  class ConvHead(nn.Module):
7
- def __init__(self, out_channels, in_channels=32, num_layers=1,
8
- kernel_size=3, padding=1,
9
- norm_layer=nn.BatchNorm2d):
 
 
 
 
 
 
10
  super(ConvHead, self).__init__()
11
  convhead = []
12
 
13
  for i in range(num_layers):
14
- convhead.extend([
15
- nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
16
- nn.ReLU(),
17
- norm_layer(in_channels) if norm_layer is not None else nn.Identity()
18
- ])
 
 
 
 
19
  convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
20
 
21
  self.convhead = nn.Sequential(*convhead)
@@ -25,25 +35,43 @@ class ConvHead(nn.Module):
25
 
26
 
27
  class SepConvHead(nn.Module):
28
- def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
29
- kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
30
- norm_layer=nn.BatchNorm2d):
 
 
 
 
 
 
 
 
 
31
  super(SepConvHead, self).__init__()
32
 
33
  sepconvhead = []
34
 
35
  for i in range(num_layers):
36
  sepconvhead.append(
37
- SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
38
- out_channels=mid_channels,
39
- dw_kernel=kernel_size, dw_padding=padding,
40
- norm_layer=norm_layer, activation='relu')
 
 
 
 
41
  )
42
  if dropout_ratio > 0 and dropout_indx == i:
43
  sepconvhead.append(nn.Dropout(dropout_ratio))
44
 
45
  sepconvhead.append(
46
- nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
 
 
 
 
 
47
  )
48
 
49
  self.layers = nn.Sequential(*sepconvhead)
@@ -55,16 +83,34 @@ class SepConvHead(nn.Module):
55
 
56
 
57
  class SeparableConv2d(nn.Module):
58
- def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
59
- activation=None, use_bias=False, norm_layer=None):
 
 
 
 
 
 
 
 
 
60
  super(SeparableConv2d, self).__init__()
61
  _activation = ops.select_activation_function(activation)
62
  self.body = nn.Sequential(
63
- nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
64
- padding=dw_padding, bias=use_bias, groups=in_channels),
65
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
 
 
 
 
 
 
 
 
 
66
  norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
67
- _activation()
68
  )
69
 
70
  def forward(self, x):
 
4
 
5
 
6
  class ConvHead(nn.Module):
7
+ def __init__(
8
+ self,
9
+ out_channels,
10
+ in_channels=32,
11
+ num_layers=1,
12
+ kernel_size=3,
13
+ padding=1,
14
+ norm_layer=nn.BatchNorm2d,
15
+ ):
16
  super(ConvHead, self).__init__()
17
  convhead = []
18
 
19
  for i in range(num_layers):
20
+ convhead.extend(
21
+ [
22
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
23
+ nn.ReLU(),
24
+ norm_layer(in_channels)
25
+ if norm_layer is not None
26
+ else nn.Identity(),
27
+ ]
28
+ )
29
  convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
30
 
31
  self.convhead = nn.Sequential(*convhead)
 
35
 
36
 
37
  class SepConvHead(nn.Module):
38
+ def __init__(
39
+ self,
40
+ num_outputs,
41
+ in_channels,
42
+ mid_channels,
43
+ num_layers=1,
44
+ kernel_size=3,
45
+ padding=1,
46
+ dropout_ratio=0.0,
47
+ dropout_indx=0,
48
+ norm_layer=nn.BatchNorm2d,
49
+ ):
50
  super(SepConvHead, self).__init__()
51
 
52
  sepconvhead = []
53
 
54
  for i in range(num_layers):
55
  sepconvhead.append(
56
+ SeparableConv2d(
57
+ in_channels=in_channels if i == 0 else mid_channels,
58
+ out_channels=mid_channels,
59
+ dw_kernel=kernel_size,
60
+ dw_padding=padding,
61
+ norm_layer=norm_layer,
62
+ activation="relu",
63
+ )
64
  )
65
  if dropout_ratio > 0 and dropout_indx == i:
66
  sepconvhead.append(nn.Dropout(dropout_ratio))
67
 
68
  sepconvhead.append(
69
+ nn.Conv2d(
70
+ in_channels=mid_channels,
71
+ out_channels=num_outputs,
72
+ kernel_size=1,
73
+ padding=0,
74
+ )
75
  )
76
 
77
  self.layers = nn.Sequential(*sepconvhead)
 
83
 
84
 
85
  class SeparableConv2d(nn.Module):
86
+ def __init__(
87
+ self,
88
+ in_channels,
89
+ out_channels,
90
+ dw_kernel,
91
+ dw_padding,
92
+ dw_stride=1,
93
+ activation=None,
94
+ use_bias=False,
95
+ norm_layer=None,
96
+ ):
97
  super(SeparableConv2d, self).__init__()
98
  _activation = ops.select_activation_function(activation)
99
  self.body = nn.Sequential(
100
+ nn.Conv2d(
101
+ in_channels,
102
+ in_channels,
103
+ kernel_size=dw_kernel,
104
+ stride=dw_stride,
105
+ padding=dw_padding,
106
+ bias=use_bias,
107
+ groups=in_channels,
108
+ ),
109
+ nn.Conv2d(
110
+ in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias
111
+ ),
112
  norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
113
+ _activation(),
114
  )
115
 
116
  def forward(self, x):
isegm/model/modeling/deeplab_v3.py CHANGED
@@ -1,21 +1,26 @@
1
  from contextlib import ExitStack
2
 
3
  import torch
4
- from torch import nn
5
  import torch.nn.functional as F
 
 
 
6
 
7
  from .basic_blocks import SeparableConv2d
8
  from .resnet import ResNetBackbone
9
- from isegm.model import ops
10
 
11
 
12
  class DeepLabV3Plus(nn.Module):
13
- def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
14
- backbone_norm_layer=None,
15
- ch=256,
16
- project_dropout=0.5,
17
- inference_mode=False,
18
- **kwargs):
 
 
 
 
19
  super(DeepLabV3Plus, self).__init__()
20
  if backbone_norm_layer is None:
21
  backbone_norm_layer = norm_layer
@@ -29,28 +34,44 @@ class DeepLabV3Plus(nn.Module):
29
  self.skip_project_in_channels = 256 # layer 1 out_channels
30
 
31
  self._kwargs = kwargs
32
- if backbone == 'resnet34':
33
  self.aspp_in_channels = 512
34
  self.skip_project_in_channels = 64
35
 
36
- self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
37
- norm_layer=self.backbone_norm_layer, **kwargs)
 
 
 
 
38
 
39
- self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
40
- norm_layer=self.norm_layer)
41
- self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
42
- self.aspp = _ASPP(in_channels=self.aspp_in_channels,
43
- atrous_rates=[12, 24, 36],
44
- out_channels=ch,
45
- project_dropout=project_dropout,
46
- norm_layer=self.norm_layer)
 
 
 
 
 
 
 
 
47
 
48
  if inference_mode:
49
  self.set_prediction_mode()
50
 
51
  def load_pretrained_weights(self):
52
- pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
53
- norm_layer=self.backbone_norm_layer, **self._kwargs)
 
 
 
 
54
  backbone_state_dict = self.backbone.state_dict()
55
  pretrained_state_dict = pretrained.state_dict()
56
 
@@ -74,11 +95,11 @@ class DeepLabV3Plus(nn.Module):
74
  c1 = self.skip_project(c1)
75
 
76
  x = self.aspp(c4)
77
- x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
78
  x = torch.cat((x, c1), dim=1)
79
  x = self.head(x)
80
 
81
- return x,
82
 
83
 
84
  class _SkipProject(nn.Module):
@@ -89,7 +110,7 @@ class _SkipProject(nn.Module):
89
  self.skip_project = nn.Sequential(
90
  nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
91
  norm_layer(out_channels),
92
- _activation()
93
  )
94
 
95
  def forward(self, x):
@@ -97,15 +118,31 @@ class _SkipProject(nn.Module):
97
 
98
 
99
  class _DeepLabHead(nn.Module):
100
- def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
 
 
101
  super(_DeepLabHead, self).__init__()
102
 
103
  self.block = nn.Sequential(
104
- SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
105
- dw_padding=1, activation='relu', norm_layer=norm_layer),
106
- SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
107
- dw_padding=1, activation='relu', norm_layer=norm_layer),
108
- nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
 
111
  def forward(self, x):
@@ -113,14 +150,25 @@ class _DeepLabHead(nn.Module):
113
 
114
 
115
  class _ASPP(nn.Module):
116
- def __init__(self, in_channels, atrous_rates, out_channels=256,
117
- project_dropout=0.5, norm_layer=nn.BatchNorm2d):
 
 
 
 
 
 
118
  super(_ASPP, self).__init__()
119
 
120
  b0 = nn.Sequential(
121
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
 
 
 
 
 
122
  norm_layer(out_channels),
123
- nn.ReLU()
124
  )
125
 
126
  rate1, rate2, rate3 = tuple(atrous_rates)
@@ -132,10 +180,14 @@ class _ASPP(nn.Module):
132
  self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
133
 
134
  project = [
135
- nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
136
- kernel_size=1, bias=False),
 
 
 
 
137
  norm_layer(out_channels),
138
- nn.ReLU()
139
  ]
140
  if project_dropout > 0:
141
  project.append(nn.Dropout(project_dropout))
@@ -153,24 +205,33 @@ class _AsppPooling(nn.Module):
153
 
154
  self.gap = nn.Sequential(
155
  nn.AdaptiveAvgPool2d((1, 1)),
156
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
157
- kernel_size=1, bias=False),
 
 
 
 
158
  norm_layer(out_channels),
159
- nn.ReLU()
160
  )
161
 
162
  def forward(self, x):
163
  pool = self.gap(x)
164
- return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
165
 
166
 
167
  def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
168
  block = nn.Sequential(
169
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
170
- kernel_size=3, padding=atrous_rate,
171
- dilation=atrous_rate, bias=False),
 
 
 
 
 
172
  norm_layer(out_channels),
173
- nn.ReLU()
174
  )
175
 
176
  return block
 
1
  from contextlib import ExitStack
2
 
3
  import torch
 
4
  import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from isegm.model import ops
8
 
9
  from .basic_blocks import SeparableConv2d
10
  from .resnet import ResNetBackbone
 
11
 
12
 
13
  class DeepLabV3Plus(nn.Module):
14
+ def __init__(
15
+ self,
16
+ backbone="resnet50",
17
+ norm_layer=nn.BatchNorm2d,
18
+ backbone_norm_layer=None,
19
+ ch=256,
20
+ project_dropout=0.5,
21
+ inference_mode=False,
22
+ **kwargs
23
+ ):
24
  super(DeepLabV3Plus, self).__init__()
25
  if backbone_norm_layer is None:
26
  backbone_norm_layer = norm_layer
 
34
  self.skip_project_in_channels = 256 # layer 1 out_channels
35
 
36
  self._kwargs = kwargs
37
+ if backbone == "resnet34":
38
  self.aspp_in_channels = 512
39
  self.skip_project_in_channels = 64
40
 
41
+ self.backbone = ResNetBackbone(
42
+ backbone=self.backbone_name,
43
+ pretrained_base=False,
44
+ norm_layer=self.backbone_norm_layer,
45
+ **kwargs
46
+ )
47
 
48
+ self.head = _DeepLabHead(
49
+ in_channels=ch + 32,
50
+ mid_channels=ch,
51
+ out_channels=ch,
52
+ norm_layer=self.norm_layer,
53
+ )
54
+ self.skip_project = _SkipProject(
55
+ self.skip_project_in_channels, 32, norm_layer=self.norm_layer
56
+ )
57
+ self.aspp = _ASPP(
58
+ in_channels=self.aspp_in_channels,
59
+ atrous_rates=[12, 24, 36],
60
+ out_channels=ch,
61
+ project_dropout=project_dropout,
62
+ norm_layer=self.norm_layer,
63
+ )
64
 
65
  if inference_mode:
66
  self.set_prediction_mode()
67
 
68
  def load_pretrained_weights(self):
69
+ pretrained = ResNetBackbone(
70
+ backbone=self.backbone_name,
71
+ pretrained_base=True,
72
+ norm_layer=self.backbone_norm_layer,
73
+ **self._kwargs
74
+ )
75
  backbone_state_dict = self.backbone.state_dict()
76
  pretrained_state_dict = pretrained.state_dict()
77
 
 
95
  c1 = self.skip_project(c1)
96
 
97
  x = self.aspp(c4)
98
+ x = F.interpolate(x, c1.size()[2:], mode="bilinear", align_corners=True)
99
  x = torch.cat((x, c1), dim=1)
100
  x = self.head(x)
101
 
102
+ return (x,)
103
 
104
 
105
  class _SkipProject(nn.Module):
 
110
  self.skip_project = nn.Sequential(
111
  nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
112
  norm_layer(out_channels),
113
+ _activation(),
114
  )
115
 
116
  def forward(self, x):
 
118
 
119
 
120
  class _DeepLabHead(nn.Module):
121
+ def __init__(
122
+ self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d
123
+ ):
124
  super(_DeepLabHead, self).__init__()
125
 
126
  self.block = nn.Sequential(
127
+ SeparableConv2d(
128
+ in_channels=in_channels,
129
+ out_channels=mid_channels,
130
+ dw_kernel=3,
131
+ dw_padding=1,
132
+ activation="relu",
133
+ norm_layer=norm_layer,
134
+ ),
135
+ SeparableConv2d(
136
+ in_channels=mid_channels,
137
+ out_channels=mid_channels,
138
+ dw_kernel=3,
139
+ dw_padding=1,
140
+ activation="relu",
141
+ norm_layer=norm_layer,
142
+ ),
143
+ nn.Conv2d(
144
+ in_channels=mid_channels, out_channels=out_channels, kernel_size=1
145
+ ),
146
  )
147
 
148
  def forward(self, x):
 
150
 
151
 
152
  class _ASPP(nn.Module):
153
+ def __init__(
154
+ self,
155
+ in_channels,
156
+ atrous_rates,
157
+ out_channels=256,
158
+ project_dropout=0.5,
159
+ norm_layer=nn.BatchNorm2d,
160
+ ):
161
  super(_ASPP, self).__init__()
162
 
163
  b0 = nn.Sequential(
164
+ nn.Conv2d(
165
+ in_channels=in_channels,
166
+ out_channels=out_channels,
167
+ kernel_size=1,
168
+ bias=False,
169
+ ),
170
  norm_layer(out_channels),
171
+ nn.ReLU(),
172
  )
173
 
174
  rate1, rate2, rate3 = tuple(atrous_rates)
 
180
  self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
181
 
182
  project = [
183
+ nn.Conv2d(
184
+ in_channels=5 * out_channels,
185
+ out_channels=out_channels,
186
+ kernel_size=1,
187
+ bias=False,
188
+ ),
189
  norm_layer(out_channels),
190
+ nn.ReLU(),
191
  ]
192
  if project_dropout > 0:
193
  project.append(nn.Dropout(project_dropout))
 
205
 
206
  self.gap = nn.Sequential(
207
  nn.AdaptiveAvgPool2d((1, 1)),
208
+ nn.Conv2d(
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ kernel_size=1,
212
+ bias=False,
213
+ ),
214
  norm_layer(out_channels),
215
+ nn.ReLU(),
216
  )
217
 
218
  def forward(self, x):
219
  pool = self.gap(x)
220
+ return F.interpolate(pool, x.size()[2:], mode="bilinear", align_corners=True)
221
 
222
 
223
  def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
224
  block = nn.Sequential(
225
+ nn.Conv2d(
226
+ in_channels=in_channels,
227
+ out_channels=out_channels,
228
+ kernel_size=3,
229
+ padding=atrous_rate,
230
+ dilation=atrous_rate,
231
+ bias=False,
232
+ ),
233
  norm_layer(out_channels),
234
+ nn.ReLU(),
235
  )
236
 
237
  return block
isegm/model/modeling/hrnet_ocr.py CHANGED
@@ -1,19 +1,30 @@
1
  import os
 
2
  import numpy as np
3
  import torch
4
- import torch.nn as nn
5
  import torch._utils
 
6
  import torch.nn.functional as F
7
- from .ocr import SpatialOCR_Module, SpatialGather_Module
 
8
  from .resnetv1b import BasicBlockV1b, BottleneckV1b
9
 
10
  relu_inplace = True
11
 
12
 
13
  class HighResolutionModule(nn.Module):
14
- def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
15
- num_channels, fuse_method,multi_scale_output=True,
16
- norm_layer=nn.BatchNorm2d, align_corners=True):
 
 
 
 
 
 
 
 
 
17
  super(HighResolutionModule, self).__init__()
18
  self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
19
 
@@ -26,48 +37,67 @@ class HighResolutionModule(nn.Module):
26
  self.multi_scale_output = multi_scale_output
27
 
28
  self.branches = self._make_branches(
29
- num_branches, blocks, num_blocks, num_channels)
 
30
  self.fuse_layers = self._make_fuse_layers()
31
  self.relu = nn.ReLU(inplace=relu_inplace)
32
 
33
  def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
34
  if num_branches != len(num_blocks):
35
- error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
36
- num_branches, len(num_blocks))
 
37
  raise ValueError(error_msg)
38
 
39
  if num_branches != len(num_channels):
40
- error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
41
- num_branches, len(num_channels))
 
42
  raise ValueError(error_msg)
43
 
44
  if num_branches != len(num_inchannels):
45
- error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
46
- num_branches, len(num_inchannels))
 
47
  raise ValueError(error_msg)
48
 
49
- def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
50
- stride=1):
51
  downsample = None
52
- if stride != 1 or \
53
- self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
 
 
 
54
  downsample = nn.Sequential(
55
- nn.Conv2d(self.num_inchannels[branch_index],
56
- num_channels[branch_index] * block.expansion,
57
- kernel_size=1, stride=stride, bias=False),
 
 
 
 
58
  self.norm_layer(num_channels[branch_index] * block.expansion),
59
  )
60
 
61
  layers = []
62
- layers.append(block(self.num_inchannels[branch_index],
63
- num_channels[branch_index], stride,
64
- downsample=downsample, norm_layer=self.norm_layer))
65
- self.num_inchannels[branch_index] = \
66
- num_channels[branch_index] * block.expansion
 
 
 
 
 
67
  for i in range(1, num_blocks[branch_index]):
68
- layers.append(block(self.num_inchannels[branch_index],
69
- num_channels[branch_index],
70
- norm_layer=self.norm_layer))
 
 
 
 
71
 
72
  return nn.Sequential(*layers)
73
 
@@ -75,8 +105,7 @@ class HighResolutionModule(nn.Module):
75
  branches = []
76
 
77
  for i in range(num_branches):
78
- branches.append(
79
- self._make_one_branch(i, block, num_blocks, num_channels))
80
 
81
  return nn.ModuleList(branches)
82
 
@@ -91,12 +120,17 @@ class HighResolutionModule(nn.Module):
91
  fuse_layer = []
92
  for j in range(num_branches):
93
  if j > i:
94
- fuse_layer.append(nn.Sequential(
95
- nn.Conv2d(in_channels=num_inchannels[j],
96
- out_channels=num_inchannels[i],
97
- kernel_size=1,
98
- bias=False),
99
- self.norm_layer(num_inchannels[i])))
 
 
 
 
 
100
  elif j == i:
101
  fuse_layer.append(None)
102
  else:
@@ -104,19 +138,35 @@ class HighResolutionModule(nn.Module):
104
  for k in range(i - j):
105
  if k == i - j - 1:
106
  num_outchannels_conv3x3 = num_inchannels[i]
107
- conv3x3s.append(nn.Sequential(
108
- nn.Conv2d(num_inchannels[j],
109
- num_outchannels_conv3x3,
110
- kernel_size=3, stride=2, padding=1, bias=False),
111
- self.norm_layer(num_outchannels_conv3x3)))
 
 
 
 
 
 
 
 
112
  else:
113
  num_outchannels_conv3x3 = num_inchannels[j]
114
- conv3x3s.append(nn.Sequential(
115
- nn.Conv2d(num_inchannels[j],
116
- num_outchannels_conv3x3,
117
- kernel_size=3, stride=2, padding=1, bias=False),
118
- self.norm_layer(num_outchannels_conv3x3),
119
- nn.ReLU(inplace=relu_inplace)))
 
 
 
 
 
 
 
 
120
  fuse_layer.append(nn.Sequential(*conv3x3s))
121
  fuse_layers.append(nn.ModuleList(fuse_layer))
122
 
@@ -144,7 +194,9 @@ class HighResolutionModule(nn.Module):
144
  y = y + F.interpolate(
145
  self.fuse_layers[i][j](x[j]),
146
  size=[height_output, width_output],
147
- mode='bilinear', align_corners=self.align_corners)
 
 
148
  else:
149
  y = y + self.fuse_layers[i][j](x[j])
150
  x_fuse.append(self.relu(y))
@@ -153,8 +205,15 @@ class HighResolutionModule(nn.Module):
153
 
154
 
155
  class HighResolutionNet(nn.Module):
156
- def __init__(self, width, num_classes, ocr_width=256, small=False,
157
- norm_layer=nn.BatchNorm2d, align_corners=True):
 
 
 
 
 
 
 
158
  super(HighResolutionNet, self).__init__()
159
  self.norm_layer = norm_layer
160
  self.width = width
@@ -170,40 +229,61 @@ class HighResolutionNet(nn.Module):
170
  num_blocks = 2 if small else 4
171
 
172
  stage1_num_channels = 64
173
- self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
 
 
174
  stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
175
 
176
  self.stage2_num_branches = 2
177
  num_channels = [width, 2 * width]
178
  num_inchannels = [
179
- num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
 
180
  self.transition1 = self._make_transition_layer(
181
- [stage1_out_channel], num_inchannels)
 
182
  self.stage2, pre_stage_channels = self._make_stage(
183
- BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
184
- num_blocks=2 * [num_blocks], num_channels=num_channels)
 
 
 
 
 
185
 
186
  self.stage3_num_branches = 3
187
  num_channels = [width, 2 * width, 4 * width]
188
  num_inchannels = [
189
- num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
 
190
  self.transition2 = self._make_transition_layer(
191
- pre_stage_channels, num_inchannels)
 
192
  self.stage3, pre_stage_channels = self._make_stage(
193
- BasicBlockV1b, num_inchannels=num_inchannels,
194
- num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
195
- num_blocks=3 * [num_blocks], num_channels=num_channels)
 
 
 
 
196
 
197
  self.stage4_num_branches = 4
198
  num_channels = [width, 2 * width, 4 * width, 8 * width]
199
  num_inchannels = [
200
- num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
 
201
  self.transition3 = self._make_transition_layer(
202
- pre_stage_channels, num_inchannels)
 
203
  self.stage4, pre_stage_channels = self._make_stage(
204
- BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
 
 
205
  num_branches=self.stage4_num_branches,
206
- num_blocks=4 * [num_blocks], num_channels=num_channels)
 
 
207
 
208
  last_inp_channels = np.int(np.sum(pre_stage_channels))
209
  if self.ocr_width > 0:
@@ -211,43 +291,77 @@ class HighResolutionNet(nn.Module):
211
  ocr_key_channels = self.ocr_width
212
 
213
  self.conv3x3_ocr = nn.Sequential(
214
- nn.Conv2d(last_inp_channels, ocr_mid_channels,
215
- kernel_size=3, stride=1, padding=1),
 
 
 
 
 
216
  norm_layer(ocr_mid_channels),
217
  nn.ReLU(inplace=relu_inplace),
218
  )
219
  self.ocr_gather_head = SpatialGather_Module(num_classes)
220
 
221
- self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
222
- key_channels=ocr_key_channels,
223
- out_channels=ocr_mid_channels,
224
- scale=1,
225
- dropout=0.05,
226
- norm_layer=norm_layer,
227
- align_corners=align_corners)
 
 
228
  self.cls_head = nn.Conv2d(
229
- ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
 
 
 
 
 
 
230
 
231
  self.aux_head = nn.Sequential(
232
- nn.Conv2d(last_inp_channels, last_inp_channels,
233
- kernel_size=1, stride=1, padding=0),
 
 
 
 
 
234
  norm_layer(last_inp_channels),
235
  nn.ReLU(inplace=relu_inplace),
236
- nn.Conv2d(last_inp_channels, num_classes,
237
- kernel_size=1, stride=1, padding=0, bias=True)
 
 
 
 
 
 
238
  )
239
  else:
240
  self.cls_head = nn.Sequential(
241
- nn.Conv2d(last_inp_channels, last_inp_channels,
242
- kernel_size=3, stride=1, padding=1),
 
 
 
 
 
243
  norm_layer(last_inp_channels),
244
  nn.ReLU(inplace=relu_inplace),
245
- nn.Conv2d(last_inp_channels, num_classes,
246
- kernel_size=1, stride=1, padding=0, bias=True)
 
 
 
 
 
 
247
  )
248
 
249
- def _make_transition_layer(
250
- self, num_channels_pre_layer, num_channels_cur_layer):
251
  num_branches_cur = len(num_channels_cur_layer)
252
  num_branches_pre = len(num_channels_pre_layer)
253
 
@@ -255,28 +369,45 @@ class HighResolutionNet(nn.Module):
255
  for i in range(num_branches_cur):
256
  if i < num_branches_pre:
257
  if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
258
- transition_layers.append(nn.Sequential(
259
- nn.Conv2d(num_channels_pre_layer[i],
260
- num_channels_cur_layer[i],
261
- kernel_size=3,
262
- stride=1,
263
- padding=1,
264
- bias=False),
265
- self.norm_layer(num_channels_cur_layer[i]),
266
- nn.ReLU(inplace=relu_inplace)))
 
 
 
 
 
267
  else:
268
  transition_layers.append(None)
269
  else:
270
  conv3x3s = []
271
  for j in range(i + 1 - num_branches_pre):
272
  inchannels = num_channels_pre_layer[-1]
273
- outchannels = num_channels_cur_layer[i] \
274
- if j == i - num_branches_pre else inchannels
275
- conv3x3s.append(nn.Sequential(
276
- nn.Conv2d(inchannels, outchannels,
277
- kernel_size=3, stride=2, padding=1, bias=False),
278
- self.norm_layer(outchannels),
279
- nn.ReLU(inplace=relu_inplace)))
 
 
 
 
 
 
 
 
 
 
 
 
280
  transition_layers.append(nn.Sequential(*conv3x3s))
281
 
282
  return nn.ModuleList(transition_layers)
@@ -285,24 +416,43 @@ class HighResolutionNet(nn.Module):
285
  downsample = None
286
  if stride != 1 or inplanes != planes * block.expansion:
287
  downsample = nn.Sequential(
288
- nn.Conv2d(inplanes, planes * block.expansion,
289
- kernel_size=1, stride=stride, bias=False),
 
 
 
 
 
290
  self.norm_layer(planes * block.expansion),
291
  )
292
 
293
  layers = []
294
- layers.append(block(inplanes, planes, stride,
295
- downsample=downsample, norm_layer=self.norm_layer))
 
 
 
 
 
 
 
296
  inplanes = planes * block.expansion
297
  for i in range(1, blocks):
298
  layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
299
 
300
  return nn.Sequential(*layers)
301
 
302
- def _make_stage(self, block, num_inchannels,
303
- num_modules, num_branches, num_blocks, num_channels,
304
- fuse_method='SUM',
305
- multi_scale_output=True):
 
 
 
 
 
 
 
306
  modules = []
307
  for i in range(num_modules):
308
  # multi_scale_output is only used last module
@@ -311,15 +461,17 @@ class HighResolutionNet(nn.Module):
311
  else:
312
  reset_multi_scale_output = True
313
  modules.append(
314
- HighResolutionModule(num_branches,
315
- block,
316
- num_blocks,
317
- num_inchannels,
318
- num_channels,
319
- fuse_method,
320
- reset_multi_scale_output,
321
- norm_layer=self.norm_layer,
322
- align_corners=self.align_corners)
 
 
323
  )
324
  num_inchannels = modules[-1].get_num_inchannels()
325
 
@@ -387,30 +539,38 @@ class HighResolutionNet(nn.Module):
387
  def aggregate_hrnet_features(self, x):
388
  # Upsampling
389
  x0_h, x0_w = x[0].size(2), x[0].size(3)
390
- x1 = F.interpolate(x[1], size=(x0_h, x0_w),
391
- mode='bilinear', align_corners=self.align_corners)
392
- x2 = F.interpolate(x[2], size=(x0_h, x0_w),
393
- mode='bilinear', align_corners=self.align_corners)
394
- x3 = F.interpolate(x[3], size=(x0_h, x0_w),
395
- mode='bilinear', align_corners=self.align_corners)
 
 
 
396
 
397
  return torch.cat([x[0], x1, x2, x3], 1)
398
 
399
- def load_pretrained_weights(self, pretrained_path=''):
400
  model_dict = self.state_dict()
401
 
402
  if not os.path.exists(pretrained_path):
403
  print(f'\nFile "{pretrained_path}" does not exist.')
404
- print('You need to specify the correct path to the pre-trained weights.\n'
405
- 'You can download the weights for HRNet from the repository:\n'
406
- 'https://github.com/HRNet/HRNet-Image-Classification')
 
 
407
  exit(1)
408
- pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
409
- pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
410
- pretrained_dict.items()}
411
-
412
- pretrained_dict = {k: v for k, v in pretrained_dict.items()
413
- if k in model_dict.keys()}
 
 
 
414
 
415
  model_dict.update(pretrained_dict)
416
  self.load_state_dict(model_dict)
 
1
  import os
2
+
3
  import numpy as np
4
  import torch
 
5
  import torch._utils
6
+ import torch.nn as nn
7
  import torch.nn.functional as F
8
+
9
+ from .ocr import SpatialGather_Module, SpatialOCR_Module
10
  from .resnetv1b import BasicBlockV1b, BottleneckV1b
11
 
12
  relu_inplace = True
13
 
14
 
15
  class HighResolutionModule(nn.Module):
16
+ def __init__(
17
+ self,
18
+ num_branches,
19
+ blocks,
20
+ num_blocks,
21
+ num_inchannels,
22
+ num_channels,
23
+ fuse_method,
24
+ multi_scale_output=True,
25
+ norm_layer=nn.BatchNorm2d,
26
+ align_corners=True,
27
+ ):
28
  super(HighResolutionModule, self).__init__()
29
  self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
30
 
 
37
  self.multi_scale_output = multi_scale_output
38
 
39
  self.branches = self._make_branches(
40
+ num_branches, blocks, num_blocks, num_channels
41
+ )
42
  self.fuse_layers = self._make_fuse_layers()
43
  self.relu = nn.ReLU(inplace=relu_inplace)
44
 
45
  def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
46
  if num_branches != len(num_blocks):
47
+ error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
48
+ num_branches, len(num_blocks)
49
+ )
50
  raise ValueError(error_msg)
51
 
52
  if num_branches != len(num_channels):
53
+ error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
54
+ num_branches, len(num_channels)
55
+ )
56
  raise ValueError(error_msg)
57
 
58
  if num_branches != len(num_inchannels):
59
+ error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
60
+ num_branches, len(num_inchannels)
61
+ )
62
  raise ValueError(error_msg)
63
 
64
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
 
65
  downsample = None
66
+ if (
67
+ stride != 1
68
+ or self.num_inchannels[branch_index]
69
+ != num_channels[branch_index] * block.expansion
70
+ ):
71
  downsample = nn.Sequential(
72
+ nn.Conv2d(
73
+ self.num_inchannels[branch_index],
74
+ num_channels[branch_index] * block.expansion,
75
+ kernel_size=1,
76
+ stride=stride,
77
+ bias=False,
78
+ ),
79
  self.norm_layer(num_channels[branch_index] * block.expansion),
80
  )
81
 
82
  layers = []
83
+ layers.append(
84
+ block(
85
+ self.num_inchannels[branch_index],
86
+ num_channels[branch_index],
87
+ stride,
88
+ downsample=downsample,
89
+ norm_layer=self.norm_layer,
90
+ )
91
+ )
92
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
93
  for i in range(1, num_blocks[branch_index]):
94
+ layers.append(
95
+ block(
96
+ self.num_inchannels[branch_index],
97
+ num_channels[branch_index],
98
+ norm_layer=self.norm_layer,
99
+ )
100
+ )
101
 
102
  return nn.Sequential(*layers)
103
 
 
105
  branches = []
106
 
107
  for i in range(num_branches):
108
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
 
109
 
110
  return nn.ModuleList(branches)
111
 
 
120
  fuse_layer = []
121
  for j in range(num_branches):
122
  if j > i:
123
+ fuse_layer.append(
124
+ nn.Sequential(
125
+ nn.Conv2d(
126
+ in_channels=num_inchannels[j],
127
+ out_channels=num_inchannels[i],
128
+ kernel_size=1,
129
+ bias=False,
130
+ ),
131
+ self.norm_layer(num_inchannels[i]),
132
+ )
133
+ )
134
  elif j == i:
135
  fuse_layer.append(None)
136
  else:
 
138
  for k in range(i - j):
139
  if k == i - j - 1:
140
  num_outchannels_conv3x3 = num_inchannels[i]
141
+ conv3x3s.append(
142
+ nn.Sequential(
143
+ nn.Conv2d(
144
+ num_inchannels[j],
145
+ num_outchannels_conv3x3,
146
+ kernel_size=3,
147
+ stride=2,
148
+ padding=1,
149
+ bias=False,
150
+ ),
151
+ self.norm_layer(num_outchannels_conv3x3),
152
+ )
153
+ )
154
  else:
155
  num_outchannels_conv3x3 = num_inchannels[j]
156
+ conv3x3s.append(
157
+ nn.Sequential(
158
+ nn.Conv2d(
159
+ num_inchannels[j],
160
+ num_outchannels_conv3x3,
161
+ kernel_size=3,
162
+ stride=2,
163
+ padding=1,
164
+ bias=False,
165
+ ),
166
+ self.norm_layer(num_outchannels_conv3x3),
167
+ nn.ReLU(inplace=relu_inplace),
168
+ )
169
+ )
170
  fuse_layer.append(nn.Sequential(*conv3x3s))
171
  fuse_layers.append(nn.ModuleList(fuse_layer))
172
 
 
194
  y = y + F.interpolate(
195
  self.fuse_layers[i][j](x[j]),
196
  size=[height_output, width_output],
197
+ mode="bilinear",
198
+ align_corners=self.align_corners,
199
+ )
200
  else:
201
  y = y + self.fuse_layers[i][j](x[j])
202
  x_fuse.append(self.relu(y))
 
205
 
206
 
207
  class HighResolutionNet(nn.Module):
208
+ def __init__(
209
+ self,
210
+ width,
211
+ num_classes,
212
+ ocr_width=256,
213
+ small=False,
214
+ norm_layer=nn.BatchNorm2d,
215
+ align_corners=True,
216
+ ):
217
  super(HighResolutionNet, self).__init__()
218
  self.norm_layer = norm_layer
219
  self.width = width
 
229
  num_blocks = 2 if small else 4
230
 
231
  stage1_num_channels = 64
232
+ self.layer1 = self._make_layer(
233
+ BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks
234
+ )
235
  stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
236
 
237
  self.stage2_num_branches = 2
238
  num_channels = [width, 2 * width]
239
  num_inchannels = [
240
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
241
+ ]
242
  self.transition1 = self._make_transition_layer(
243
+ [stage1_out_channel], num_inchannels
244
+ )
245
  self.stage2, pre_stage_channels = self._make_stage(
246
+ BasicBlockV1b,
247
+ num_inchannels=num_inchannels,
248
+ num_modules=1,
249
+ num_branches=self.stage2_num_branches,
250
+ num_blocks=2 * [num_blocks],
251
+ num_channels=num_channels,
252
+ )
253
 
254
  self.stage3_num_branches = 3
255
  num_channels = [width, 2 * width, 4 * width]
256
  num_inchannels = [
257
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
258
+ ]
259
  self.transition2 = self._make_transition_layer(
260
+ pre_stage_channels, num_inchannels
261
+ )
262
  self.stage3, pre_stage_channels = self._make_stage(
263
+ BasicBlockV1b,
264
+ num_inchannels=num_inchannels,
265
+ num_modules=3 if small else 4,
266
+ num_branches=self.stage3_num_branches,
267
+ num_blocks=3 * [num_blocks],
268
+ num_channels=num_channels,
269
+ )
270
 
271
  self.stage4_num_branches = 4
272
  num_channels = [width, 2 * width, 4 * width, 8 * width]
273
  num_inchannels = [
274
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
275
+ ]
276
  self.transition3 = self._make_transition_layer(
277
+ pre_stage_channels, num_inchannels
278
+ )
279
  self.stage4, pre_stage_channels = self._make_stage(
280
+ BasicBlockV1b,
281
+ num_inchannels=num_inchannels,
282
+ num_modules=2 if small else 3,
283
  num_branches=self.stage4_num_branches,
284
+ num_blocks=4 * [num_blocks],
285
+ num_channels=num_channels,
286
+ )
287
 
288
  last_inp_channels = np.int(np.sum(pre_stage_channels))
289
  if self.ocr_width > 0:
 
291
  ocr_key_channels = self.ocr_width
292
 
293
  self.conv3x3_ocr = nn.Sequential(
294
+ nn.Conv2d(
295
+ last_inp_channels,
296
+ ocr_mid_channels,
297
+ kernel_size=3,
298
+ stride=1,
299
+ padding=1,
300
+ ),
301
  norm_layer(ocr_mid_channels),
302
  nn.ReLU(inplace=relu_inplace),
303
  )
304
  self.ocr_gather_head = SpatialGather_Module(num_classes)
305
 
306
+ self.ocr_distri_head = SpatialOCR_Module(
307
+ in_channels=ocr_mid_channels,
308
+ key_channels=ocr_key_channels,
309
+ out_channels=ocr_mid_channels,
310
+ scale=1,
311
+ dropout=0.05,
312
+ norm_layer=norm_layer,
313
+ align_corners=align_corners,
314
+ )
315
  self.cls_head = nn.Conv2d(
316
+ ocr_mid_channels,
317
+ num_classes,
318
+ kernel_size=1,
319
+ stride=1,
320
+ padding=0,
321
+ bias=True,
322
+ )
323
 
324
  self.aux_head = nn.Sequential(
325
+ nn.Conv2d(
326
+ last_inp_channels,
327
+ last_inp_channels,
328
+ kernel_size=1,
329
+ stride=1,
330
+ padding=0,
331
+ ),
332
  norm_layer(last_inp_channels),
333
  nn.ReLU(inplace=relu_inplace),
334
+ nn.Conv2d(
335
+ last_inp_channels,
336
+ num_classes,
337
+ kernel_size=1,
338
+ stride=1,
339
+ padding=0,
340
+ bias=True,
341
+ ),
342
  )
343
  else:
344
  self.cls_head = nn.Sequential(
345
+ nn.Conv2d(
346
+ last_inp_channels,
347
+ last_inp_channels,
348
+ kernel_size=3,
349
+ stride=1,
350
+ padding=1,
351
+ ),
352
  norm_layer(last_inp_channels),
353
  nn.ReLU(inplace=relu_inplace),
354
+ nn.Conv2d(
355
+ last_inp_channels,
356
+ num_classes,
357
+ kernel_size=1,
358
+ stride=1,
359
+ padding=0,
360
+ bias=True,
361
+ ),
362
  )
363
 
364
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
 
365
  num_branches_cur = len(num_channels_cur_layer)
366
  num_branches_pre = len(num_channels_pre_layer)
367
 
 
369
  for i in range(num_branches_cur):
370
  if i < num_branches_pre:
371
  if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
372
+ transition_layers.append(
373
+ nn.Sequential(
374
+ nn.Conv2d(
375
+ num_channels_pre_layer[i],
376
+ num_channels_cur_layer[i],
377
+ kernel_size=3,
378
+ stride=1,
379
+ padding=1,
380
+ bias=False,
381
+ ),
382
+ self.norm_layer(num_channels_cur_layer[i]),
383
+ nn.ReLU(inplace=relu_inplace),
384
+ )
385
+ )
386
  else:
387
  transition_layers.append(None)
388
  else:
389
  conv3x3s = []
390
  for j in range(i + 1 - num_branches_pre):
391
  inchannels = num_channels_pre_layer[-1]
392
+ outchannels = (
393
+ num_channels_cur_layer[i]
394
+ if j == i - num_branches_pre
395
+ else inchannels
396
+ )
397
+ conv3x3s.append(
398
+ nn.Sequential(
399
+ nn.Conv2d(
400
+ inchannels,
401
+ outchannels,
402
+ kernel_size=3,
403
+ stride=2,
404
+ padding=1,
405
+ bias=False,
406
+ ),
407
+ self.norm_layer(outchannels),
408
+ nn.ReLU(inplace=relu_inplace),
409
+ )
410
+ )
411
  transition_layers.append(nn.Sequential(*conv3x3s))
412
 
413
  return nn.ModuleList(transition_layers)
 
416
  downsample = None
417
  if stride != 1 or inplanes != planes * block.expansion:
418
  downsample = nn.Sequential(
419
+ nn.Conv2d(
420
+ inplanes,
421
+ planes * block.expansion,
422
+ kernel_size=1,
423
+ stride=stride,
424
+ bias=False,
425
+ ),
426
  self.norm_layer(planes * block.expansion),
427
  )
428
 
429
  layers = []
430
+ layers.append(
431
+ block(
432
+ inplanes,
433
+ planes,
434
+ stride,
435
+ downsample=downsample,
436
+ norm_layer=self.norm_layer,
437
+ )
438
+ )
439
  inplanes = planes * block.expansion
440
  for i in range(1, blocks):
441
  layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
442
 
443
  return nn.Sequential(*layers)
444
 
445
+ def _make_stage(
446
+ self,
447
+ block,
448
+ num_inchannels,
449
+ num_modules,
450
+ num_branches,
451
+ num_blocks,
452
+ num_channels,
453
+ fuse_method="SUM",
454
+ multi_scale_output=True,
455
+ ):
456
  modules = []
457
  for i in range(num_modules):
458
  # multi_scale_output is only used last module
 
461
  else:
462
  reset_multi_scale_output = True
463
  modules.append(
464
+ HighResolutionModule(
465
+ num_branches,
466
+ block,
467
+ num_blocks,
468
+ num_inchannels,
469
+ num_channels,
470
+ fuse_method,
471
+ reset_multi_scale_output,
472
+ norm_layer=self.norm_layer,
473
+ align_corners=self.align_corners,
474
+ )
475
  )
476
  num_inchannels = modules[-1].get_num_inchannels()
477
 
 
539
  def aggregate_hrnet_features(self, x):
540
  # Upsampling
541
  x0_h, x0_w = x[0].size(2), x[0].size(3)
542
+ x1 = F.interpolate(
543
+ x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners
544
+ )
545
+ x2 = F.interpolate(
546
+ x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners
547
+ )
548
+ x3 = F.interpolate(
549
+ x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners
550
+ )
551
 
552
  return torch.cat([x[0], x1, x2, x3], 1)
553
 
554
+ def load_pretrained_weights(self, pretrained_path=""):
555
  model_dict = self.state_dict()
556
 
557
  if not os.path.exists(pretrained_path):
558
  print(f'\nFile "{pretrained_path}" does not exist.')
559
+ print(
560
+ "You need to specify the correct path to the pre-trained weights.\n"
561
+ "You can download the weights for HRNet from the repository:\n"
562
+ "https://github.com/HRNet/HRNet-Image-Classification"
563
+ )
564
  exit(1)
565
+ pretrained_dict = torch.load(pretrained_path, map_location={"cuda:0": "cpu"})
566
+ pretrained_dict = {
567
+ k.replace("last_layer", "aux_head").replace("model.", ""): v
568
+ for k, v in pretrained_dict.items()
569
+ }
570
+
571
+ pretrained_dict = {
572
+ k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
573
+ }
574
 
575
  model_dict.update(pretrained_dict)
576
  self.load_state_dict(model_dict)
isegm/model/modeling/ocr.py CHANGED
@@ -1,14 +1,14 @@
1
  import torch
2
- import torch.nn as nn
3
  import torch._utils
 
4
  import torch.nn.functional as F
5
 
6
 
7
  class SpatialGather_Module(nn.Module):
8
  """
9
- Aggregate the context features according to the initial
10
- predicted probability distribution.
11
- Employ the soft-weighted method to aggregate the context.
12
  """
13
 
14
  def __init__(self, cls_num=0, scale=1):
@@ -22,8 +22,9 @@ class SpatialGather_Module(nn.Module):
22
  feats = feats.view(batch_size, feats.size(1), -1)
23
  feats = feats.permute(0, 2, 1) # batch x hw x c
24
  probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
25
- ocr_context = torch.matmul(probs, feats) \
26
- .permute(0, 2, 1).unsqueeze(3) # batch x k x c
 
27
  return ocr_context
28
 
29
 
@@ -33,23 +34,26 @@ class SpatialOCR_Module(nn.Module):
33
  We aggregate the global object representation to update the representation for each pixel.
34
  """
35
 
36
- def __init__(self,
37
- in_channels,
38
- key_channels,
39
- out_channels,
40
- scale=1,
41
- dropout=0.1,
42
- norm_layer=nn.BatchNorm2d,
43
- align_corners=True):
 
 
44
  super(SpatialOCR_Module, self).__init__()
45
- self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
46
- norm_layer, align_corners)
 
47
  _in_channels = 2 * in_channels
48
 
49
  self.conv_bn_dropout = nn.Sequential(
50
  nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
51
  nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
52
- nn.Dropout2d(dropout)
53
  )
54
 
55
  def forward(self, feats, proxy_feats):
@@ -61,7 +65,7 @@ class SpatialOCR_Module(nn.Module):
61
 
62
 
63
  class ObjectAttentionBlock2D(nn.Module):
64
- '''
65
  The basic implementation for object context block
66
  Input:
67
  N X C X H X W
@@ -72,14 +76,16 @@ class ObjectAttentionBlock2D(nn.Module):
72
  bn_type : specify the bn type
73
  Return:
74
  N X C X H X W
75
- '''
76
-
77
- def __init__(self,
78
- in_channels,
79
- key_channels,
80
- scale=1,
81
- norm_layer=nn.BatchNorm2d,
82
- align_corners=True):
 
 
83
  super(ObjectAttentionBlock2D, self).__init__()
84
  self.scale = scale
85
  self.in_channels = in_channels
@@ -88,30 +94,66 @@ class ObjectAttentionBlock2D(nn.Module):
88
 
89
  self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
90
  self.f_pixel = nn.Sequential(
91
- nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
92
- kernel_size=1, stride=1, padding=0, bias=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
94
- nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
95
- kernel_size=1, stride=1, padding=0, bias=False),
96
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
97
  )
98
  self.f_object = nn.Sequential(
99
- nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
100
- kernel_size=1, stride=1, padding=0, bias=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
102
- nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
103
- kernel_size=1, stride=1, padding=0, bias=False),
104
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
105
  )
106
  self.f_down = nn.Sequential(
107
- nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
108
- kernel_size=1, stride=1, padding=0, bias=False),
109
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
 
 
 
 
 
 
110
  )
111
  self.f_up = nn.Sequential(
112
- nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
113
- kernel_size=1, stride=1, padding=0, bias=False),
114
- nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
 
 
 
 
 
 
115
  )
116
 
117
  def forward(self, x, proxy):
@@ -126,7 +168,7 @@ class ObjectAttentionBlock2D(nn.Module):
126
  value = value.permute(0, 2, 1)
127
 
128
  sim_map = torch.matmul(query, key)
129
- sim_map = (self.key_channels ** -.5) * sim_map
130
  sim_map = F.softmax(sim_map, dim=-1)
131
 
132
  # add bg context ...
@@ -135,7 +177,11 @@ class ObjectAttentionBlock2D(nn.Module):
135
  context = context.view(batch_size, self.key_channels, *x.size()[2:])
136
  context = self.f_up(context)
137
  if self.scale > 1:
138
- context = F.interpolate(input=context, size=(h, w),
139
- mode='bilinear', align_corners=self.align_corners)
 
 
 
 
140
 
141
  return context
 
1
  import torch
 
2
  import torch._utils
3
+ import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
 
7
  class SpatialGather_Module(nn.Module):
8
  """
9
+ Aggregate the context features according to the initial
10
+ predicted probability distribution.
11
+ Employ the soft-weighted method to aggregate the context.
12
  """
13
 
14
  def __init__(self, cls_num=0, scale=1):
 
22
  feats = feats.view(batch_size, feats.size(1), -1)
23
  feats = feats.permute(0, 2, 1) # batch x hw x c
24
  probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
25
+ ocr_context = (
26
+ torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3)
27
+ ) # batch x k x c
28
  return ocr_context
29
 
30
 
 
34
  We aggregate the global object representation to update the representation for each pixel.
35
  """
36
 
37
+ def __init__(
38
+ self,
39
+ in_channels,
40
+ key_channels,
41
+ out_channels,
42
+ scale=1,
43
+ dropout=0.1,
44
+ norm_layer=nn.BatchNorm2d,
45
+ align_corners=True,
46
+ ):
47
  super(SpatialOCR_Module, self).__init__()
48
+ self.object_context_block = ObjectAttentionBlock2D(
49
+ in_channels, key_channels, scale, norm_layer, align_corners
50
+ )
51
  _in_channels = 2 * in_channels
52
 
53
  self.conv_bn_dropout = nn.Sequential(
54
  nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
55
  nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
56
+ nn.Dropout2d(dropout),
57
  )
58
 
59
  def forward(self, feats, proxy_feats):
 
65
 
66
 
67
  class ObjectAttentionBlock2D(nn.Module):
68
+ """
69
  The basic implementation for object context block
70
  Input:
71
  N X C X H X W
 
76
  bn_type : specify the bn type
77
  Return:
78
  N X C X H X W
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ in_channels,
84
+ key_channels,
85
+ scale=1,
86
+ norm_layer=nn.BatchNorm2d,
87
+ align_corners=True,
88
+ ):
89
  super(ObjectAttentionBlock2D, self).__init__()
90
  self.scale = scale
91
  self.in_channels = in_channels
 
94
 
95
  self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
96
  self.f_pixel = nn.Sequential(
97
+ nn.Conv2d(
98
+ in_channels=self.in_channels,
99
+ out_channels=self.key_channels,
100
+ kernel_size=1,
101
+ stride=1,
102
+ padding=0,
103
+ bias=False,
104
+ ),
105
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
106
+ nn.Conv2d(
107
+ in_channels=self.key_channels,
108
+ out_channels=self.key_channels,
109
+ kernel_size=1,
110
+ stride=1,
111
+ padding=0,
112
+ bias=False,
113
+ ),
114
  nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
 
 
 
115
  )
116
  self.f_object = nn.Sequential(
117
+ nn.Conv2d(
118
+ in_channels=self.in_channels,
119
+ out_channels=self.key_channels,
120
+ kernel_size=1,
121
+ stride=1,
122
+ padding=0,
123
+ bias=False,
124
+ ),
125
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
126
+ nn.Conv2d(
127
+ in_channels=self.key_channels,
128
+ out_channels=self.key_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0,
132
+ bias=False,
133
+ ),
134
  nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
 
 
 
135
  )
136
  self.f_down = nn.Sequential(
137
+ nn.Conv2d(
138
+ in_channels=self.in_channels,
139
+ out_channels=self.key_channels,
140
+ kernel_size=1,
141
+ stride=1,
142
+ padding=0,
143
+ bias=False,
144
+ ),
145
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
146
  )
147
  self.f_up = nn.Sequential(
148
+ nn.Conv2d(
149
+ in_channels=self.key_channels,
150
+ out_channels=self.in_channels,
151
+ kernel_size=1,
152
+ stride=1,
153
+ padding=0,
154
+ bias=False,
155
+ ),
156
+ nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)),
157
  )
158
 
159
  def forward(self, x, proxy):
 
168
  value = value.permute(0, 2, 1)
169
 
170
  sim_map = torch.matmul(query, key)
171
+ sim_map = (self.key_channels**-0.5) * sim_map
172
  sim_map = F.softmax(sim_map, dim=-1)
173
 
174
  # add bg context ...
 
177
  context = context.view(batch_size, self.key_channels, *x.size()[2:])
178
  context = self.f_up(context)
179
  if self.scale > 1:
180
+ context = F.interpolate(
181
+ input=context,
182
+ size=(h, w),
183
+ mode="bilinear",
184
+ align_corners=self.align_corners,
185
+ )
186
 
187
  return context
isegm/model/modeling/resnet.py CHANGED
@@ -1,21 +1,32 @@
1
  import torch
 
2
  from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
3
 
4
 
5
  class ResNetBackbone(torch.nn.Module):
6
- def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
 
 
7
  super(ResNetBackbone, self).__init__()
8
 
9
- if backbone == 'resnet34':
10
- pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
11
- elif backbone == 'resnet50':
12
- pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
13
- elif backbone == 'resnet101':
14
- pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
15
- elif backbone == 'resnet152':
16
- pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
 
 
 
 
 
 
 
 
17
  else:
18
- raise RuntimeError(f'unknown backbone: {backbone}')
19
 
20
  self.conv1 = pretrained.conv1
21
  self.bn1 = pretrained.bn1
@@ -31,9 +42,12 @@ class ResNetBackbone(torch.nn.Module):
31
  x = self.bn1(x)
32
  x = self.relu(x)
33
  if additional_features is not None:
34
- x = x + torch.nn.functional.pad(additional_features,
35
- [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
36
- mode='constant', value=0)
 
 
 
37
  x = self.maxpool(x)
38
  c1 = self.layer1(x)
39
  c2 = self.layer2(c1)
 
1
  import torch
2
+
3
  from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
4
 
5
 
6
  class ResNetBackbone(torch.nn.Module):
7
+ def __init__(
8
+ self, backbone="resnet50", pretrained_base=True, dilated=True, **kwargs
9
+ ):
10
  super(ResNetBackbone, self).__init__()
11
 
12
+ if backbone == "resnet34":
13
+ pretrained = resnet34_v1b(
14
+ pretrained=pretrained_base, dilated=dilated, **kwargs
15
+ )
16
+ elif backbone == "resnet50":
17
+ pretrained = resnet50_v1s(
18
+ pretrained=pretrained_base, dilated=dilated, **kwargs
19
+ )
20
+ elif backbone == "resnet101":
21
+ pretrained = resnet101_v1s(
22
+ pretrained=pretrained_base, dilated=dilated, **kwargs
23
+ )
24
+ elif backbone == "resnet152":
25
+ pretrained = resnet152_v1s(
26
+ pretrained=pretrained_base, dilated=dilated, **kwargs
27
+ )
28
  else:
29
+ raise RuntimeError(f"unknown backbone: {backbone}")
30
 
31
  self.conv1 = pretrained.conv1
32
  self.bn1 = pretrained.bn1
 
42
  x = self.bn1(x)
43
  x = self.relu(x)
44
  if additional_features is not None:
45
+ x = x + torch.nn.functional.pad(
46
+ additional_features,
47
+ [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
48
+ mode="constant",
49
+ value=0,
50
+ )
51
  x = self.maxpool(x)
52
  c1 = self.layer1(x)
53
  c2 = self.layer2(c1)
isegm/model/modeling/resnetv1b.py CHANGED
@@ -1,19 +1,42 @@
1
  import torch
2
  import torch.nn as nn
3
- GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
 
4
 
5
 
6
  class BasicBlockV1b(nn.Module):
7
  expansion = 1
8
 
9
- def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
10
- previous_dilation=1, norm_layer=nn.BatchNorm2d):
 
 
 
 
 
 
 
 
11
  super(BasicBlockV1b, self).__init__()
12
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
13
- padding=dilation, dilation=dilation, bias=False)
 
 
 
 
 
 
 
14
  self.bn1 = norm_layer(planes)
15
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
16
- padding=previous_dilation, dilation=previous_dilation, bias=False)
 
 
 
 
 
 
 
17
  self.bn2 = norm_layer(planes)
18
 
19
  self.relu = nn.ReLU(inplace=True)
@@ -42,17 +65,34 @@ class BasicBlockV1b(nn.Module):
42
  class BottleneckV1b(nn.Module):
43
  expansion = 4
44
 
45
- def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
46
- previous_dilation=1, norm_layer=nn.BatchNorm2d):
 
 
 
 
 
 
 
 
47
  super(BottleneckV1b, self).__init__()
48
  self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49
  self.bn1 = norm_layer(planes)
50
 
51
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52
- padding=dilation, dilation=dilation, bias=False)
 
 
 
 
 
 
 
53
  self.bn2 = norm_layer(planes)
54
 
55
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
 
 
56
  self.bn3 = norm_layer(planes * self.expansion)
57
 
58
  self.relu = nn.ReLU(inplace=True)
@@ -83,7 +123,7 @@ class BottleneckV1b(nn.Module):
83
 
84
 
85
  class ResNetV1b(nn.Module):
86
- """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
87
 
88
  Parameters
89
  ----------
@@ -111,86 +151,198 @@ class ResNetV1b(nn.Module):
111
 
112
  - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
113
  """
114
- def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
115
- avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
116
- self.inplanes = stem_width*2 if deep_stem else 64
 
 
 
 
 
 
 
 
 
 
 
117
  super(ResNetV1b, self).__init__()
118
  if not deep_stem:
119
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
 
 
120
  else:
121
  self.conv1 = nn.Sequential(
122
- nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
 
 
123
  norm_layer(stem_width),
124
  nn.ReLU(True),
125
- nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
 
 
 
 
 
 
 
126
  norm_layer(stem_width),
127
  nn.ReLU(True),
128
- nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
 
 
 
 
 
 
 
129
  )
130
  self.bn1 = norm_layer(self.inplanes)
131
  self.relu = nn.ReLU(True)
132
  self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
133
- self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
134
- norm_layer=norm_layer)
135
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
136
- norm_layer=norm_layer)
 
 
137
  if dilated:
138
- self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
139
- avg_down=avg_down, norm_layer=norm_layer)
140
- self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
141
- avg_down=avg_down, norm_layer=norm_layer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  else:
143
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144
- avg_down=avg_down, norm_layer=norm_layer)
145
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146
- avg_down=avg_down, norm_layer=norm_layer)
 
 
 
 
 
 
 
 
 
 
 
 
147
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148
  self.drop = None
149
  if final_drop > 0.0:
150
  self.drop = nn.Dropout(final_drop)
151
  self.fc = nn.Linear(512 * block.expansion, classes)
152
 
153
- def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
154
- avg_down=False, norm_layer=nn.BatchNorm2d):
 
 
 
 
 
 
 
 
155
  downsample = None
156
  if stride != 1 or self.inplanes != planes * block.expansion:
157
  downsample = []
158
  if avg_down:
159
  if dilation == 1:
160
  downsample.append(
161
- nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
 
 
 
 
 
162
  )
163
  else:
164
  downsample.append(
165
- nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
 
 
 
 
 
166
  )
167
- downsample.extend([
168
- nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
169
- kernel_size=1, stride=1, bias=False),
170
- norm_layer(planes * block.expansion)
171
- ])
 
 
 
 
 
 
 
172
  downsample = nn.Sequential(*downsample)
173
  else:
174
  downsample = nn.Sequential(
175
- nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
176
- kernel_size=1, stride=stride, bias=False),
177
- norm_layer(planes * block.expansion)
 
 
 
 
 
178
  )
179
 
180
  layers = []
181
  if dilation in (1, 2):
182
- layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
183
- previous_dilation=dilation, norm_layer=norm_layer))
 
 
 
 
 
 
 
 
 
184
  elif dilation == 4:
185
- layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
186
- previous_dilation=dilation, norm_layer=norm_layer))
 
 
 
 
 
 
 
 
 
187
  else:
188
  raise RuntimeError("=> unknown dilation size: {}".format(dilation))
189
 
190
  self.inplanes = planes * block.expansion
191
  for _ in range(1, blocks):
192
- layers.append(block(self.inplanes, planes, dilation=dilation,
193
- previous_dilation=dilation, norm_layer=norm_layer))
 
 
 
 
 
 
 
194
 
195
  return nn.Sequential(*layers)
196
 
@@ -229,8 +381,10 @@ def resnet34_v1b(pretrained=False, **kwargs):
229
  if pretrained:
230
  model_dict = model.state_dict()
231
  filtered_orig_dict = _safe_state_dict_filtering(
232
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
233
- model_dict.keys()
 
 
234
  )
235
  model_dict.update(filtered_orig_dict)
236
  model.load_state_dict(model_dict)
@@ -238,12 +392,16 @@ def resnet34_v1b(pretrained=False, **kwargs):
238
 
239
 
240
  def resnet50_v1s(pretrained=False, **kwargs):
241
- model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
 
 
242
  if pretrained:
243
  model_dict = model.state_dict()
244
  filtered_orig_dict = _safe_state_dict_filtering(
245
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
246
- model_dict.keys()
 
 
247
  )
248
  model_dict.update(filtered_orig_dict)
249
  model.load_state_dict(model_dict)
@@ -251,12 +409,16 @@ def resnet50_v1s(pretrained=False, **kwargs):
251
 
252
 
253
  def resnet101_v1s(pretrained=False, **kwargs):
254
- model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
 
 
255
  if pretrained:
256
  model_dict = model.state_dict()
257
  filtered_orig_dict = _safe_state_dict_filtering(
258
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
259
- model_dict.keys()
 
 
260
  )
261
  model_dict.update(filtered_orig_dict)
262
  model.load_state_dict(model_dict)
@@ -264,12 +426,16 @@ def resnet101_v1s(pretrained=False, **kwargs):
264
 
265
 
266
  def resnet152_v1s(pretrained=False, **kwargs):
267
- model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
 
 
268
  if pretrained:
269
  model_dict = model.state_dict()
270
  filtered_orig_dict = _safe_state_dict_filtering(
271
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
272
- model_dict.keys()
 
 
273
  )
274
  model_dict.update(filtered_orig_dict)
275
  model.load_state_dict(model_dict)
 
1
  import torch
2
  import torch.nn as nn
3
+
4
+ GLUON_RESNET_TORCH_HUB = "rwightman/pytorch-pretrained-gluonresnet"
5
 
6
 
7
  class BasicBlockV1b(nn.Module):
8
  expansion = 1
9
 
10
+ def __init__(
11
+ self,
12
+ inplanes,
13
+ planes,
14
+ stride=1,
15
+ dilation=1,
16
+ downsample=None,
17
+ previous_dilation=1,
18
+ norm_layer=nn.BatchNorm2d,
19
+ ):
20
  super(BasicBlockV1b, self).__init__()
21
+ self.conv1 = nn.Conv2d(
22
+ inplanes,
23
+ planes,
24
+ kernel_size=3,
25
+ stride=stride,
26
+ padding=dilation,
27
+ dilation=dilation,
28
+ bias=False,
29
+ )
30
  self.bn1 = norm_layer(planes)
31
+ self.conv2 = nn.Conv2d(
32
+ planes,
33
+ planes,
34
+ kernel_size=3,
35
+ stride=1,
36
+ padding=previous_dilation,
37
+ dilation=previous_dilation,
38
+ bias=False,
39
+ )
40
  self.bn2 = norm_layer(planes)
41
 
42
  self.relu = nn.ReLU(inplace=True)
 
65
  class BottleneckV1b(nn.Module):
66
  expansion = 4
67
 
68
+ def __init__(
69
+ self,
70
+ inplanes,
71
+ planes,
72
+ stride=1,
73
+ dilation=1,
74
+ downsample=None,
75
+ previous_dilation=1,
76
+ norm_layer=nn.BatchNorm2d,
77
+ ):
78
  super(BottleneckV1b, self).__init__()
79
  self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
80
  self.bn1 = norm_layer(planes)
81
 
82
+ self.conv2 = nn.Conv2d(
83
+ planes,
84
+ planes,
85
+ kernel_size=3,
86
+ stride=stride,
87
+ padding=dilation,
88
+ dilation=dilation,
89
+ bias=False,
90
+ )
91
  self.bn2 = norm_layer(planes)
92
 
93
+ self.conv3 = nn.Conv2d(
94
+ planes, planes * self.expansion, kernel_size=1, bias=False
95
+ )
96
  self.bn3 = norm_layer(planes * self.expansion)
97
 
98
  self.relu = nn.ReLU(inplace=True)
 
123
 
124
 
125
  class ResNetV1b(nn.Module):
126
+ """Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
127
 
128
  Parameters
129
  ----------
 
151
 
152
  - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
153
  """
154
+
155
+ def __init__(
156
+ self,
157
+ block,
158
+ layers,
159
+ classes=1000,
160
+ dilated=True,
161
+ deep_stem=False,
162
+ stem_width=32,
163
+ avg_down=False,
164
+ final_drop=0.0,
165
+ norm_layer=nn.BatchNorm2d,
166
+ ):
167
+ self.inplanes = stem_width * 2 if deep_stem else 64
168
  super(ResNetV1b, self).__init__()
169
  if not deep_stem:
170
+ self.conv1 = nn.Conv2d(
171
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False
172
+ )
173
  else:
174
  self.conv1 = nn.Sequential(
175
+ nn.Conv2d(
176
+ 3, stem_width, kernel_size=3, stride=2, padding=1, bias=False
177
+ ),
178
  norm_layer(stem_width),
179
  nn.ReLU(True),
180
+ nn.Conv2d(
181
+ stem_width,
182
+ stem_width,
183
+ kernel_size=3,
184
+ stride=1,
185
+ padding=1,
186
+ bias=False,
187
+ ),
188
  norm_layer(stem_width),
189
  nn.ReLU(True),
190
+ nn.Conv2d(
191
+ stem_width,
192
+ 2 * stem_width,
193
+ kernel_size=3,
194
+ stride=1,
195
+ padding=1,
196
+ bias=False,
197
+ ),
198
  )
199
  self.bn1 = norm_layer(self.inplanes)
200
  self.relu = nn.ReLU(True)
201
  self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
202
+ self.layer1 = self._make_layer(
203
+ block, 64, layers[0], avg_down=avg_down, norm_layer=norm_layer
204
+ )
205
+ self.layer2 = self._make_layer(
206
+ block, 128, layers[1], stride=2, avg_down=avg_down, norm_layer=norm_layer
207
+ )
208
  if dilated:
209
+ self.layer3 = self._make_layer(
210
+ block,
211
+ 256,
212
+ layers[2],
213
+ stride=1,
214
+ dilation=2,
215
+ avg_down=avg_down,
216
+ norm_layer=norm_layer,
217
+ )
218
+ self.layer4 = self._make_layer(
219
+ block,
220
+ 512,
221
+ layers[3],
222
+ stride=1,
223
+ dilation=4,
224
+ avg_down=avg_down,
225
+ norm_layer=norm_layer,
226
+ )
227
  else:
228
+ self.layer3 = self._make_layer(
229
+ block,
230
+ 256,
231
+ layers[2],
232
+ stride=2,
233
+ avg_down=avg_down,
234
+ norm_layer=norm_layer,
235
+ )
236
+ self.layer4 = self._make_layer(
237
+ block,
238
+ 512,
239
+ layers[3],
240
+ stride=2,
241
+ avg_down=avg_down,
242
+ norm_layer=norm_layer,
243
+ )
244
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
245
  self.drop = None
246
  if final_drop > 0.0:
247
  self.drop = nn.Dropout(final_drop)
248
  self.fc = nn.Linear(512 * block.expansion, classes)
249
 
250
+ def _make_layer(
251
+ self,
252
+ block,
253
+ planes,
254
+ blocks,
255
+ stride=1,
256
+ dilation=1,
257
+ avg_down=False,
258
+ norm_layer=nn.BatchNorm2d,
259
+ ):
260
  downsample = None
261
  if stride != 1 or self.inplanes != planes * block.expansion:
262
  downsample = []
263
  if avg_down:
264
  if dilation == 1:
265
  downsample.append(
266
+ nn.AvgPool2d(
267
+ kernel_size=stride,
268
+ stride=stride,
269
+ ceil_mode=True,
270
+ count_include_pad=False,
271
+ )
272
  )
273
  else:
274
  downsample.append(
275
+ nn.AvgPool2d(
276
+ kernel_size=1,
277
+ stride=1,
278
+ ceil_mode=True,
279
+ count_include_pad=False,
280
+ )
281
  )
282
+ downsample.extend(
283
+ [
284
+ nn.Conv2d(
285
+ self.inplanes,
286
+ out_channels=planes * block.expansion,
287
+ kernel_size=1,
288
+ stride=1,
289
+ bias=False,
290
+ ),
291
+ norm_layer(planes * block.expansion),
292
+ ]
293
+ )
294
  downsample = nn.Sequential(*downsample)
295
  else:
296
  downsample = nn.Sequential(
297
+ nn.Conv2d(
298
+ self.inplanes,
299
+ out_channels=planes * block.expansion,
300
+ kernel_size=1,
301
+ stride=stride,
302
+ bias=False,
303
+ ),
304
+ norm_layer(planes * block.expansion),
305
  )
306
 
307
  layers = []
308
  if dilation in (1, 2):
309
+ layers.append(
310
+ block(
311
+ self.inplanes,
312
+ planes,
313
+ stride,
314
+ dilation=1,
315
+ downsample=downsample,
316
+ previous_dilation=dilation,
317
+ norm_layer=norm_layer,
318
+ )
319
+ )
320
  elif dilation == 4:
321
+ layers.append(
322
+ block(
323
+ self.inplanes,
324
+ planes,
325
+ stride,
326
+ dilation=2,
327
+ downsample=downsample,
328
+ previous_dilation=dilation,
329
+ norm_layer=norm_layer,
330
+ )
331
+ )
332
  else:
333
  raise RuntimeError("=> unknown dilation size: {}".format(dilation))
334
 
335
  self.inplanes = planes * block.expansion
336
  for _ in range(1, blocks):
337
+ layers.append(
338
+ block(
339
+ self.inplanes,
340
+ planes,
341
+ dilation=dilation,
342
+ previous_dilation=dilation,
343
+ norm_layer=norm_layer,
344
+ )
345
+ )
346
 
347
  return nn.Sequential(*layers)
348
 
 
381
  if pretrained:
382
  model_dict = model.state_dict()
383
  filtered_orig_dict = _safe_state_dict_filtering(
384
+ torch.hub.load(
385
+ GLUON_RESNET_TORCH_HUB, "gluon_resnet34_v1b", pretrained=True
386
+ ).state_dict(),
387
+ model_dict.keys(),
388
  )
389
  model_dict.update(filtered_orig_dict)
390
  model.load_state_dict(model_dict)
 
392
 
393
 
394
  def resnet50_v1s(pretrained=False, **kwargs):
395
+ model = ResNetV1b(
396
+ BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs
397
+ )
398
  if pretrained:
399
  model_dict = model.state_dict()
400
  filtered_orig_dict = _safe_state_dict_filtering(
401
+ torch.hub.load(
402
+ GLUON_RESNET_TORCH_HUB, "gluon_resnet50_v1s", pretrained=True
403
+ ).state_dict(),
404
+ model_dict.keys(),
405
  )
406
  model_dict.update(filtered_orig_dict)
407
  model.load_state_dict(model_dict)
 
409
 
410
 
411
  def resnet101_v1s(pretrained=False, **kwargs):
412
+ model = ResNetV1b(
413
+ BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs
414
+ )
415
  if pretrained:
416
  model_dict = model.state_dict()
417
  filtered_orig_dict = _safe_state_dict_filtering(
418
+ torch.hub.load(
419
+ GLUON_RESNET_TORCH_HUB, "gluon_resnet101_v1s", pretrained=True
420
+ ).state_dict(),
421
+ model_dict.keys(),
422
  )
423
  model_dict.update(filtered_orig_dict)
424
  model.load_state_dict(model_dict)
 
426
 
427
 
428
  def resnet152_v1s(pretrained=False, **kwargs):
429
+ model = ResNetV1b(
430
+ BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs
431
+ )
432
  if pretrained:
433
  model_dict = model.state_dict()
434
  filtered_orig_dict = _safe_state_dict_filtering(
435
+ torch.hub.load(
436
+ GLUON_RESNET_TORCH_HUB, "gluon_resnet152_v1s", pretrained=True
437
+ ).state_dict(),
438
+ model_dict.keys(),
439
  )
440
  model_dict.update(filtered_orig_dict)
441
  model.load_state_dict(model_dict)
isegm/model/modifiers.py CHANGED
@@ -1,11 +1,9 @@
1
-
2
-
3
  class LRMult(object):
4
- def __init__(self, lr_mult=1.):
5
  self.lr_mult = lr_mult
6
 
7
  def __call__(self, m):
8
- if getattr(m, 'weight', None) is not None:
9
  m.weight.lr_mult = self.lr_mult
10
- if getattr(m, 'bias', None) is not None:
11
  m.bias.lr_mult = self.lr_mult
 
 
 
1
  class LRMult(object):
2
+ def __init__(self, lr_mult=1.0):
3
  self.lr_mult = lr_mult
4
 
5
  def __call__(self, m):
6
+ if getattr(m, "weight", None) is not None:
7
  m.weight.lr_mult = self.lr_mult
8
+ if getattr(m, "bias", None) is not None:
9
  m.bias.lr_mult = self.lr_mult
isegm/model/ops.py CHANGED
@@ -1,14 +1,15 @@
 
1
  import torch
2
  from torch import nn as nn
3
- import numpy as np
4
  import isegm.model.initializer as initializer
5
 
6
 
7
  def select_activation_function(activation):
8
  if isinstance(activation, str):
9
- if activation.lower() == 'relu':
10
  return nn.ReLU
11
- elif activation.lower() == 'softplus':
12
  return nn.Softplus
13
  else:
14
  raise ValueError(f"Unknown activation type {activation}")
@@ -24,14 +25,18 @@ class BilinearConvTranspose2d(nn.ConvTranspose2d):
24
  self.scale = scale
25
 
26
  super().__init__(
27
- in_channels, out_channels,
 
28
  kernel_size=kernel_size,
29
  stride=scale,
30
  padding=1,
31
  groups=groups,
32
- bias=False)
 
33
 
34
- self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
 
 
35
 
36
 
37
  class DistMaps(nn.Module):
@@ -43,29 +48,47 @@ class DistMaps(nn.Module):
43
  self.use_disks = use_disks
44
  if self.cpu_mode:
45
  from isegm.utils.cython import get_dist_maps
 
46
  self._get_dist_maps = get_dist_maps
47
 
48
  def get_coord_features(self, points, batchsize, rows, cols):
49
  if self.cpu_mode:
50
  coords = []
51
  for i in range(batchsize):
52
- norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius
53
- coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
54
- norm_delimeter))
55
- coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
 
 
 
 
 
 
 
56
  else:
57
  num_points = points.shape[1] // 2
58
  points = points.view(-1, points.size(2))
59
  points, points_order = torch.split(points, [2, 1], dim=1)
60
 
61
  invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
62
- row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
63
- col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
 
 
 
 
64
 
65
  coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
66
- coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
67
-
68
- add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
 
 
 
 
 
 
69
  coords.add_(-add_xy)
70
  if not self.use_disks:
71
  coords.div_(self.norm_radius * self.spatial_scale)
 
1
+ import numpy as np
2
  import torch
3
  from torch import nn as nn
4
+
5
  import isegm.model.initializer as initializer
6
 
7
 
8
  def select_activation_function(activation):
9
  if isinstance(activation, str):
10
+ if activation.lower() == "relu":
11
  return nn.ReLU
12
+ elif activation.lower() == "softplus":
13
  return nn.Softplus
14
  else:
15
  raise ValueError(f"Unknown activation type {activation}")
 
25
  self.scale = scale
26
 
27
  super().__init__(
28
+ in_channels,
29
+ out_channels,
30
  kernel_size=kernel_size,
31
  stride=scale,
32
  padding=1,
33
  groups=groups,
34
+ bias=False,
35
+ )
36
 
37
+ self.apply(
38
+ initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)
39
+ )
40
 
41
 
42
  class DistMaps(nn.Module):
 
48
  self.use_disks = use_disks
49
  if self.cpu_mode:
50
  from isegm.utils.cython import get_dist_maps
51
+
52
  self._get_dist_maps = get_dist_maps
53
 
54
  def get_coord_features(self, points, batchsize, rows, cols):
55
  if self.cpu_mode:
56
  coords = []
57
  for i in range(batchsize):
58
+ norm_delimeter = (
59
+ 1.0 if self.use_disks else self.spatial_scale * self.norm_radius
60
+ )
61
+ coords.append(
62
+ self._get_dist_maps(
63
+ points[i].cpu().float().numpy(), rows, cols, norm_delimeter
64
+ )
65
+ )
66
+ coords = (
67
+ torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
68
+ )
69
  else:
70
  num_points = points.shape[1] // 2
71
  points = points.view(-1, points.size(2))
72
  points, points_order = torch.split(points, [2, 1], dim=1)
73
 
74
  invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
75
+ row_array = torch.arange(
76
+ start=0, end=rows, step=1, dtype=torch.float32, device=points.device
77
+ )
78
+ col_array = torch.arange(
79
+ start=0, end=cols, step=1, dtype=torch.float32, device=points.device
80
+ )
81
 
82
  coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
83
+ coords = (
84
+ torch.stack((coord_rows, coord_cols), dim=0)
85
+ .unsqueeze(0)
86
+ .repeat(points.size(0), 1, 1, 1)
87
+ )
88
+
89
+ add_xy = (points * self.spatial_scale).view(
90
+ points.size(0), points.size(1), 1, 1
91
+ )
92
  coords.add_(-add_xy)
93
  if not self.use_disks:
94
  coords.div_(self.norm_radius * self.spatial_scale)
isegm/utils/cython/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  # noinspection PyUnresolvedReferences
2
- from .dist_maps import get_dist_maps
 
1
  # noinspection PyUnresolvedReferences
2
+ from .dist_maps import get_dist_maps
isegm/utils/cython/_get_dist_maps.pyx CHANGED
@@ -1,7 +1,8 @@
1
  import numpy as np
 
2
  cimport cython
3
  cimport numpy as np
4
- from libc.stdlib cimport malloc, free
5
 
6
  ctypedef struct qnode:
7
  int row
 
1
  import numpy as np
2
+
3
  cimport cython
4
  cimport numpy as np
5
+ from libc.stdlib cimport free, malloc
6
 
7
  ctypedef struct qnode:
8
  int row
isegm/utils/cython/dist_maps.py CHANGED
@@ -1,3 +1,5 @@
1
- import pyximport; pyximport.install(pyximport=True, language_level=3)
 
 
2
  # noinspection PyUnresolvedReferences
3
- from ._get_dist_maps import get_dist_maps
 
1
+ import pyximport
2
+
3
+ pyximport.install(pyximport=True, language_level=3)
4
  # noinspection PyUnresolvedReferences
5
+ from ._get_dist_maps import get_dist_maps
isegm/utils/distributed.py CHANGED
@@ -10,7 +10,11 @@ def get_rank():
10
 
11
 
12
  def synchronize():
13
- if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1:
 
 
 
 
14
  return
15
  dist.barrier()
16
 
@@ -58,10 +62,15 @@ def get_sampler(dataset, shuffle, distributed):
58
 
59
 
60
  def get_dp_wrapper(distributed):
61
- class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel):
 
 
 
 
62
  def __getattr__(self, name):
63
  try:
64
  return super().__getattr__(name)
65
  except AttributeError:
66
  return getattr(self.module, name)
 
67
  return DPWrapper
 
10
 
11
 
12
  def synchronize():
13
+ if (
14
+ not dist.is_available()
15
+ or not dist.is_initialized()
16
+ or dist.get_world_size() == 1
17
+ ):
18
  return
19
  dist.barrier()
20
 
 
62
 
63
 
64
  def get_dp_wrapper(distributed):
65
+ class DPWrapper(
66
+ torch.nn.parallel.DistributedDataParallel
67
+ if distributed
68
+ else torch.nn.DataParallel
69
+ ):
70
  def __getattr__(self, name):
71
  try:
72
  return super().__getattr__(name)
73
  except AttributeError:
74
  return getattr(self.module, name)
75
+
76
  return DPWrapper