diff --git a/app.py b/app.py index e8b4434e4a2b29542d8d78cf48f4bf5cfc7585c7..60a1adaf00c6a5f75eed0011d82eab403606acdf 100644 --- a/app.py +++ b/app.py @@ -1,94 +1,127 @@ +import os +from typing import Dict, List + +import cv2 +import numpy as np import streamlit as st import torch -import numpy as np -import cv2 import wget -import os - from PIL import Image from streamlit_drawable_canvas import st_canvas from isegm.inference import clicker as ck from isegm.inference import utils -from isegm.inference.predictors import get_predictor +from isegm.inference.predictors import BasePredictor, get_predictor + +################################### +# Global scope objects. +################################### +URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" +CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600 +POS_COLOR, NEG_COLOR = "#3498DB", "#C70039" +ERR_X, ERR_Y = 5.5, 1.0 +MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +clicker = ck.Clicker() +predictor = None +image = None + -@st.experimental_memo -def load_model(model_path, device): +################################### +# Functions. +################################### +# @st.cache_resource +def load_model(model_path: str, device: torch.device) -> BasePredictor: model = utils.load_is_model(model_path, device, cpu_dist_maps=True) predictor_params = {"brs_mode": "NoBRS"} predictor = get_predictor(model, device=device, **predictor_params) return predictor -# Objects in the global scope -url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" -models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} -clicker = ck.Clicker() -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -pos_color, neg_color = "#3498DB", "#C70039" -canvas_height, canvas_width = 600, 600 -err_x, err_y = 5.5, 1.0 -predictor = None -image = None +def feed_clicks( + clicker: ck.Clicker, + clicks: List[Dict[str, float]], + image_width: int, + image_height: int, +) -> None: + ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH + for click in clicks: + x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h + x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) + + is_positive = click["stroke"] == POS_COLOR + click = ck.Click(is_positive=is_positive, coords=(y, x)) + clicker.add_click(click) + +def predict( + image: Image, mask: torch.Tensor, threshold: float = 0.5 +) -> torch.Tensor: + predictor.set_input_image(np.array(image)) + with st.spinner("Wait for prediction..."): + pred = predictor.get_prediction(clicker, prev_mask=mask) + pred = cv2.resize( + pred, + dsize=(CANVAS_HEIGHT, CANVAS_WIDTH), + interpolation=cv2.INTER_CUBIC, + ) + pred = np.where(pred > threshold, 1.0, 0) + return pred + + +################################### +# Sidebar GUI +################################### # Items in the sidebar. -model = st.sidebar.selectbox("Select a Model:", tuple(models.keys())) +model = st.sidebar.selectbox("Select a Model:", tuple(MODELS.keys())) threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) -marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative")) +marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative")) image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) +if image_path: + image = Image.open(image_path).convert("RGB") -# Objects for prediction. +################################### +# Preparation +################################### +# Model. with st.spinner("Wait for downloading a model..."): - if not os.path.exists(models[model]): - _ = wget.download(f"{url_prefix}/{models[model]}") - + if not os.path.exists(MODELS[model]): + _ = wget.download(f"{URL_PREFIX}/{MODELS[model]}") +# Predictor. with st.spinner("Wait for loading a model..."): - predictor = load_model(models[model], device) + predictor = load_model(MODELS[model], device) +################################### +# GUI +################################### # Create a canvas component. -if image_path: - image = Image.open(image_path).convert("RGB") - st.title("Canvas:") canvas_result = st_canvas( - fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity - stroke_width=3, - stroke_color=pos_color if marking_type == "positive" else neg_color, - background_color="#eee", - background_image=image, - update_streamlit=True, - drawing_mode="point", - point_display_radius=3, - key="canvas", - width=canvas_width, - height=canvas_height, + fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity + stroke_width=3, + stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR, + background_color="#eee", + background_image=image, + update_streamlit=True, + drawing_mode="point", + point_display_radius=3, + key="canvas", + width=CANVAS_WIDTH, + height=CANVAS_HEIGHT, ) +################################### +# Prediction +################################### # Check the user inputs ans execute predictions. st.title("Prediction:") if canvas_result.json_data and canvas_result.json_data["objects"] and image: - objects = canvas_result.json_data["objects"] image_width, image_height = image.size - ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width - - pos_clicks, neg_clicks = [], [] - for click in objects: - x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h - x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) - - is_positive = click["stroke"] == pos_color - click = ck.Click(is_positive=is_positive, coords=(y, x)) - clicker.add_click(click) + feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height) # Run prediction. - pred = None - predictor.set_input_image(np.array(image)) - init_mask = torch.zeros((1, 1, image_height, image_width), device=device) - - with st.spinner("Wait for prediction..."): - pred = predictor.get_prediction(clicker, prev_mask=init_mask) - pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC) - pred = np.where(pred > threshold, 1.0, 0) + mask = torch.zeros((1, 1, image_width, image_height), device=device) + pred = predict(image, mask, threshold) # Show the prediction result. st.image(pred, caption="") diff --git a/isegm/data/base.py b/isegm/data/base.py index ee2a532643d2ebf9c234a19e46df652ac56497cb..84ccd9d0017fdc636470989a1842e60d2f803071 100644 --- a/isegm/data/base.py +++ b/isegm/data/base.py @@ -1,22 +1,26 @@ -import random import pickle +import random + import numpy as np import torch from torchvision import transforms + from .points_sampler import MultiPointSampler from .sample import DSample class ISDataset(torch.utils.data.dataset.Dataset): - def __init__(self, - augmentator=None, - points_sampler=MultiPointSampler(max_num_points=12), - min_object_area=0, - keep_background_prob=0.0, - with_image_info=False, - samples_scores_path=None, - samples_scores_gamma=1.0, - epoch_len=-1): + def __init__( + self, + augmentator=None, + points_sampler=MultiPointSampler(max_num_points=12), + min_object_area=0, + keep_background_prob=0.0, + with_image_info=False, + samples_scores_path=None, + samples_scores_gamma=1.0, + epoch_len=-1, + ): super(ISDataset, self).__init__() self.epoch_len = epoch_len self.augmentator = augmentator @@ -24,15 +28,19 @@ class ISDataset(torch.utils.data.dataset.Dataset): self.keep_background_prob = keep_background_prob self.points_sampler = points_sampler self.with_image_info = with_image_info - self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma) + self.samples_precomputed_scores = self._load_samples_scores( + samples_scores_path, samples_scores_gamma + ) self.to_tensor = transforms.ToTensor() self.dataset_samples = None def __getitem__(self, index): if self.samples_precomputed_scores is not None: - index = np.random.choice(self.samples_precomputed_scores['indices'], - p=self.samples_precomputed_scores['probs']) + index = np.random.choice( + self.samples_precomputed_scores["indices"], + p=self.samples_precomputed_scores["probs"], + ) else: if self.epoch_len > 0: index = random.randrange(0, len(self.dataset_samples)) @@ -46,13 +54,13 @@ class ISDataset(torch.utils.data.dataset.Dataset): mask = self.points_sampler.selected_mask output = { - 'images': self.to_tensor(sample.image), - 'points': points.astype(np.float32), - 'instances': mask + "images": self.to_tensor(sample.image), + "points": points.astype(np.float32), + "instances": mask, } if self.with_image_info: - output['image_info'] = sample.sample_id + output["image_info"] = sample.sample_id return output @@ -63,8 +71,10 @@ class ISDataset(torch.utils.data.dataset.Dataset): valid_augmentation = False while not valid_augmentation: sample.augment(self.augmentator) - keep_sample = (self.keep_background_prob < 0.0 or - random.random() < self.keep_background_prob) + keep_sample = ( + self.keep_background_prob < 0.0 + or random.random() < self.keep_background_prob + ) valid_augmentation = len(sample) > 0 or keep_sample return sample @@ -86,14 +96,11 @@ class ISDataset(torch.utils.data.dataset.Dataset): if samples_scores_path is None: return None - with open(samples_scores_path, 'rb') as f: + with open(samples_scores_path, "rb") as f: images_scores = pickle.load(f) probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) probs /= probs.sum() - samples_scores = { - 'indices': [x[0] for x in images_scores], - 'probs': probs - } - print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') + samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs} + print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}") return samples_scores diff --git a/isegm/data/compose.py b/isegm/data/compose.py index e6e458cfd5693a3b5a73b9717c268213914f8430..e94bc6d93dedc170ee5a53cab2e58f8e7343a8ce 100644 --- a/isegm/data/compose.py +++ b/isegm/data/compose.py @@ -1,5 +1,7 @@ -import numpy as np from math import isclose + +import numpy as np + from .base import ISDataset @@ -10,7 +12,9 @@ class ComposeDataset(ISDataset): self._datasets = datasets self.dataset_samples = [] for dataset_indx, dataset in enumerate(self._datasets): - self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) + self.dataset_samples.extend( + [(dataset_indx, i) for i in range(len(dataset))] + ) def get_sample(self, index): dataset_indx, sample_indx = self.dataset_samples[index] @@ -21,16 +25,18 @@ class ProportionalComposeDataset(ISDataset): def __init__(self, datasets, ratios, **kwargs): super().__init__(**kwargs) - assert len(ratios) == len(datasets),\ - "The number of datasets must match the number of ratios" - assert isclose(sum(ratios), 1.0),\ - "The sum of ratios must be equal to 1" + assert len(ratios) == len( + datasets + ), "The number of datasets must match the number of ratios" + assert isclose(sum(ratios), 1.0), "The sum of ratios must be equal to 1" self._ratios = ratios self._datasets = datasets self.dataset_samples = [] for dataset_indx, dataset in enumerate(self._datasets): - self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) + self.dataset_samples.extend( + [(dataset_indx, i) for i in range(len(dataset))] + ) def get_sample(self, index): dataset_indx = np.random.choice(len(self._datasets), p=self._ratios) diff --git a/isegm/data/datasets/__init__.py b/isegm/data/datasets/__init__.py index 966ffff2028cd494f785011eb037628890c06b94..bd463cdd61c6420e07663ff9ae10542bb7f96a56 100644 --- a/isegm/data/datasets/__init__.py +++ b/isegm/data/datasets/__init__.py @@ -1,12 +1,13 @@ from isegm.data.compose import ComposeDataset, ProportionalComposeDataset + +from .ade20k import ADE20kDataset from .berkeley import BerkeleyDataset from .coco import CocoDataset +from .coco_lvis import CocoLvisDataset from .davis import DavisDataset from .grabcut import GrabCutDataset -from .coco_lvis import CocoLvisDataset +from .images_dir import ImagesDirDataset from .lvis import LvisDataset from .openimages import OpenImagesDataset -from .sbd import SBDDataset, SBDEvaluationDataset -from .images_dir import ImagesDirDataset -from .ade20k import ADE20kDataset from .pascalvoc import PascalVocDataset +from .sbd import SBDDataset, SBDEvaluationDataset diff --git a/isegm/data/datasets/ade20k.py b/isegm/data/datasets/ade20k.py index 6791b8353a2d34c5e6e36ca5cdc6e4bdb62339c2..36b597c8f5e86d093cf62daf4100dde26a985df0 100644 --- a/isegm/data/datasets/ade20k.py +++ b/isegm/data/datasets/ade20k.py @@ -1,6 +1,6 @@ import os -import random import pickle as pkl +import random from pathlib import Path import cv2 @@ -12,18 +12,18 @@ from isegm.utils.misc import get_labels_with_sizes class ADE20kDataset(ISDataset): - def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): + def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs): super().__init__(**kwargs) - assert split in {'train', 'val'} + assert split in {"train", "val"} self.dataset_path = Path(dataset_path) self.dataset_split = split - self.dataset_split_folder = 'training' if split == 'train' else 'validation' + self.dataset_split_folder = "training" if split == "train" else "validation" self.stuff_prob = stuff_prob - anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl' + anno_path = self.dataset_path / f"{split}-annotations-object-segmentation.pkl" if os.path.exists(anno_path): - with anno_path.open('rb') as f: + with anno_path.open("rb") as f: annotations = pkl.load(f) else: raise RuntimeError(f"Can't find annotations at {anno_path}") @@ -34,21 +34,23 @@ class ADE20kDataset(ISDataset): image_id = self.dataset_samples[index] sample_annos = self.annotations[image_id] - image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg') + image_path = str(self.dataset_path / sample_annos["folder"] / f"{image_id}.jpg") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # select random mask for an image - layer = random.choice(sample_annos['layers']) - mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name']) - instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances + layer = random.choice(sample_annos["layers"]) + mask_path = str(self.dataset_path / sample_annos["folder"] / layer["mask_name"]) + instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[ + :, :, 0 + ] # the B channel holds instances instances_mask = instances_mask.astype(np.int32) object_ids, _ = get_labels_with_sizes(instances_mask) if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob): # remove stuff objects for i, object_id in enumerate(object_ids): - if i in layer['stuff_instances']: + if i in layer["stuff_instances"]: instances_mask[instances_mask == object_id] = 0 object_ids, _ = get_labels_with_sizes(instances_mask) diff --git a/isegm/data/datasets/berkeley.py b/isegm/data/datasets/berkeley.py index 5c269d84afdc8350cf92f0deddf732c5b62c0687..fe45e39d464414d798d68290387fdec93e64ac29 100644 --- a/isegm/data/datasets/berkeley.py +++ b/isegm/data/datasets/berkeley.py @@ -3,4 +3,6 @@ from .grabcut import GrabCutDataset class BerkeleyDataset(GrabCutDataset): def __init__(self, dataset_path, **kwargs): - super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs) + super().__init__( + dataset_path, images_dir_name="images", masks_dir_name="masks", **kwargs + ) diff --git a/isegm/data/datasets/coco.py b/isegm/data/datasets/coco.py index 985eb768579636ca9fcae68c56654af94d477f2a..65fe77f87bbf4871d46bbf5132350a1fc86e9c5b 100644 --- a/isegm/data/datasets/coco.py +++ b/isegm/data/datasets/coco.py @@ -1,14 +1,16 @@ -import cv2 import json import random -import numpy as np from pathlib import Path + +import cv2 +import numpy as np + from isegm.data.base import ISDataset from isegm.data.sample import DSample class CocoDataset(ISDataset): - def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): + def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs): super(CocoDataset, self).__init__(**kwargs) self.split = split self.dataset_path = Path(dataset_path) @@ -17,26 +19,28 @@ class CocoDataset(ISDataset): self.load_samples() def load_samples(self): - annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json' - self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}' + annotation_path = ( + self.dataset_path / "annotations" / f"panoptic_{self.split}.json" + ) + self.labels_path = self.dataset_path / "annotations" / f"panoptic_{self.split}" self.images_path = self.dataset_path / self.split - with open(annotation_path, 'r') as f: + with open(annotation_path, "r") as f: annotation = json.load(f) - self.dataset_samples = annotation['annotations'] + self.dataset_samples = annotation["annotations"] - self._categories = annotation['categories'] - self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0] - self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1] + self._categories = annotation["categories"] + self._stuff_labels = [x["id"] for x in self._categories if x["isthing"] == 0] + self._things_labels = [x["id"] for x in self._categories if x["isthing"] == 1] self._things_labels_set = set(self._things_labels) self._stuff_labels_set = set(self._stuff_labels) def get_sample(self, index) -> DSample: dataset_sample = self.dataset_samples[index] - image_path = self.images_path / self.get_image_name(dataset_sample['file_name']) - label_path = self.labels_path / dataset_sample['file_name'] + image_path = self.images_path / self.get_image_name(dataset_sample["file_name"]) + label_path = self.labels_path / dataset_sample["file_name"] image = cv2.imread(str(image_path)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -47,11 +51,11 @@ class CocoDataset(ISDataset): things_ids = [] stuff_ids = [] - for segment in dataset_sample['segments_info']: - class_id = segment['category_id'] - obj_id = segment['id'] + for segment in dataset_sample["segments_info"]: + class_id = segment["category_id"] + obj_id = segment["id"] if class_id in self._things_labels_set: - if segment['iscrowd'] == 1: + if segment["iscrowd"] == 1: continue things_ids.append(obj_id) else: @@ -71,4 +75,4 @@ class CocoDataset(ISDataset): @classmethod def get_image_name(cls, panoptic_name): - return panoptic_name.replace('.png', '.jpg') + return panoptic_name.replace(".png", ".jpg") diff --git a/isegm/data/datasets/coco_lvis.py b/isegm/data/datasets/coco_lvis.py index 03691036178a588e02167faf512ed473cae10e25..ce8def53e70444e0e5fec77d73fdb30878dd2171 100644 --- a/isegm/data/datasets/coco_lvis.py +++ b/isegm/data/datasets/coco_lvis.py @@ -1,66 +1,78 @@ -from pathlib import Path +import json import pickle import random -import numpy as np -import json -import cv2 from copy import deepcopy +from pathlib import Path + +import cv2 +import numpy as np + from isegm.data.base import ISDataset from isegm.data.sample import DSample class CocoLvisDataset(ISDataset): - def __init__(self, dataset_path, split='train', stuff_prob=0.0, - allow_list_name=None, anno_file='hannotation.pickle', **kwargs): + def __init__( + self, + dataset_path, + split="train", + stuff_prob=0.0, + allow_list_name=None, + anno_file="hannotation.pickle", + **kwargs, + ): super(CocoLvisDataset, self).__init__(**kwargs) dataset_path = Path(dataset_path) self._split_path = dataset_path / split self.split = split - self._images_path = self._split_path / 'images' - self._masks_path = self._split_path / 'masks' + self._images_path = self._split_path / "images" + self._masks_path = self._split_path / "masks" self.stuff_prob = stuff_prob - with open(self._split_path / anno_file, 'rb') as f: + with open(self._split_path / anno_file, "rb") as f: self.dataset_samples = sorted(pickle.load(f).items()) if allow_list_name is not None: allow_list_path = self._split_path / allow_list_name - with open(allow_list_path, 'r') as f: + with open(allow_list_path, "r") as f: allow_images_ids = json.load(f) allow_images_ids = set(allow_images_ids) - self.dataset_samples = [sample for sample in self.dataset_samples - if sample[0] in allow_images_ids] + self.dataset_samples = [ + sample + for sample in self.dataset_samples + if sample[0] in allow_images_ids + ] def get_sample(self, index) -> DSample: image_id, sample = self.dataset_samples[index] - image_path = self._images_path / f'{image_id}.jpg' + image_path = self._images_path / f"{image_id}.jpg" image = cv2.imread(str(image_path)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - packed_masks_path = self._masks_path / f'{image_id}.pickle' - with open(packed_masks_path, 'rb') as f: + packed_masks_path = self._masks_path / f"{image_id}.pickle" + with open(packed_masks_path, "rb") as f: encoded_layers, objs_mapping = pickle.load(f) layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers] layers = np.stack(layers, axis=2) - instances_info = deepcopy(sample['hierarchy']) + instances_info = deepcopy(sample["hierarchy"]) for inst_id, inst_info in list(instances_info.items()): if inst_info is None: - inst_info = {'children': [], 'parent': None, 'node_level': 0} + inst_info = {"children": [], "parent": None, "node_level": 0} instances_info[inst_id] = inst_info - inst_info['mapping'] = objs_mapping[inst_id] + inst_info["mapping"] = objs_mapping[inst_id] if self.stuff_prob > 0 and random.random() < self.stuff_prob: - for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): + for inst_id in range(sample["num_instance_masks"], len(objs_mapping)): instances_info[inst_id] = { - 'mapping': objs_mapping[inst_id], - 'parent': None, - 'children': [] + "mapping": objs_mapping[inst_id], + "parent": None, + "children": [], } else: - for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): + for inst_id in range(sample["num_instance_masks"], len(objs_mapping)): layer_indx, mask_id = objs_mapping[inst_id] layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0 diff --git a/isegm/data/datasets/davis.py b/isegm/data/datasets/davis.py index de36b96be27f12a286865086a4e070a452987169..f5bca7b316c660a3fe9e362e7796ef6f0c373cc1 100644 --- a/isegm/data/datasets/davis.py +++ b/isegm/data/datasets/davis.py @@ -8,22 +8,22 @@ from isegm.data.sample import DSample class DavisDataset(ISDataset): - def __init__(self, dataset_path, - images_dir_name='img', masks_dir_name='gt', - **kwargs): + def __init__( + self, dataset_path, images_dir_name="img", masks_dir_name="gt", **kwargs + ): super(DavisDataset, self).__init__(**kwargs) self.dataset_path = Path(dataset_path) self._images_path = self.dataset_path / images_dir_name self._insts_path = self.dataset_path / masks_dir_name - self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] - self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} + self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))] + self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")} def get_sample(self, index) -> DSample: image_name = self.dataset_samples[index] image_path = str(self._images_path / image_name) - mask_path = str(self._masks_paths[image_name.split('.')[0]]) + mask_path = str(self._masks_paths[image_name.split(".")[0]]) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) diff --git a/isegm/data/datasets/grabcut.py b/isegm/data/datasets/grabcut.py index ff00446d613183e3a0deed29cd8ed8dae53fd5b3..b5d466663fae3d61d13b64daf5477192dfb8844e 100644 --- a/isegm/data/datasets/grabcut.py +++ b/isegm/data/datasets/grabcut.py @@ -8,22 +8,26 @@ from isegm.data.sample import DSample class GrabCutDataset(ISDataset): - def __init__(self, dataset_path, - images_dir_name='data_GT', masks_dir_name='boundary_GT', - **kwargs): + def __init__( + self, + dataset_path, + images_dir_name="data_GT", + masks_dir_name="boundary_GT", + **kwargs + ): super(GrabCutDataset, self).__init__(**kwargs) self.dataset_path = Path(dataset_path) self._images_path = self.dataset_path / images_dir_name self._insts_path = self.dataset_path / masks_dir_name - self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] - self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} + self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))] + self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")} def get_sample(self, index) -> DSample: image_name = self.dataset_samples[index] image_path = str(self._images_path / image_name) - mask_path = str(self._masks_paths[image_name.split('.')[0]]) + mask_path = str(self._masks_paths[image_name.split(".")[0]]) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -31,4 +35,6 @@ class GrabCutDataset(ISDataset): instances_mask[instances_mask == 128] = -1 instances_mask[instances_mask > 128] = 1 - return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index) + return DSample( + image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index + ) diff --git a/isegm/data/datasets/images_dir.py b/isegm/data/datasets/images_dir.py index db7d0fa6288ebdd7a7865648965e64183543ac87..403f5211ca5a7b21f9e42fdea570d1587d74ffa8 100644 --- a/isegm/data/datasets/images_dir.py +++ b/isegm/data/datasets/images_dir.py @@ -1,49 +1,50 @@ +from pathlib import Path + import cv2 import numpy as np -from pathlib import Path from isegm.data.base import ISDataset from isegm.data.sample import DSample class ImagesDirDataset(ISDataset): - def __init__(self, dataset_path, - images_dir_name='images', masks_dir_name='masks', - **kwargs): + def __init__( + self, dataset_path, images_dir_name="images", masks_dir_name="masks", **kwargs + ): super(ImagesDirDataset, self).__init__(**kwargs) self.dataset_path = Path(dataset_path) self._images_path = self.dataset_path / images_dir_name self._insts_path = self.dataset_path / masks_dir_name - images_list = [x for x in sorted(self._images_path.glob('*.*'))] + images_list = [x for x in sorted(self._images_path.glob("*.*"))] - samples = {x.stem: {'image': x, 'masks': []} for x in images_list} - for mask_path in self._insts_path.glob('*.*'): + samples = {x.stem: {"image": x, "masks": []} for x in images_list} + for mask_path in self._insts_path.glob("*.*"): mask_name = mask_path.stem if mask_name in samples: - samples[mask_name]['masks'].append(mask_path) + samples[mask_name]["masks"].append(mask_path) continue - mask_name_split = mask_name.split('_') + mask_name_split = mask_name.split("_") if mask_name_split[-1].isdigit(): - mask_name = '_'.join(mask_name_split[:-1]) + mask_name = "_".join(mask_name_split[:-1]) assert mask_name in samples - samples[mask_name]['masks'].append(mask_path) + samples[mask_name]["masks"].append(mask_path) for x in samples.values(): - assert len(x['masks']) > 0, x['image'] + assert len(x["masks"]) > 0, x["image"] self.dataset_samples = [v for k, v in sorted(samples.items())] def get_sample(self, index) -> DSample: sample = self.dataset_samples[index] - image_path = str(sample['image']) + image_path = str(sample["image"]) objects = [] ignored_regions = [] masks = [] - for indx, mask_path in enumerate(sample['masks']): + for indx, mask_path in enumerate(sample["masks"]): gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32) instances_mask = np.zeros_like(gt_mask) instances_mask[gt_mask == 128] = 2 @@ -55,5 +56,10 @@ class ImagesDirDataset(ISDataset): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return DSample(image, np.stack(masks, axis=2), - objects_ids=objects, ignore_ids=ignored_regions, sample_id=index) + return DSample( + image, + np.stack(masks, axis=2), + objects_ids=objects, + ignore_ids=ignored_regions, + sample_id=index, + ) diff --git a/isegm/data/datasets/lvis.py b/isegm/data/datasets/lvis.py index fd94b431d97effcff96ee3bee607f97375b88325..5543825cf0635185ba31a1f3c27b56987670ae40 100644 --- a/isegm/data/datasets/lvis.py +++ b/isegm/data/datasets/lvis.py @@ -11,42 +11,41 @@ from isegm.data.sample import DSample class LvisDataset(ISDataset): - def __init__(self, dataset_path, split='train', - max_overlap_ratio=0.5, - **kwargs): + def __init__(self, dataset_path, split="train", max_overlap_ratio=0.5, **kwargs): super(LvisDataset, self).__init__(**kwargs) dataset_path = Path(dataset_path) - train_categories_path = dataset_path / 'train_categories.json' - self._train_path = dataset_path / 'train' - self._val_path = dataset_path / 'val' + train_categories_path = dataset_path / "train_categories.json" + self._train_path = dataset_path / "train" + self._val_path = dataset_path / "val" self.split = split self.max_overlap_ratio = max_overlap_ratio - with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f: + with open(dataset_path / split / f"lvis_{self.split}.json", "r") as f: json_annotation = json.loads(f.read()) self.annotations = defaultdict(list) - for x in json_annotation['annotations']: - self.annotations[x['image_id']].append(x) + for x in json_annotation["annotations"]: + self.annotations[x["image_id"]].append(x) if not train_categories_path.exists(): self.generate_train_categories(dataset_path, train_categories_path) - self.dataset_samples = [x for x in json_annotation['images'] - if len(self.annotations[x['id']]) > 0] + self.dataset_samples = [ + x for x in json_annotation["images"] if len(self.annotations[x["id"]]) > 0 + ] def get_sample(self, index) -> DSample: image_info = self.dataset_samples[index] - image_id, image_url = image_info['id'], image_info['coco_url'] - image_filename = image_url.split('/')[-1] + image_id, image_url = image_info["id"], image_info["coco_url"] + image_filename = image_url.split("/")[-1] image_annotations = self.annotations[image_id] random.shuffle(image_annotations) # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017) - if 'train2017' in image_url: - image_path = self._train_path / 'images' / image_filename + if "train2017" in image_url: + image_path = self._train_path / "images" / image_filename else: - image_path = self._val_path / 'images' / image_filename + image_path = self._val_path / "images" / image_filename image = cv2.imread(str(image_path)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -62,9 +61,14 @@ class LvisDataset(ISDataset): instances_mask = np.zeros_like(object_mask, dtype=np.int32) overlap_ids = np.bincount(instances_mask[object_mask].flatten()) - overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids) - if overlap_area > 0 and inst_id > 0] - overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area + overlap_areas = [ + overlap_area / instances_area[inst_id] + for inst_id, overlap_area in enumerate(overlap_ids) + if overlap_area > 0 and inst_id > 0 + ] + overlap_ratio = ( + np.logical_and(object_mask, instances_mask > 0).sum() / object_area + ) if overlap_areas: overlap_ratio = max(overlap_ratio, max(overlap_areas)) if overlap_ratio > self.max_overlap_ratio: @@ -77,11 +81,10 @@ class LvisDataset(ISDataset): return DSample(image, instances_mask, objects_ids=objects_ids) - @staticmethod def get_mask_from_polygon(annotation, image): mask = np.zeros(image.shape[:2], dtype=np.int32) - for contour_points in annotation['segmentation']: + for contour_points in annotation["segmentation"]: contour_points = np.array(contour_points).reshape((-1, 2)) contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :] cv2.fillPoly(mask, contour_points, 1) @@ -90,8 +93,8 @@ class LvisDataset(ISDataset): @staticmethod def generate_train_categories(dataset_path, train_categories_path): - with open(dataset_path / 'train/lvis_train.json', 'r') as f: + with open(dataset_path / "train/lvis_train.json", "r") as f: annotation = json.load(f) - with open(train_categories_path, 'w') as f: - json.dump(annotation['categories'], f, indent=1) + with open(train_categories_path, "w") as f: + json.dump(annotation["categories"], f, indent=1) diff --git a/isegm/data/datasets/openimages.py b/isegm/data/datasets/openimages.py index d0a81cfbf08b9b5ddd3fe565a00e778733e9ee4a..2c6f46d5466455617fefe4b44dd58837bfb375ee 100644 --- a/isegm/data/datasets/openimages.py +++ b/isegm/data/datasets/openimages.py @@ -1,6 +1,6 @@ import os -import random import pickle as pkl +import random from pathlib import Path import cv2 @@ -11,29 +11,31 @@ from isegm.data.sample import DSample class OpenImagesDataset(ISDataset): - def __init__(self, dataset_path, split='train', **kwargs): + def __init__(self, dataset_path, split="train", **kwargs): super().__init__(**kwargs) - assert split in {'train', 'val', 'test'} + assert split in {"train", "val", "test"} self.dataset_path = Path(dataset_path) self._split_path = self.dataset_path / split - self._images_path = self._split_path / 'images' - self._masks_path = self._split_path / 'masks' + self._images_path = self._split_path / "images" + self._masks_path = self._split_path / "masks" self.dataset_split = split - clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl' + clean_anno_path = ( + self._split_path / f"{split}-annotations-object-segmentation_clean.pkl" + ) if os.path.exists(clean_anno_path): - with clean_anno_path.open('rb') as f: + with clean_anno_path.open("rb") as f: annotations = pkl.load(f) else: raise RuntimeError(f"Can't find annotations at {clean_anno_path}") - self.image_id_to_masks = annotations['image_id_to_masks'] - self.dataset_samples = annotations['dataset_samples'] + self.image_id_to_masks = annotations["image_id_to_masks"] + self.dataset_samples = annotations["dataset_samples"] def get_sample(self, index) -> DSample: image_id = self.dataset_samples[index] - image_path = str(self._images_path / f'{image_id}.jpg') + image_path = str(self._images_path / f"{image_id}.jpg") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -49,9 +51,16 @@ class OpenImagesDataset(ISDataset): min_height = min(image.shape[0], instances_mask.shape[0]) if image.shape[0] != min_height or image.shape[1] != min_width: - image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR) - if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width: - instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST) + image = cv2.resize( + image, (min_width, min_height), interpolation=cv2.INTER_LINEAR + ) + if ( + instances_mask.shape[0] != min_height + or instances_mask.shape[1] != min_width + ): + instances_mask = cv2.resize( + instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST + ) object_ids = [1] if instances_mask.sum() > 0 else [] diff --git a/isegm/data/datasets/pascalvoc.py b/isegm/data/datasets/pascalvoc.py index 4e1ad488f2228c1a94040d1bde21cd421ff70b3e..1cf84113bb3162c6c84000621a6e2858f76ea0bb 100644 --- a/isegm/data/datasets/pascalvoc.py +++ b/isegm/data/datasets/pascalvoc.py @@ -9,32 +9,38 @@ from isegm.data.sample import DSample class PascalVocDataset(ISDataset): - def __init__(self, dataset_path, split='train', **kwargs): + def __init__(self, dataset_path, split="train", **kwargs): super().__init__(**kwargs) - assert split in {'train', 'val', 'trainval', 'test'} + assert split in {"train", "val", "trainval", "test"} self.dataset_path = Path(dataset_path) self._images_path = self.dataset_path / "JPEGImages" self._insts_path = self.dataset_path / "SegmentationObject" self.dataset_split = split - if split == 'test': - with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f: + if split == "test": + with open( + self.dataset_path / f"ImageSets/Segmentation/test.pickle", "rb" + ) as f: self.dataset_samples, self.instance_ids = pkl.load(f) else: - with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f: + with open( + self.dataset_path / f"ImageSets/Segmentation/{split}.txt", "r" + ) as f: self.dataset_samples = [name.strip() for name in f.readlines()] def get_sample(self, index) -> DSample: sample_id = self.dataset_samples[index] - image_path = str(self._images_path / f'{sample_id}.jpg') - mask_path = str(self._insts_path / f'{sample_id}.png') + image_path = str(self._images_path / f"{sample_id}.jpg") + mask_path = str(self._insts_path / f"{sample_id}.png") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) instances_mask = cv2.imread(mask_path) - instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) - if self.dataset_split == 'test': + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + if self.dataset_split == "test": instance_id = self.instance_ids[index] mask = np.zeros_like(instances_mask) mask[instances_mask == 220] = 220 # ignored area @@ -45,4 +51,10 @@ class PascalVocDataset(ISDataset): objects_ids = np.unique(instances_mask) objects_ids = [x for x in objects_ids if x != 0 and x != 220] - return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + return DSample( + image, + instances_mask, + objects_ids=objects_ids, + ignore_ids=[220], + sample_id=index, + ) diff --git a/isegm/data/datasets/sbd.py b/isegm/data/datasets/sbd.py index b6a05e4b370f4b6486ebc24ceb961f545f256f81..fecee137cbe2ba584019a102a570bae65a439b85 100644 --- a/isegm/data/datasets/sbd.py +++ b/isegm/data/datasets/sbd.py @@ -5,38 +5,42 @@ import cv2 import numpy as np from scipy.io import loadmat -from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes from isegm.data.base import ISDataset from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes class SBDDataset(ISDataset): - def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs): + def __init__(self, dataset_path, split="train", buggy_mask_thresh=0.08, **kwargs): super(SBDDataset, self).__init__(**kwargs) - assert split in {'train', 'val'} + assert split in {"train", "val"} self.dataset_path = Path(dataset_path) self.dataset_split = split - self._images_path = self.dataset_path / 'img' - self._insts_path = self.dataset_path / 'inst' + self._images_path = self.dataset_path / "img" + self._insts_path = self.dataset_path / "inst" self._buggy_objects = dict() self._buggy_mask_thresh = buggy_mask_thresh - with open(self.dataset_path / f'{split}.txt', 'r') as f: + with open(self.dataset_path / f"{split}.txt", "r") as f: self.dataset_samples = [x.strip() for x in f.readlines()] def get_sample(self, index): image_name = self.dataset_samples[index] - image_path = str(self._images_path / f'{image_name}.jpg') - inst_info_path = str(self._insts_path / f'{image_name}.mat') + image_path = str(self._images_path / f"{image_name}.jpg") + inst_info_path = str(self._insts_path / f"{image_name}.mat") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( + np.int32 + ) instances_mask = self.remove_buggy_masks(index, instances_mask) instances_ids, _ = get_labels_with_sizes(instances_mask) - return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index) + return DSample( + image, instances_mask, objects_ids=instances_ids, sample_id=index + ) def remove_buggy_masks(self, index, instances_mask): if self._buggy_mask_thresh > 0.0: @@ -61,51 +65,55 @@ class SBDDataset(ISDataset): class SBDEvaluationDataset(ISDataset): - def __init__(self, dataset_path, split='val', **kwargs): + def __init__(self, dataset_path, split="val", **kwargs): super(SBDEvaluationDataset, self).__init__(**kwargs) - assert split in {'train', 'val'} + assert split in {"train", "val"} self.dataset_path = Path(dataset_path) self.dataset_split = split - self._images_path = self.dataset_path / 'img' - self._insts_path = self.dataset_path / 'inst' + self._images_path = self.dataset_path / "img" + self._insts_path = self.dataset_path / "inst" - with open(self.dataset_path / f'{split}.txt', 'r') as f: + with open(self.dataset_path / f"{split}.txt", "r") as f: self.dataset_samples = [x.strip() for x in f.readlines()] self.dataset_samples = self.get_sbd_images_and_ids_list() def get_sample(self, index) -> DSample: image_name, instance_id = self.dataset_samples[index] - image_path = str(self._images_path / f'{image_name}.jpg') - inst_info_path = str(self._insts_path / f'{image_name}.mat') + image_path = str(self._images_path / f"{image_name}.jpg") + inst_info_path = str(self._insts_path / f"{image_name}.mat") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( + np.int32 + ) instances_mask[instances_mask != instance_id] = 0 instances_mask[instances_mask > 0] = 1 return DSample(image, instances_mask, objects_ids=[1], sample_id=index) def get_sbd_images_and_ids_list(self): - pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl' + pkl_path = self.dataset_path / f"{self.dataset_split}_images_and_ids_list.pkl" if pkl_path.exists(): - with open(str(pkl_path), 'rb') as fp: + with open(str(pkl_path), "rb") as fp: images_and_ids_list = pkl.load(fp) else: images_and_ids_list = [] for sample in self.dataset_samples: - inst_info_path = str(self._insts_path / f'{sample}.mat') - instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + inst_info_path = str(self._insts_path / f"{sample}.mat") + instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( + np.int32 + ) instances_ids, _ = get_labels_with_sizes(instances_mask) for instances_id in instances_ids: images_and_ids_list.append((sample, instances_id)) - with open(str(pkl_path), 'wb') as fp: + with open(str(pkl_path), "wb") as fp: pkl.dump(images_and_ids_list, fp) return images_and_ids_list diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py index 43cc6380b173599457099ab35bd288db5cee0193..2e956ff0749349b7ba4c2bada38bbdfef2a46400 100644 --- a/isegm/data/points_sampler.py +++ b/isegm/data/points_sampler.py @@ -1,8 +1,10 @@ -import cv2 import math import random -import numpy as np from functools import lru_cache + +import cv2 +import numpy as np + from .sample import DSample @@ -28,13 +30,25 @@ class BasePointSampler: class MultiPointSampler(BasePointSampler): - def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1, - positive_erode_prob=0.9, positive_erode_iters=3, - negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5, - merge_objects_prob=0.0, max_num_merged_objects=2, - use_hierarchy=False, soft_targets=False, - first_click_center=False, only_one_first_click=False, - sfc_inner_k=1.7, sfc_full_inner_prob=0.0): + def __init__( + self, + max_num_points, + prob_gamma=0.7, + expand_ratio=0.1, + positive_erode_prob=0.9, + positive_erode_iters=3, + negative_bg_prob=0.1, + negative_other_prob=0.4, + negative_border_prob=0.5, + merge_objects_prob=0.0, + max_num_merged_objects=2, + use_hierarchy=False, + soft_targets=False, + first_click_center=False, + only_one_first_click=False, + sfc_inner_k=1.7, + sfc_full_inner_prob=0.0, + ): super().__init__() self.max_num_points = max_num_points self.expand_ratio = expand_ratio @@ -52,8 +66,12 @@ class MultiPointSampler(BasePointSampler): max_num_merged_objects = max_num_points self.max_num_merged_objects = max_num_merged_objects - self.neg_strategies = ['bg', 'other', 'border'] - self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob] + self.neg_strategies = ["bg", "other", "border"] + self.neg_strategies_prob = [ + negative_bg_prob, + negative_other_prob, + negative_border_prob, + ] assert math.isclose(sum(self.neg_strategies_prob), 1.0) self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma) @@ -66,7 +84,7 @@ class MultiPointSampler(BasePointSampler): self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32) self._selected_masks = [[]] self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies} - self._neg_masks['required'] = [] + self._neg_masks["required"] = [] return gt_mask, pos_masks, neg_masks = self._sample_mask(sample) @@ -80,14 +98,16 @@ class MultiPointSampler(BasePointSampler): if len(sample) <= len(self._selected_masks): neg_mask_other = neg_mask_bg else: - neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()), - np.logical_not(binary_gt_mask)) + neg_mask_other = np.logical_and( + np.logical_not(sample.get_background_mask()), + np.logical_not(binary_gt_mask), + ) self._neg_masks = { - 'bg': neg_mask_bg, - 'other': neg_mask_other, - 'border': neg_mask_border, - 'required': neg_masks + "bg": neg_mask_bg, + "other": neg_mask_other, + "border": neg_mask_border, + "required": neg_masks, } def _sample_mask(self, sample: DSample): @@ -104,7 +124,11 @@ class MultiPointSampler(BasePointSampler): pos_segments = [] neg_segments = [] for obj_id in random_ids: - obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample) + ( + obj_gt_mask, + obj_pos_segments, + obj_neg_segments, + ) = self._sample_from_masks_layer(obj_id, sample) if gt_mask is None: gt_mask = obj_gt_mask else: @@ -123,35 +147,45 @@ class MultiPointSampler(BasePointSampler): if not self.use_hierarchy: node_mask = sample.get_object_mask(obj_id) - gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask + gt_mask = ( + sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask + ) return gt_mask, [node_mask], [] def _select_node(node_id): node_info = objs_tree[node_id] - if not node_info['children'] or random.random() < 0.5: + if not node_info["children"] or random.random() < 0.5: return node_id - return _select_node(random.choice(node_info['children'])) + return _select_node(random.choice(node_info["children"])) selected_node = _select_node(obj_id) node_info = objs_tree[selected_node] node_mask = sample.get_object_mask(selected_node) - gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask + gt_mask = ( + sample.get_soft_object_mask(selected_node) + if self.soft_targets + else node_mask + ) pos_mask = node_mask.copy() negative_segments = [] - if node_info['parent'] is not None and node_info['parent'] in objs_tree: - parent_mask = sample.get_object_mask(node_info['parent']) - negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask))) - - for child_id in node_info['children']: - if objs_tree[child_id]['area'] / node_info['area'] < 0.10: + if node_info["parent"] is not None and node_info["parent"] in objs_tree: + parent_mask = sample.get_object_mask(node_info["parent"]) + negative_segments.append( + np.logical_and(parent_mask, np.logical_not(node_mask)) + ) + + for child_id in node_info["children"]: + if objs_tree[child_id]["area"] / node_info["area"] < 0.10: child_mask = sample.get_object_mask(child_id) pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) - if node_info['children']: - max_disabled_children = min(len(node_info['children']), 3) + if node_info["children"]: + max_disabled_children = min(len(node_info["children"]), 3) num_disabled_children = np.random.randint(0, max_disabled_children + 1) - disabled_children = random.sample(node_info['children'], num_disabled_children) + disabled_children = random.sample( + node_info["children"], num_disabled_children + ) for child_id in disabled_children: child_mask = sample.get_object_mask(child_id) @@ -167,24 +201,32 @@ class MultiPointSampler(BasePointSampler): def sample_points(self): assert self._selected_mask is not None - pos_points = self._multi_mask_sample_points(self._selected_masks, - is_negative=[False] * len(self._selected_masks), - with_first_click=self.first_click_center) - - neg_strategy = [(self._neg_masks[k], prob) - for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)] - neg_masks = self._neg_masks['required'] + [neg_strategy] - neg_points = self._multi_mask_sample_points(neg_masks, - is_negative=[False] * len(self._neg_masks['required']) + [True]) + pos_points = self._multi_mask_sample_points( + self._selected_masks, + is_negative=[False] * len(self._selected_masks), + with_first_click=self.first_click_center, + ) + + neg_strategy = [ + (self._neg_masks[k], prob) + for k, prob in zip(self.neg_strategies, self.neg_strategies_prob) + ] + neg_masks = self._neg_masks["required"] + [neg_strategy] + neg_points = self._multi_mask_sample_points( + neg_masks, is_negative=[False] * len(self._neg_masks["required"]) + [True] + ) return pos_points + neg_points - def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False): - selected_masks = selected_masks[:self.max_num_points] + def _multi_mask_sample_points( + self, selected_masks, is_negative, with_first_click=False + ): + selected_masks = selected_masks[: self.max_num_points] each_obj_points = [ - self._sample_points(mask, is_negative=is_negative[i], - with_first_click=with_first_click) + self._sample_points( + mask, is_negative=is_negative[i], with_first_click=with_first_click + ) for i, mask in enumerate(selected_masks) ] each_obj_points = [x for x in each_obj_points if len(x) > 0] @@ -200,17 +242,27 @@ class MultiPointSampler(BasePointSampler): aggregated_masks_with_prob = [] for indx, x in enumerate(selected_masks): - if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)): + if ( + isinstance(x, (list, tuple)) + and x + and isinstance(x[0], (list, tuple)) + ): for t, prob in x: - aggregated_masks_with_prob.append((t, prob / len(selected_masks))) + aggregated_masks_with_prob.append( + (t, prob / len(selected_masks)) + ) else: aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) - other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True) + other_points_union = self._sample_points( + aggregated_masks_with_prob, is_negative=True + ) if len(other_points_union) + len(points) <= self.max_num_points: points.extend(other_points_union) else: - points.extend(random.sample(other_points_union, self.max_num_points - len(points))) + points.extend( + random.sample(other_points_union, self.max_num_points - len(points)) + ) if len(points) < self.max_num_points: points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) @@ -219,9 +271,13 @@ class MultiPointSampler(BasePointSampler): def _sample_points(self, mask, is_negative=False, with_first_click=False): if is_negative: - num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs) + num_points = np.random.choice( + np.arange(self.max_num_points + 1), p=self._neg_probs + ) else: - num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) + num_points = 1 + np.random.choice( + np.arange(self.max_num_points), p=self._pos_probs + ) indices_probs = None if isinstance(mask, (list, tuple)): @@ -237,9 +293,13 @@ class MultiPointSampler(BasePointSampler): first_click = with_first_click and j == 0 and indices_probs is None if first_click: - point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob) + point_indices = get_point_candidates( + mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob + ) elif indices_probs: - point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs) + point_indices_indx = np.random.choice( + np.arange(len(indices)), p=indices_probs + ) point_indices = indices[point_indices_indx][0] else: point_indices = indices @@ -247,7 +307,9 @@ class MultiPointSampler(BasePointSampler): num_indices = len(point_indices) if num_indices > 0: point_indx = 0 if first_click else 100 - click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx] + click = point_indices[np.random.randint(0, num_indices)].tolist() + [ + point_indx + ] points.append(click) return points @@ -257,8 +319,9 @@ class MultiPointSampler(BasePointSampler): return mask kernel = np.ones((3, 3), np.uint8) - eroded_mask = cv2.erode(mask.astype(np.uint8), - kernel, iterations=self.positive_erode_iters).astype(np.bool) + eroded_mask = cv2.erode( + mask.astype(np.uint8), kernel, iterations=self.positive_erode_iters + ).astype(np.bool) if eroded_mask.sum() > 10: return eroded_mask @@ -291,7 +354,7 @@ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): if full_prob > 0 and random.random() < full_prob: return obj_mask - padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant') + padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), "constant") dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1] if k > 0: diff --git a/isegm/data/sample.py b/isegm/data/sample.py index d57794ca405a02f2e3a317b87efdeb94352cf138..1355f213d6516a7a79e591a5c98d4675b446f51c 100644 --- a/isegm/data/sample.py +++ b/isegm/data/sample.py @@ -1,13 +1,22 @@ -import numpy as np from copy import deepcopy -from isegm.utils.misc import get_labels_with_sizes -from isegm.data.transforms import remove_image_only_transforms + +import numpy as np from albumentations import ReplayCompose +from isegm.data.transforms import remove_image_only_transforms +from isegm.utils.misc import get_labels_with_sizes + class DSample: - def __init__(self, image, encoded_masks, objects=None, - objects_ids=None, ignore_ids=None, sample_id=None): + def __init__( + self, + image, + encoded_masks, + objects=None, + objects_ids=None, + ignore_ids=None, + sample_id=None, + ): self.image = image self.sample_id = sample_id @@ -24,9 +33,9 @@ class DSample: self._objects = dict() for indx, obj_mapping in enumerate(objects_ids): self._objects[indx] = { - 'parent': None, - 'mapping': obj_mapping, - 'children': [] + "parent": None, + "mapping": obj_mapping, + "children": [], } if ignore_ids: @@ -44,10 +53,10 @@ class DSample: def augment(self, augmentator): self.reset_augmentation() aug_output = augmentator(image=self.image, mask=self._encoded_masks) - self.image = aug_output['image'] - self._encoded_masks = aug_output['mask'] + self.image = aug_output["image"] + self._encoded_masks = aug_output["mask"] - aug_replay = aug_output.get('replay', None) + aug_replay = aug_output.get("replay", None) if aug_replay: assert len(self._ignored_regions) == 0 mask_replay = remove_image_only_transforms(aug_replay) @@ -69,15 +78,15 @@ class DSample: self._soft_mask_aug = None def remove_small_objects(self, min_area): - if self._objects and not 'area' in list(self._objects.values())[0]: + if self._objects and not "area" in list(self._objects.values())[0]: self._compute_objects_areas() for obj_id, obj_info in list(self._objects.items()): - if obj_info['area'] < min_area: + if obj_info["area"] < min_area: self._remove_object(obj_id) def get_object_mask(self, obj_id): - layer_indx, mask_id = self._objects[obj_id]['mapping'] + layer_indx, mask_id = self._objects[obj_id]["mapping"] obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32) if self._ignored_regions: for layer_indx, mask_id in self._ignored_regions: @@ -89,9 +98,13 @@ class DSample: def get_soft_object_mask(self, obj_id): assert self._soft_mask_aug is not None original_encoded_masks = self._original_data[1] - layer_indx, mask_id = self._objects[obj_id]['mapping'] - obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32) - obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image'] + layer_indx, mask_id = self._objects[obj_id]["mapping"] + obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype( + np.float32 + ) + obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[ + "image" + ] return np.clip(obj_mask, 0, 1) def get_background_mask(self): @@ -108,20 +121,28 @@ class DSample: @property def root_objects(self): - return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None] + return [ + obj_id + for obj_id, obj_info in self._objects.items() + if obj_info["parent"] is None + ] def _compute_objects_areas(self): - inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()} + inverse_index = { + node["mapping"]: node_id for node_id, node in self._objects.items() + } ignored_regions_keys = set(self._ignored_regions) for layer_indx in range(self._encoded_masks.shape[2]): - objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx]) + objects_ids, objects_areas = get_labels_with_sizes( + self._encoded_masks[:, :, layer_indx] + ) for obj_id, obj_area in zip(objects_ids, objects_areas): inv_key = (layer_indx, obj_id) if inv_key in ignored_regions_keys: continue try: - self._objects[inverse_index[inv_key]]['area'] = obj_area + self._objects[inverse_index[inv_key]]["area"] = obj_area del inverse_index[inv_key] except KeyError: layer = self._encoded_masks[:, :, layer_indx] @@ -129,18 +150,20 @@ class DSample: self._encoded_masks[:, :, layer_indx] = layer for obj_id in inverse_index.values(): - self._objects[obj_id]['area'] = 0 + self._objects[obj_id]["area"] = 0 def _remove_object(self, obj_id): obj_info = self._objects[obj_id] - obj_parent = obj_info['parent'] - for child_id in obj_info['children']: - self._objects[child_id]['parent'] = obj_parent + obj_parent = obj_info["parent"] + for child_id in obj_info["children"]: + self._objects[child_id]["parent"] = obj_parent if obj_parent is not None: - parent_children = self._objects[obj_parent]['children'] + parent_children = self._objects[obj_parent]["children"] parent_children = [x for x in parent_children if x != obj_id] - self._objects[obj_parent]['children'] = parent_children + obj_info['children'] + self._objects[obj_parent]["children"] = ( + parent_children + obj_info["children"] + ) del self._objects[obj_id] diff --git a/isegm/data/transforms.py b/isegm/data/transforms.py index 0a3fd67f6969ba7e120d03ce85b67a8b4651281d..04fa075a0da4996b9c4dfaf7546a3bf79d2f9f84 100644 --- a/isegm/data/transforms.py +++ b/isegm/data/transforms.py @@ -1,28 +1,40 @@ -import cv2 import random -import numpy as np +import cv2 +import numpy as np +from albumentations import DualTransform, ImageOnlyTransform +from albumentations.augmentations import functional as F from albumentations.core.serialization import SERIALIZABLE_REGISTRY -from albumentations import ImageOnlyTransform, DualTransform from albumentations.core.transforms_interface import to_tuple -from albumentations.augmentations import functional as F -from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes + +from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask, + get_labels_with_sizes) class UniformRandomResize(DualTransform): - def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): + def __init__( + self, + scale_range=(0.9, 1.1), + interpolation=cv2.INTER_LINEAR, + always_apply=False, + p=1, + ): super().__init__(always_apply, p) self.scale_range = scale_range self.interpolation = interpolation def get_params_dependent_on_targets(self, params): scale = random.uniform(*self.scale_range) - height = int(round(params['image'].shape[0] * scale)) - width = int(round(params['image'].shape[1] * scale)) - return {'new_height': height, 'new_width': width} + height = int(round(params["image"].shape[0] * scale)) + width = int(round(params["image"].shape[1] * scale)) + return {"new_height": height, "new_width": width} - def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params): - return F.resize(img, height=new_height, width=new_width, interpolation=interpolation) + def apply( + self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params + ): + return F.resize( + img, height=new_height, width=new_width, interpolation=interpolation + ) def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params): scale_x = new_width / params["cols"] @@ -39,16 +51,16 @@ class UniformRandomResize(DualTransform): class ZoomIn(DualTransform): def __init__( - self, - height, - width, - bbox_jitter=0.1, - expansion_ratio=1.4, - min_crop_size=200, - min_area=100, - always_resize=False, - always_apply=False, - p=0.5, + self, + height, + width, + bbox_jitter=0.1, + expansion_ratio=1.4, + min_crop_size=200, + min_area=100, + always_resize=False, + always_apply=False, + p=0.5, ): super(ZoomIn, self).__init__(always_apply, p) self.height = height @@ -66,7 +78,7 @@ class ZoomIn(DualTransform): return img rmin, rmax, cmin, cmax = bbox - img = img[rmin:rmax + 1, cmin:cmax + 1] + img = img[rmin : rmax + 1, cmin : cmax + 1] img = F.resize(img, height=self.height, width=self.width) return img @@ -74,12 +86,16 @@ class ZoomIn(DualTransform): def apply_to_mask(self, mask, selected_object, bbox, **params): if selected_object is None: if self.always_resize: - mask = F.resize(mask, height=self.height, width=self.width, - interpolation=cv2.INTER_NEAREST) + mask = F.resize( + mask, + height=self.height, + width=self.width, + interpolation=cv2.INTER_NEAREST, + ) return mask rmin, rmax, cmin, cmax = bbox - mask = mask[rmin:rmax + 1, cmin:cmax + 1] + mask = mask[rmin : rmax + 1, cmin : cmax + 1] if isinstance(selected_object, tuple): layer_indx, mask_id = selected_object obj_mask = mask[:, :, layer_indx] == mask_id @@ -90,25 +106,34 @@ class ZoomIn(DualTransform): new_mask = mask.copy() new_mask[np.logical_not(obj_mask)] = 0 - new_mask = F.resize(new_mask, height=self.height, width=self.width, - interpolation=cv2.INTER_NEAREST) + new_mask = F.resize( + new_mask, + height=self.height, + width=self.width, + interpolation=cv2.INTER_NEAREST, + ) return new_mask def get_params_dependent_on_targets(self, params): - instances = params['mask'] + instances = params["mask"] is_mask_layer = len(instances.shape) > 2 candidates = [] if is_mask_layer: for layer_indx in range(instances.shape[2]): labels, areas = get_labels_with_sizes(instances[:, :, layer_indx]) - candidates.extend([(layer_indx, obj_id) - for obj_id, area in zip(labels, areas) - if area > self.min_area]) + candidates.extend( + [ + (layer_indx, obj_id) + for obj_id, area in zip(labels, areas) + if area > self.min_area + ] + ) else: labels, areas = get_labels_with_sizes(instances) - candidates = [obj_id for obj_id, area in zip(labels, areas) - if area > self.min_area] + candidates = [ + obj_id for obj_id, area in zip(labels, areas) if area > self.min_area + ] selected_object = None bbox = None @@ -131,10 +156,7 @@ class ZoomIn(DualTransform): bbox = self._jitter_bbox(bbox) bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1) - return { - 'selected_object': selected_object, - 'bbox': bbox - } + return {"selected_object": selected_object, "bbox": bbox} def _jitter_bbox(self, bbox): rmin, rmax, cmin, cmax = bbox @@ -158,21 +180,28 @@ class ZoomIn(DualTransform): return ["mask"] def get_transform_init_args_names(self): - return ("height", "width", "bbox_jitter", - "expansion_ratio", "min_crop_size", "min_area", "always_resize") + return ( + "height", + "width", + "bbox_jitter", + "expansion_ratio", + "min_crop_size", + "min_area", + "always_resize", + ) def remove_image_only_transforms(sdict): - if not 'transforms' in sdict: + if not "transforms" in sdict: return sdict keep_transforms = [] - for tdict in sdict['transforms']: - cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']] - if 'transforms' in tdict: + for tdict in sdict["transforms"]: + cls = SERIALIZABLE_REGISTRY[tdict["__class_fullname__"]] + if "transforms" in tdict: keep_transforms.append(remove_image_only_transforms(tdict)) elif not issubclass(cls, ImageOnlyTransform): keep_transforms.append(tdict) - sdict['transforms'] = keep_transforms + sdict["transforms"] = keep_transforms return sdict diff --git a/isegm/engine/optimizer.py b/isegm/engine/optimizer.py index fd03d8cfc368ee6807fce420ad73e0024a5b6401..5dee7e8caf438cf8c5e09426d0ca97b365f0a466 100644 --- a/isegm/engine/optimizer.py +++ b/isegm/engine/optimizer.py @@ -1,27 +1,29 @@ -import torch import math + +import torch + from isegm.utils.log import logger def get_optimizer(model, opt_name, opt_kwargs): params = [] - base_lr = opt_kwargs['lr'] + base_lr = opt_kwargs["lr"] for name, param in model.named_parameters(): - param_group = {'params': [param]} + param_group = {"params": [param]} if not param.requires_grad: params.append(param_group) continue - if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): + if not math.isclose(getattr(param, "lr_mult", 1.0), 1.0): logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') - param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult + param_group["lr"] = param_group.get("lr", base_lr) * param.lr_mult params.append(param_group) optimizer = { - 'sgd': torch.optim.SGD, - 'adam': torch.optim.Adam, - 'adamw': torch.optim.AdamW + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "adamw": torch.optim.AdamW, }[opt_name.lower()](params, **opt_kwargs) return optimizer diff --git a/isegm/engine/trainer.py b/isegm/engine/trainer.py index ba56323dbc0e909ba0c48025bf620868ff90cfb9..87c5cb6f75b2d5d178f5219462b9271835b889dc 100644 --- a/isegm/engine/trainer.py +++ b/isegm/engine/trainer.py @@ -1,40 +1,48 @@ +import logging import os import random -import logging -from copy import deepcopy from collections import defaultdict +from copy import deepcopy import cv2 -import torch import numpy as np -from tqdm import tqdm +import torch from torch.utils.data import DataLoader +from tqdm import tqdm -from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg -from isegm.utils.vis import draw_probmap, draw_points +from isegm.utils.distributed import (get_dp_wrapper, get_sampler, + reduce_loss_dict) +from isegm.utils.log import SummaryWriterAvg, TqdmToLogger, logger from isegm.utils.misc import save_checkpoint from isegm.utils.serialization import get_config_repr -from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict +from isegm.utils.vis import draw_points, draw_probmap + from .optimizer import get_optimizer class ISTrainer(object): - def __init__(self, model, cfg, model_cfg, loss_cfg, - trainset, valset, - optimizer='adam', - optimizer_params=None, - image_dump_interval=200, - checkpoint_interval=10, - tb_dump_period=25, - max_interactive_points=0, - lr_scheduler=None, - metrics=None, - additional_val_metrics=None, - net_inputs=('images', 'points'), - max_num_next_clicks=0, - click_models=None, - prev_mask_drop_prob=0.0, - ): + def __init__( + self, + model, + cfg, + model_cfg, + loss_cfg, + trainset, + valset, + optimizer="adam", + optimizer_params=None, + image_dump_interval=200, + checkpoint_interval=10, + tb_dump_period=25, + max_interactive_points=0, + lr_scheduler=None, + metrics=None, + additional_val_metrics=None, + net_inputs=("images", "points"), + max_num_next_clicks=0, + click_models=None, + prev_mask_drop_prob=0.0, + ): self.cfg = cfg self.model_cfg = model_cfg self.max_interactive_points = max_interactive_points @@ -60,35 +68,44 @@ class ISTrainer(object): self.checkpoint_interval = checkpoint_interval self.image_dump_interval = image_dump_interval - self.task_prefix = '' + self.task_prefix = "" self.sw = None self.trainset = trainset self.valset = valset - logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.') - logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.') + logger.info( + f"Dataset of {trainset.get_samples_number()} samples was loaded for training." + ) + logger.info( + f"Dataset of {valset.get_samples_number()} samples was loaded for validation." + ) self.train_data = DataLoader( - trainset, cfg.batch_size, + trainset, + cfg.batch_size, sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), - drop_last=True, pin_memory=True, - num_workers=cfg.workers + drop_last=True, + pin_memory=True, + num_workers=cfg.workers, ) self.val_data = DataLoader( - valset, cfg.val_batch_size, + valset, + cfg.val_batch_size, sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), - drop_last=True, pin_memory=True, - num_workers=cfg.workers + drop_last=True, + pin_memory=True, + num_workers=cfg.workers, ) self.optim = get_optimizer(model, optimizer, optimizer_params) model = self._load_weights(model) if cfg.multi_gpu: - model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids, - output_device=cfg.gpu_ids[0]) + model = get_dp_wrapper(cfg.distributed)( + model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0] + ) if self.is_master: logger.info(model) @@ -96,7 +113,7 @@ class ISTrainer(object): self.device = cfg.device self.net = model.to(self.device) - self.lr = optimizer_params['lr'] + self.lr = optimizer_params["lr"] if lr_scheduler is not None: self.lr_scheduler = lr_scheduler(optimizer=self.optim) @@ -117,8 +134,8 @@ class ISTrainer(object): if start_epoch is None: start_epoch = self.cfg.start_epoch - logger.info(f'Starting Epoch: {start_epoch}') - logger.info(f'Total Epochs: {num_epochs}') + logger.info(f"Starting Epoch: {start_epoch}") + logger.info(f"Total Epochs: {num_epochs}") for epoch in range(start_epoch, num_epochs): self.training(epoch) if validation: @@ -126,15 +143,21 @@ class ISTrainer(object): def training(self, epoch): if self.sw is None and self.is_master: - self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), - flush_secs=10, dump_period=self.tb_dump_period) + self.sw = SummaryWriterAvg( + log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, + dump_period=self.tb_dump_period, + ) if self.cfg.distributed: self.train_data.sampler.set_epoch(epoch) - log_prefix = 'Train' + self.task_prefix.capitalize() - tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\ - if self.is_master else self.train_data + log_prefix = "Train" + self.task_prefix.capitalize() + tbar = ( + tqdm(self.train_data, file=self.tqdm_out, ncols=100) + if self.is_master + else self.train_data + ) for metric in self.train_metrics: metric.reset_epoch_stats() @@ -144,67 +167,109 @@ class ISTrainer(object): for i, batch_data in enumerate(tbar): global_step = epoch * len(self.train_data) + i - loss, losses_logging, splitted_batch_data, outputs = \ - self.batch_forward(batch_data) + loss, losses_logging, splitted_batch_data, outputs = self.batch_forward( + batch_data + ) self.optim.zero_grad() loss.backward() self.optim.step() - losses_logging['overall'] = loss + losses_logging["overall"] = loss reduce_loss_dict(losses_logging) - train_loss += losses_logging['overall'].item() + train_loss += losses_logging["overall"].item() if self.is_master: for loss_name, loss_value in losses_logging.items(): - self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', - value=loss_value.item(), - global_step=global_step) + self.sw.add_scalar( + tag=f"{log_prefix}Losses/{loss_name}", + value=loss_value.item(), + global_step=global_step, + ) for k, v in self.loss_cfg.items(): - if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: - v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) - - if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: - self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') - - self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', - value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], - global_step=global_step) - - tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}') + if ( + "_loss" in k + and hasattr(v, "log_states") + and self.loss_cfg.get(k + "_weight", 0.0) > 0 + ): + v.log_states(self.sw, f"{log_prefix}Losses/{k}", global_step) + + if ( + self.image_dump_interval > 0 + and global_step % self.image_dump_interval == 0 + ): + self.save_visualization( + splitted_batch_data, outputs, global_step, prefix="train" + ) + + self.sw.add_scalar( + tag=f"{log_prefix}States/learning_rate", + value=self.lr + if not hasattr(self, "lr_scheduler") + else self.lr_scheduler.get_lr()[-1], + global_step=global_step, + ) + + tbar.set_description( + f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}" + ) for metric in self.train_metrics: - metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + metric.log_states( + self.sw, f"{log_prefix}Metrics/{metric.name}", global_step + ) if self.is_master: for metric in self.train_metrics: - self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', - value=metric.get_epoch_value(), - global_step=epoch, disable_avg=True) - - save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, - epoch=None, multi_gpu=self.cfg.multi_gpu) + self.sw.add_scalar( + tag=f"{log_prefix}Metrics/{metric.name}", + value=metric.get_epoch_value(), + global_step=epoch, + disable_avg=True, + ) + + save_checkpoint( + self.net, + self.cfg.CHECKPOINTS_PATH, + prefix=self.task_prefix, + epoch=None, + multi_gpu=self.cfg.multi_gpu, + ) if isinstance(self.checkpoint_interval, (list, tuple)): - checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1] + checkpoint_interval = [ + x for x in self.checkpoint_interval if x[0] <= epoch + ][-1][1] else: checkpoint_interval = self.checkpoint_interval if epoch % checkpoint_interval == 0: - save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, - epoch=epoch, multi_gpu=self.cfg.multi_gpu) - - if hasattr(self, 'lr_scheduler'): + save_checkpoint( + self.net, + self.cfg.CHECKPOINTS_PATH, + prefix=self.task_prefix, + epoch=epoch, + multi_gpu=self.cfg.multi_gpu, + ) + + if hasattr(self, "lr_scheduler"): self.lr_scheduler.step() def validation(self, epoch): if self.sw is None and self.is_master: - self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), - flush_secs=10, dump_period=self.tb_dump_period) - - log_prefix = 'Val' + self.task_prefix.capitalize() - tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data + self.sw = SummaryWriterAvg( + log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, + dump_period=self.tb_dump_period, + ) + + log_prefix = "Val" + self.task_prefix.capitalize() + tbar = ( + tqdm(self.val_data, file=self.tqdm_out, ncols=100) + if self.is_master + else self.val_data + ) for metric in self.val_metrics: metric.reset_epoch_stats() @@ -215,29 +280,45 @@ class ISTrainer(object): self.net.eval() for i, batch_data in enumerate(tbar): global_step = epoch * len(self.val_data) + i - loss, batch_losses_logging, splitted_batch_data, outputs = \ - self.batch_forward(batch_data, validation=True) - - batch_losses_logging['overall'] = loss + ( + loss, + batch_losses_logging, + splitted_batch_data, + outputs, + ) = self.batch_forward(batch_data, validation=True) + + batch_losses_logging["overall"] = loss reduce_loss_dict(batch_losses_logging) for loss_name, loss_value in batch_losses_logging.items(): losses_logging[loss_name].append(loss_value.item()) - val_loss += batch_losses_logging['overall'].item() + val_loss += batch_losses_logging["overall"].item() if self.is_master: - tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}') + tbar.set_description( + f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}" + ) for metric in self.val_metrics: - metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + metric.log_states( + self.sw, f"{log_prefix}Metrics/{metric.name}", global_step + ) if self.is_master: for loss_name, loss_values in losses_logging.items(): - self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(), - global_step=epoch, disable_avg=True) + self.sw.add_scalar( + tag=f"{log_prefix}Losses/{loss_name}", + value=np.array(loss_values).mean(), + global_step=epoch, + disable_avg=True, + ) for metric in self.val_metrics: - self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), - global_step=epoch, disable_avg=True) + self.sw.add_scalar( + tag=f"{log_prefix}Metrics/{metric.name}", + value=metric.get_epoch_value(), + global_step=epoch, + disable_avg=True, + ) def batch_forward(self, batch_data, validation=False): metrics = self.val_metrics if validation else self.train_metrics @@ -245,8 +326,16 @@ class ISTrainer(object): with torch.set_grad_enabled(not validation): batch_data = {k: v.to(self.device) for k, v in batch_data.items()} - image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] - orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone() + image, gt_mask, points = ( + batch_data["images"], + batch_data["instances"], + batch_data["points"], + ) + orig_image, orig_gt_mask, orig_points = ( + image.clone(), + gt_mask.clone(), + points.clone(), + ) prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] @@ -261,44 +350,79 @@ class ISTrainer(object): if not validation: self.net.eval() - if self.click_models is None or click_indx >= len(self.click_models): + if self.click_models is None or click_indx >= len( + self.click_models + ): eval_model = self.net else: eval_model = self.click_models[click_indx] - net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image - prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) + net_input = ( + torch.cat((image, prev_output), dim=1) + if self.net.with_prev_mask + else image + ) + prev_output = torch.sigmoid( + eval_model(net_input, points)["instances"] + ) - points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) + points = get_next_points( + prev_output, orig_gt_mask, points, click_indx + 1 + ) if not validation: self.net.train() - if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None: - zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob + if ( + self.net.with_prev_mask + and self.prev_mask_drop_prob > 0 + and last_click_indx is not None + ): + zero_mask = ( + np.random.random(size=prev_output.size(0)) + < self.prev_mask_drop_prob + ) prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) - batch_data['points'] = points + batch_data["points"] = points - net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + net_input = ( + torch.cat((image, prev_output), dim=1) + if self.net.with_prev_mask + else image + ) output = self.net(net_input, points) loss = 0.0 - loss = self.add_loss('instance_loss', loss, losses_logging, validation, - lambda: (output['instances'], batch_data['instances'])) - loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation, - lambda: (output['instances_aux'], batch_data['instances'])) + loss = self.add_loss( + "instance_loss", + loss, + losses_logging, + validation, + lambda: (output["instances"], batch_data["instances"]), + ) + loss = self.add_loss( + "instance_aux_loss", + loss, + losses_logging, + validation, + lambda: (output["instances_aux"], batch_data["instances"]), + ) if self.is_master: with torch.no_grad(): for m in metrics: - m.update(*(output.get(x) for x in m.pred_outputs), - *(batch_data[x] for x in m.gt_outputs)) + m.update( + *(output.get(x) for x in m.pred_outputs), + *(batch_data[x] for x in m.gt_outputs), + ) return loss, losses_logging, batch_data, output - def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs): + def add_loss( + self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs + ): loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg - loss_weight = loss_cfg.get(loss_name + '_weight', 0.0) + loss_weight = loss_cfg.get(loss_name + "_weight", 0.0) if loss_weight > 0.0: loss_criterion = loss_cfg.get(loss_name) loss = loss_criterion(*lambda_loss_inputs()) @@ -316,18 +440,23 @@ class ISTrainer(object): if not output_images_path.exists(): output_images_path.mkdir(parents=True) - image_name_prefix = f'{global_step:06d}' + image_name_prefix = f"{global_step:06d}" def _save_image(suffix, image): - cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'), - image, [cv2.IMWRITE_JPEG_QUALITY, 85]) + cv2.imwrite( + str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"), + image, + [cv2.IMWRITE_JPEG_QUALITY, 85], + ) - images = splitted_batch_data['images'] - points = splitted_batch_data['points'] - instance_masks = splitted_batch_data['instances'] + images = splitted_batch_data["images"] + points = splitted_batch_data["points"] + instance_masks = splitted_batch_data["instances"] gt_instance_masks = instance_masks.cpu().numpy() - predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy() + predicted_instance_masks = ( + torch.sigmoid(outputs["instances"]).detach().cpu().numpy() + ) points = points.detach().cpu().numpy() image_blob, points = images[0], points[0] @@ -337,15 +466,21 @@ class ISTrainer(object): image = image_blob.cpu().numpy() * 255 image = image.transpose((1, 2, 0)) - image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0)) - image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255)) + image_with_points = draw_points( + image, points[: self.max_interactive_points], (0, 255, 0) + ) + image_with_points = draw_points( + image_with_points, points[self.max_interactive_points :], (0, 0, 255) + ) gt_mask[gt_mask < 0] = 0.25 gt_mask = draw_probmap(gt_mask) predicted_mask = draw_probmap(predicted_mask) - viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8) + viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype( + np.uint8 + ) - _save_image('instance_segmentation', viz_image[:, :, ::-1]) + _save_image("instance_segmentation", viz_image[:, :, ::-1]) def _load_weights(self, net): if self.cfg.weights is not None: @@ -355,11 +490,13 @@ class ISTrainer(object): else: raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") elif self.cfg.resume_exp is not None: - checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth')) + checkpoints = list( + self.cfg.CHECKPOINTS_PATH.glob(f"{self.cfg.resume_prefix}*.pth") + ) assert len(checkpoints) == 1 checkpoint_path = checkpoints[0] - logger.info(f'Load checkpoint from path: {checkpoint_path}') + logger.info(f"Load checkpoint from path: {checkpoint_path}") load_weights(net, str(checkpoint_path)) return net @@ -376,8 +513,8 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): fn_mask = np.logical_and(gt, pred < pred_thresh) fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) - fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) - fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) num_points = points.size(1) // 2 points = points.clone() @@ -408,6 +545,6 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): def load_weights(model, path_to_weights): current_state_dict = model.state_dict() - new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] + new_state_dict = torch.load(path_to_weights, map_location="cpu")["state_dict"] current_state_dict.update(new_state_dict) model.load_state_dict(current_state_dict) diff --git a/isegm/inference/clicker.py b/isegm/inference/clicker.py index 8789e117b139cd8f99914892022176b774698b2a..5fc731c9edfab79ba57ee45b84e8e89be33d100d 100644 --- a/isegm/inference/clicker.py +++ b/isegm/inference/clicker.py @@ -1,10 +1,13 @@ -import numpy as np from copy import deepcopy + import cv2 +import numpy as np class Clicker(object): - def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): + def __init__( + self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0 + ): self.click_indx_offset = click_indx_offset if gt_mask is not None: self.gt_mask = gt_mask == 1 @@ -27,12 +30,18 @@ class Clicker(object): return self.clicks_list[:clicks_limit] def _get_next_click(self, pred_mask, padding=True): - fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) - fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) + fn_mask = np.logical_and( + np.logical_and(self.gt_mask, np.logical_not(pred_mask)), + self.not_ignore_mask, + ) + fp_mask = np.logical_and( + np.logical_and(np.logical_not(self.gt_mask), pred_mask), + self.not_ignore_mask, + ) if padding: - fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') - fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) diff --git a/isegm/inference/evaluation.py b/isegm/inference/evaluation.py index ef46e40849a9890151bda1aa1e9b2b55814dfce6..dcf1aba638e239fb7d5c7e4ff1fead94ecc7943c 100644 --- a/isegm/inference/evaluation.py +++ b/isegm/inference/evaluation.py @@ -20,8 +20,9 @@ def evaluate_dataset(dataset, predictor, **kwargs): for index in tqdm(range(len(dataset)), leave=False): sample = dataset.get_sample(index) - _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, predictor, - sample_id=index, **kwargs) + _, sample_ious, _ = evaluate_sample( + sample.image, sample.gt_mask, predictor, sample_id=index, **kwargs + ) all_ious.append(sample_ious) end_time = time() elapsed_time = end_time - start_time @@ -29,9 +30,17 @@ def evaluate_dataset(dataset, predictor, **kwargs): return all_ious, elapsed_time -def evaluate_sample(image, gt_mask, predictor, max_iou_thr, - pred_thr=0.49, min_clicks=1, max_clicks=20, - sample_id=None, callback=None): +def evaluate_sample( + image, + gt_mask, + predictor, + max_iou_thr, + pred_thr=0.49, + min_clicks=1, + max_clicks=20, + sample_id=None, + callback=None, +): clicker = Clicker(gt_mask=gt_mask) pred_mask = np.zeros_like(gt_mask) ious_list = [] @@ -45,7 +54,14 @@ def evaluate_sample(image, gt_mask, predictor, max_iou_thr, pred_mask = pred_probs > pred_thr if callback is not None: - callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + callback( + image, + gt_mask, + pred_probs, + sample_id, + click_indx, + clicker.clicks_list, + ) iou = utils.get_iou(gt_mask, pred_mask) ious_list.append(iou) diff --git a/isegm/inference/predictors/__init__.py b/isegm/inference/predictors/__init__.py index 1e5a4f7b58fa6234c898d42f43e91a42669308cd..f6b4be9aec13b05adde44f44a2189f71a46b2be9 100644 --- a/isegm/inference/predictors/__init__.py +++ b/isegm/inference/predictors/__init__.py @@ -1,27 +1,31 @@ -from .base import BasePredictor -from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor -from .brs_functors import InputOptimizer, ScaleBiasOptimizer from isegm.inference.transforms import ZoomIn from isegm.model.is_hrnet_model import HRNetModel +from .base import BasePredictor +from .brs import (FeatureBRSPredictor, HRNetFeatureBRSPredictor, + InputBRSPredictor) +from .brs_functors import InputOptimizer, ScaleBiasOptimizer + -def get_predictor(net, brs_mode, device, - prob_thresh=0.49, - with_flip=True, - zoom_in_params=dict(), - predictor_params=None, - brs_opt_func_params=None, - lbfgs_params=None): +def get_predictor( + net, + brs_mode, + device, + prob_thresh=0.49, + with_flip=True, + zoom_in_params=dict(), + predictor_params=None, + brs_opt_func_params=None, + lbfgs_params=None, +): lbfgs_params_ = { - 'm': 20, - 'factr': 0, - 'pgtol': 1e-8, - 'maxfun': 20, + "m": 20, + "factr": 0, + "pgtol": 1e-8, + "maxfun": 20, } - predictor_params_ = { - 'optimize_after_n_clicks': 1 - } + predictor_params_ = {"optimize_after_n_clicks": 1} if zoom_in_params is not None: zoom_in = ZoomIn(**zoom_in_params) @@ -30,68 +34,86 @@ def get_predictor(net, brs_mode, device, if lbfgs_params is not None: lbfgs_params_.update(lbfgs_params) - lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] + lbfgs_params_["maxiter"] = 2 * lbfgs_params_["maxfun"] if brs_opt_func_params is None: brs_opt_func_params = dict() if isinstance(net, (list, tuple)): - assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." + assert brs_mode == "NoBRS", "Multi-stage models support only NoBRS mode." - if brs_mode == 'NoBRS': + if brs_mode == "NoBRS": if predictor_params is not None: predictor_params_.update(predictor_params) - predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) - elif brs_mode.startswith('f-BRS'): - predictor_params_.update({ - 'net_clicks_limit': 8, - }) + predictor = BasePredictor( + net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_ + ) + elif brs_mode.startswith("f-BRS"): + predictor_params_.update( + { + "net_clicks_limit": 8, + } + ) if predictor_params is not None: predictor_params_.update(predictor_params) insertion_mode = { - 'f-BRS-A': 'after_c4', - 'f-BRS-B': 'after_aspp', - 'f-BRS-C': 'after_deeplab' + "f-BRS-A": "after_c4", + "f-BRS-B": "after_aspp", + "f-BRS-C": "after_deeplab", }[brs_mode] - opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, - with_flip=with_flip, - optimizer_params=lbfgs_params_, - **brs_opt_func_params) + opt_functor = ScaleBiasOptimizer( + prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params + ) if isinstance(net, HRNetModel): FeaturePredictor = HRNetFeatureBRSPredictor - insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] + insertion_mode = {"after_c4": "A", "after_aspp": "A", "after_deeplab": "C"}[ + insertion_mode + ] else: FeaturePredictor = FeatureBRSPredictor - predictor = FeaturePredictor(net, device, - opt_functor=opt_functor, - with_flip=with_flip, - insertion_mode=insertion_mode, - zoom_in=zoom_in, - **predictor_params_) - elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': - use_dmaps = brs_mode == 'DistMap-BRS' - - predictor_params_.update({ - 'net_clicks_limit': 5, - }) + predictor = FeaturePredictor( + net, + device, + opt_functor=opt_functor, + with_flip=with_flip, + insertion_mode=insertion_mode, + zoom_in=zoom_in, + **predictor_params_ + ) + elif brs_mode == "RGB-BRS" or brs_mode == "DistMap-BRS": + use_dmaps = brs_mode == "DistMap-BRS" + + predictor_params_.update( + { + "net_clicks_limit": 5, + } + ) if predictor_params is not None: predictor_params_.update(predictor_params) - opt_functor = InputOptimizer(prob_thresh=prob_thresh, - with_flip=with_flip, - optimizer_params=lbfgs_params_, - **brs_opt_func_params) - - predictor = InputBRSPredictor(net, device, - optimize_target='dmaps' if use_dmaps else 'rgb', - opt_functor=opt_functor, - with_flip=with_flip, - zoom_in=zoom_in, - **predictor_params_) + opt_functor = InputOptimizer( + prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params + ) + + predictor = InputBRSPredictor( + net, + device, + optimize_target="dmaps" if use_dmaps else "rgb", + opt_functor=opt_functor, + with_flip=with_flip, + zoom_in=zoom_in, + **predictor_params_ + ) else: raise NotImplementedError diff --git a/isegm/inference/predictors/base.py b/isegm/inference/predictors/base.py index 870311726adfc0d5a6600f590f834a973dbefce0..55de2319cf5268e6c9edc636fd143b4057c1a2f4 100644 --- a/isegm/inference/predictors/base.py +++ b/isegm/inference/predictors/base.py @@ -1,16 +1,22 @@ import torch import torch.nn.functional as F from torchvision import transforms -from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide + +from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide, + SigmoidForPred) class BasePredictor(object): - def __init__(self, model, device, - net_clicks_limit=None, - with_flip=False, - zoom_in=None, - max_size=None, - **kwargs): + def __init__( + self, + model, + device, + net_clicks_limit=None, + with_flip=False, + zoom_in=None, + max_size=None, + **kwargs + ): self.with_flip = with_flip self.net_clicks_limit = net_clicks_limit self.original_image = None @@ -48,7 +54,12 @@ class BasePredictor(object): clicks_list = clicker.get_clicks() if self.click_models is not None: - model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1 + model_indx = ( + min( + clicker.click_indx_offset + len(clicks_list), len(self.click_models) + ) + - 1 + ) if model_indx != self.model_indx: self.model_indx = model_indx self.net = self.click_models[model_indx] @@ -56,15 +67,16 @@ class BasePredictor(object): input_image = self.original_image if prev_mask is None: prev_mask = self.prev_prediction - if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask: + if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask: input_image = torch.cat((input_image, prev_mask), dim=1) image_nd, clicks_lists, is_image_changed = self.apply_transforms( input_image, [clicks_list] ) pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) - prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, - size=image_nd.size()[2:]) + prediction = F.interpolate( + pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:] + ) for t in reversed(self.transforms): prediction = t.inv_transform(prediction) @@ -77,7 +89,7 @@ class BasePredictor(object): def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) - return self.net(image_nd, points_nd)['instances'] + return self.net(image_nd, points_nd)["instances"] def _get_transform_states(self): return [x.get_state() for x in self.transforms] @@ -97,30 +109,43 @@ class BasePredictor(object): def get_points_nd(self, clicks_lists): total_clicks = [] - num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] - num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_pos_clicks = [ + sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists + ] + num_neg_clicks = [ + len(clicks_list) - num_pos + for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks) + ] num_max_points = max(num_pos_clicks + num_neg_clicks) if self.net_clicks_limit is not None: num_max_points = min(self.net_clicks_limit, num_max_points) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: - clicks_list = clicks_list[:self.net_clicks_limit] - pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] - pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] - - neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] - neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + clicks_list = clicks_list[: self.net_clicks_limit] + pos_clicks = [ + click.coords_and_indx for click in clicks_list if click.is_positive + ] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [ + (-1, -1, -1) + ] + + neg_clicks = [ + click.coords_and_indx for click in clicks_list if not click.is_positive + ] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [ + (-1, -1, -1) + ] total_clicks.append(pos_clicks + neg_clicks) return torch.tensor(total_clicks, device=self.device) def get_states(self): return { - 'transform_states': self._get_transform_states(), - 'prev_prediction': self.prev_prediction.clone() + "transform_states": self._get_transform_states(), + "prev_prediction": self.prev_prediction.clone(), } def set_states(self, states): - self._set_transform_states(states['transform_states']) - self.prev_prediction = states['prev_prediction'] + self._set_transform_states(states["transform_states"]) + self.prev_prediction = states["prev_prediction"] diff --git a/isegm/inference/predictors/brs.py b/isegm/inference/predictors/brs.py index 910e3fd52471c39fe56668575765adcc00393d3d..a10746877b22527e50ac7a352f29492f8d8b4af5 100644 --- a/isegm/inference/predictors/brs.py +++ b/isegm/inference/predictors/brs.py @@ -1,6 +1,6 @@ +import numpy as np import torch import torch.nn.functional as F -import numpy as np from scipy.optimize import fmin_l_bfgs_b from .base import BasePredictor @@ -21,8 +21,12 @@ class BRSBasePredictor(BasePredictor): self.input_data = None def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): - pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) - neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + pos_clicks_map = np.zeros( + (len(clicks_lists), 1) + image_shape, dtype=np.float32 + ) + neg_clicks_map = np.zeros( + (len(clicks_lists), 1) + image_shape, dtype=np.float32 + ) for list_indx, clicks_list in enumerate(clicks_lists): for click in clicks_list: @@ -43,24 +47,29 @@ class BRSBasePredictor(BasePredictor): return pos_clicks_map, neg_clicks_map def get_states(self): - return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data} + return { + "transform_states": self._get_transform_states(), + "opt_data": self.opt_data, + } def set_states(self, states): - self._set_transform_states(states['transform_states']) - self.opt_data = states['opt_data'] + self._set_transform_states(states["transform_states"]) + self.opt_data = states["opt_data"] class FeatureBRSPredictor(BRSBasePredictor): - def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs): + def __init__( + self, model, device, opt_functor, insertion_mode="after_deeplab", **kwargs + ): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.insertion_mode = insertion_mode self._c1_features = None - if self.insertion_mode == 'after_deeplab': + if self.insertion_mode == "after_deeplab": self.num_channels = model.feature_extractor.ch - elif self.insertion_mode == 'after_c4': + elif self.insertion_mode == "after_c4": self.num_channels = model.feature_extractor.aspp_in_channels - elif self.insertion_mode == 'after_aspp': + elif self.insertion_mode == "after_aspp": self.num_channels = model.feature_extractor.ch + 32 else: raise NotImplementedError @@ -72,10 +81,17 @@ class FeatureBRSPredictor(BRSBasePredictor): num_clicks = len(clicks_lists[0]) bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] - if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + if ( + self.opt_data is None + or self.opt_data.shape[0] // (2 * self.num_channels) != bs + ): self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) - if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + if ( + num_clicks <= self.net_clicks_limit + or is_image_changed + or self.input_data is None + ): self.input_data = self._get_head_input(image_nd, points_nd) def get_prediction_logits(scale, bias): @@ -87,24 +103,39 @@ class FeatureBRSPredictor(BRSBasePredictor): scaled_backbone_features = self.input_data * scale scaled_backbone_features = scaled_backbone_features + bias - if self.insertion_mode == 'after_c4': + if self.insertion_mode == "after_c4": x = self.net.feature_extractor.aspp(scaled_backbone_features) - x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:], - align_corners=True) + x = F.interpolate( + x, + mode="bilinear", + size=self._c1_features.size()[2:], + align_corners=True, + ) x = torch.cat((x, self._c1_features), dim=1) scaled_backbone_features = self.net.feature_extractor.head(x) - elif self.insertion_mode == 'after_aspp': - scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features) + elif self.insertion_mode == "after_aspp": + scaled_backbone_features = self.net.feature_extractor.head( + scaled_backbone_features + ) pred_logits = self.net.head(scaled_backbone_features) - pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', - align_corners=True) + pred_logits = F.interpolate( + pred_logits, + size=image_nd.size()[2:], + mode="bilinear", + align_corners=True, + ) return pred_logits - self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + self.opt_functor.init_click( + get_prediction_logits, pos_mask, neg_mask, self.device + ) if num_clicks > self.optimize_after_n_clicks: - opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, - **self.opt_functor.optimizer_params) + opt_result = fmin_l_bfgs_b( + func=self.opt_functor, + x0=self.opt_data, + **self.opt_functor.optimizer_params + ) self.opt_data = opt_result[0] with torch.no_grad(): @@ -125,37 +156,45 @@ class FeatureBRSPredictor(BRSBasePredictor): if self.net.rgb_conv is not None: x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) additional_features = None - elif hasattr(self.net, 'maps_transform'): + elif hasattr(self.net, "maps_transform"): x = image_nd additional_features = self.net.maps_transform(coord_features) - if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp': - c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features) + if self.insertion_mode == "after_c4" or self.insertion_mode == "after_aspp": + c1, _, c3, c4 = self.net.feature_extractor.backbone( + x, additional_features + ) c1 = self.net.feature_extractor.skip_project(c1) - if self.insertion_mode == 'after_aspp': + if self.insertion_mode == "after_aspp": x = self.net.feature_extractor.aspp(c4) - x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True) + x = F.interpolate( + x, size=c1.size()[2:], mode="bilinear", align_corners=True + ) x = torch.cat((x, c1), dim=1) backbone_features = x else: backbone_features = c4 self._c1_features = c1 else: - backbone_features = self.net.feature_extractor(x, additional_features)[0] + backbone_features = self.net.feature_extractor(x, additional_features)[ + 0 + ] return backbone_features class HRNetFeatureBRSPredictor(BRSBasePredictor): - def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs): + def __init__(self, model, device, opt_functor, insertion_mode="A", **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.insertion_mode = insertion_mode self._c1_features = None - if self.insertion_mode == 'A': - self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8]) - elif self.insertion_mode == 'C': + if self.insertion_mode == "A": + self.num_channels = sum( + k * model.feature_extractor.width for k in [1, 2, 4, 8] + ) + elif self.insertion_mode == "C": self.num_channels = 2 * model.feature_extractor.ocr_width else: raise NotImplementedError @@ -166,10 +205,17 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor): num_clicks = len(clicks_lists[0]) bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] - if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + if ( + self.opt_data is None + or self.opt_data.shape[0] // (2 * self.num_channels) != bs + ): self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) - if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + if ( + num_clicks <= self.net_clicks_limit + or is_image_changed + or self.input_data is None + ): self.input_data = self._get_head_input(image_nd, points_nd) def get_prediction_logits(scale, bias): @@ -181,29 +227,44 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor): scaled_backbone_features = self.input_data * scale scaled_backbone_features = scaled_backbone_features + bias - if self.insertion_mode == 'A': + if self.insertion_mode == "A": if self.net.feature_extractor.ocr_width > 0: - out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features) - feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features) + out_aux = self.net.feature_extractor.aux_head( + scaled_backbone_features + ) + feats = self.net.feature_extractor.conv3x3_ocr( + scaled_backbone_features + ) context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) feats = self.net.feature_extractor.ocr_distri_head(feats, context) else: feats = scaled_backbone_features pred_logits = self.net.feature_extractor.cls_head(feats) - elif self.insertion_mode == 'C': - pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features) + elif self.insertion_mode == "C": + pred_logits = self.net.feature_extractor.cls_head( + scaled_backbone_features + ) else: raise NotImplementedError - pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', - align_corners=True) + pred_logits = F.interpolate( + pred_logits, + size=image_nd.size()[2:], + mode="bilinear", + align_corners=True, + ) return pred_logits - self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + self.opt_functor.init_click( + get_prediction_logits, pos_mask, neg_mask, self.device + ) if num_clicks > self.optimize_after_n_clicks: - opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, - **self.opt_functor.optimizer_params) + opt_result = fmin_l_bfgs_b( + func=self.opt_functor, + x0=self.opt_data, + **self.opt_functor.optimizer_params + ) self.opt_data = opt_result[0] with torch.no_grad(): @@ -224,20 +285,24 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor): if self.net.rgb_conv is not None: x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) additional_features = None - elif hasattr(self.net, 'maps_transform'): + elif hasattr(self.net, "maps_transform"): x = image_nd additional_features = self.net.maps_transform(coord_features) - feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features) + feats = self.net.feature_extractor.compute_hrnet_feats( + x, additional_features + ) - if self.insertion_mode == 'A': + if self.insertion_mode == "A": backbone_features = feats - elif self.insertion_mode == 'C': + elif self.insertion_mode == "C": out_aux = self.net.feature_extractor.aux_head(feats) feats = self.net.feature_extractor.conv3x3_ocr(feats) context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) - backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context) + backbone_features = self.net.feature_extractor.ocr_distri_head( + feats, context + ) else: raise NotImplementedError @@ -245,7 +310,7 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor): class InputBRSPredictor(BRSBasePredictor): - def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs): + def __init__(self, model, device, opt_functor, optimize_target="rgb", **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.optimize_target = optimize_target @@ -255,21 +320,28 @@ class InputBRSPredictor(BRSBasePredictor): num_clicks = len(clicks_lists[0]) if self.opt_data is None or is_image_changed: - if self.optimize_target == 'dmaps': - opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch + if self.optimize_target == "dmaps": + opt_channels = ( + self.net.coord_feature_ch - 1 + if self.net.with_prev_mask + else self.net.coord_feature_ch + ) else: opt_channels = 3 bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] - self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), - device=self.device, dtype=torch.float32) + self.opt_data = torch.zeros( + (bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), + device=self.device, + dtype=torch.float32, + ) def get_prediction_logits(opt_bias): input_image, prev_mask = self.net.prepare_input(image_nd) dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd) - if self.optimize_target == 'rgb': + if self.optimize_target == "rgb": input_image = input_image + opt_bias - elif self.optimize_target == 'dmaps': + elif self.optimize_target == "dmaps": if self.net.with_prev_mask: dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias else: @@ -277,25 +349,44 @@ class InputBRSPredictor(BRSBasePredictor): if self.net.rgb_conv is not None: x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) - if self.optimize_target == 'all': + if self.optimize_target == "all": x = x + opt_bias coord_features = None - elif hasattr(self.net, 'maps_transform'): + elif hasattr(self.net, "maps_transform"): x = input_image coord_features = self.net.maps_transform(dmaps) - pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances'] - pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) + pred_logits = self.net.backbone_forward(x, coord_features=coord_features)[ + "instances" + ] + pred_logits = F.interpolate( + pred_logits, + size=image_nd.size()[2:], + mode="bilinear", + align_corners=True, + ) return pred_logits - self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device, - shape=self.opt_data.shape) + self.opt_functor.init_click( + get_prediction_logits, + pos_mask, + neg_mask, + self.device, + shape=self.opt_data.shape, + ) if num_clicks > self.optimize_after_n_clicks: - opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), - **self.opt_functor.optimizer_params) - - self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device) + opt_result = fmin_l_bfgs_b( + func=self.opt_functor, + x0=self.opt_data.cpu().numpy().ravel(), + **self.opt_functor.optimizer_params + ) + + self.opt_data = ( + torch.from_numpy(opt_result[0]) + .view(self.opt_data.shape) + .to(self.device) + ) with torch.no_grad(): if self.opt_functor.best_prediction is not None: diff --git a/isegm/inference/predictors/brs_functors.py b/isegm/inference/predictors/brs_functors.py index f919e13c6c9edb6a9eb7c4afc37933db7b303c12..024c12f9419d5a33a7845700a2a90d4aab5ec429 100644 --- a/isegm/inference/predictors/brs_functors.py +++ b/isegm/inference/predictors/brs_functors.py @@ -1,19 +1,23 @@ -import torch import numpy as np +import torch from isegm.model.metrics import _compute_iou + from .brs_losses import BRSMaskLoss class BaseOptimizer: - def __init__(self, optimizer_params, - prob_thresh=0.49, - reg_weight=1e-3, - min_iou_diff=0.01, - brs_loss=BRSMaskLoss(), - with_flip=False, - flip_average=False, - **kwargs): + def __init__( + self, + optimizer_params, + prob_thresh=0.49, + reg_weight=1e-3, + min_iou_diff=0.01, + brs_loss=BRSMaskLoss(), + with_flip=False, + flip_average=False, + **kwargs + ): self.brs_loss = brs_loss self.optimizer_params = optimizer_params self.prob_thresh = prob_thresh @@ -51,7 +55,10 @@ class BaseOptimizer: if self.with_flip and self.flip_average: result, result_flipped = torch.chunk(result, 2, dim=0) result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) - pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] + pos_mask, neg_mask = ( + pos_mask[: result.shape[0]], + neg_mask[: result.shape[0]], + ) loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) loss = loss + reg_loss @@ -99,11 +106,13 @@ class ScaleBiasOptimizer(BaseOptimizer): def unpack_opt_params(self, opt_params): scale, bias = torch.chunk(opt_params, 2, dim=0) - reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) + reg_loss = self.reg_weight * ( + torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2) + ) - if self.scale_act == 'tanh': + if self.scale_act == "tanh": scale = torch.tanh(scale) - elif self.scale_act == 'sin': + elif self.scale_act == "sin": scale = torch.sin(scale) return (1 + scale, bias), reg_loss diff --git a/isegm/inference/predictors/brs_losses.py b/isegm/inference/predictors/brs_losses.py index ea98824356cf5a4d09094fb92c13ee8d8dfe15dc..d885027e48cb8a5d6bed3cfdd31838e6cf2a2668 100644 --- a/isegm/inference/predictors/brs_losses.py +++ b/isegm/inference/predictors/brs_losses.py @@ -10,13 +10,13 @@ class BRSMaskLoss(torch.nn.Module): def forward(self, result, pos_mask, neg_mask): pos_diff = (1 - result) * pos_mask - pos_target = torch.sum(pos_diff ** 2) + pos_target = torch.sum(pos_diff**2) pos_target = pos_target / (torch.sum(pos_mask) + self._eps) neg_diff = result * neg_mask - neg_target = torch.sum(neg_diff ** 2) + neg_target = torch.sum(neg_diff**2) neg_target = neg_target / (torch.sum(neg_mask) + self._eps) - + loss = pos_target + neg_target with torch.no_grad(): @@ -42,8 +42,10 @@ class OracleMaskLoss(torch.nn.Module): gt_mask = self.gt_mask.to(result.device) if self.predictor.object_roi is not None: r1, r2, c1, c2 = self.predictor.object_roi[:4] - gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] - gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) + gt_mask = gt_mask[:, :, r1 : r2 + 1, c1 : c2 + 1] + gt_mask = torch.nn.functional.interpolate( + gt_mask, result.size()[2:], mode="bilinear", align_corners=True + ) if result.shape[0] == 2: gt_mask_flipped = torch.flip(gt_mask, dims=[3]) diff --git a/isegm/inference/transforms/__init__.py b/isegm/inference/transforms/__init__.py index cbd54e38a2f84b3fef481672a7ceab070eb01b82..1c140a4ec3a48b14c36b8be672fbc787105d1171 100644 --- a/isegm/inference/transforms/__init__.py +++ b/isegm/inference/transforms/__init__.py @@ -1,5 +1,5 @@ from .base import SigmoidForPred +from .crops import Crops from .flip import AddHorizontalFlip -from .zoom_in import ZoomIn from .limit_longest_side import LimitLongestSide -from .crops import Crops +from .zoom_in import ZoomIn diff --git a/isegm/inference/transforms/crops.py b/isegm/inference/transforms/crops.py index 428d977295e2ff973b5aa1bf0a0c955df1235614..20f7fc887aaf191ebec225c5bea6d5a175be6c75 100644 --- a/isegm/inference/transforms/crops.py +++ b/isegm/inference/transforms/crops.py @@ -1,10 +1,11 @@ import math +from typing import List -import torch import numpy as np -from typing import List +import torch from isegm.inference.clicker import Click + from .base import BaseTransform @@ -33,17 +34,24 @@ class Crops(BaseTransform): image_crops = [] for dy in self.y_offsets: for dx in self.x_offsets: - self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 - image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] + self._counts[dy : dy + self.crop_height, dx : dx + self.crop_width] += 1 + image_crop = image_nd[ + :, :, dy : dy + self.crop_height, dx : dx + self.crop_width + ] image_crops.append(image_crop) image_crops = torch.cat(image_crops, dim=0) - self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) + self._counts = torch.tensor( + self._counts, device=image_nd.device, dtype=torch.float32 + ) clicks_list = clicks_lists[0] clicks_lists = [] for dy in self.y_offsets: for dx in self.x_offsets: - crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] + crop_clicks = [ + x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) + for x in clicks_list + ] clicks_lists.append(crop_clicks) return image_crops, clicks_lists @@ -52,13 +60,16 @@ class Crops(BaseTransform): if self._counts is None: return prob_map - new_prob_map = torch.zeros((1, 1, *self._counts.shape), - dtype=prob_map.dtype, device=prob_map.device) + new_prob_map = torch.zeros( + (1, 1, *self._counts.shape), dtype=prob_map.dtype, device=prob_map.device + ) crop_indx = 0 for dy in self.y_offsets: for dx in self.x_offsets: - new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] + new_prob_map[ + 0, 0, dy : dy + self.crop_height, dx : dx + self.crop_width + ] += prob_map[crop_indx, 0] crop_indx += 1 new_prob_map = torch.div(new_prob_map, self._counts) diff --git a/isegm/inference/transforms/flip.py b/isegm/inference/transforms/flip.py index 373640ebe153ae8a53c136c72f13e0c14aa788ec..485cf41aed48b232d4c46c7d8c44234d35fea6c9 100644 --- a/isegm/inference/transforms/flip.py +++ b/isegm/inference/transforms/flip.py @@ -1,7 +1,9 @@ +from typing import List + import torch -from typing import List from isegm.inference.clicker import Click + from .base import BaseTransform @@ -13,8 +15,10 @@ class AddHorizontalFlip(BaseTransform): image_width = image_nd.shape[3] clicks_lists_flipped = [] for clicks_list in clicks_lists: - clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) - for click in clicks_list] + clicks_list_flipped = [ + click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) + for click in clicks_list + ] clicks_lists_flipped.append(clicks_list_flipped) clicks_lists = clicks_lists + clicks_lists_flipped diff --git a/isegm/inference/transforms/zoom_in.py b/isegm/inference/transforms/zoom_in.py index 04b576a3e351aa7ad723fd447b309615648bc55d..fd86f294f0f3846dc8dc658224560554d34ece0a 100644 --- a/isegm/inference/transforms/zoom_in.py +++ b/isegm/inference/transforms/zoom_in.py @@ -1,19 +1,24 @@ +from typing import List + import torch -from typing import List from isegm.inference.clicker import Click -from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox +from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask, + get_bbox_iou) + from .base import BaseTransform class ZoomIn(BaseTransform): - def __init__(self, - target_size=400, - skip_clicks=1, - expansion_ratio=1.4, - min_crop_size=200, - recompute_thresh_iou=0.5, - prob_thresh=0.50): + def __init__( + self, + target_size=400, + skip_clicks=1, + expansion_ratio=1.4, + min_crop_size=200, + recompute_thresh_iou=0.5, + prob_thresh=0.50, + ): super().__init__() self.target_size = target_size self.min_crop_size = min_crop_size @@ -41,8 +46,12 @@ class ZoomIn(BaseTransform): if self._prev_probs is not None: current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] if current_pred_mask.sum() > 0: - current_object_roi = get_object_roi(current_pred_mask, clicks_list, - self.expansion_ratio, self.min_crop_size) + current_object_roi = get_object_roi( + current_pred_mask, + clicks_list, + self.expansion_ratio, + self.min_crop_size, + ) if current_object_roi is None: if self.skip_clicks >= 0: @@ -55,7 +64,10 @@ class ZoomIn(BaseTransform): update_object_roi = True elif not check_object_roi(self._object_roi, clicks_list): update_object_roi = True - elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: + elif ( + get_bbox_iou(current_object_roi, self._object_roi) + < self.recompute_thresh_iou + ): update_object_roi = True if update_object_roi: @@ -73,12 +85,18 @@ class ZoomIn(BaseTransform): assert prob_map.shape[0] == 1 rmin, rmax, cmin, cmax = self._object_roi - prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), - mode='bilinear', align_corners=True) + prob_map = torch.nn.functional.interpolate( + prob_map, + size=(rmax - rmin + 1, cmax - cmin + 1), + mode="bilinear", + align_corners=True, + ) if self._prev_probs is not None: - new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) - new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map + new_prob_map = torch.zeros( + *self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype + ) + new_prob_map[:, :, rmin : rmax + 1, cmin : cmax + 1] = prob_map else: new_prob_map = prob_map @@ -87,24 +105,46 @@ class ZoomIn(BaseTransform): return new_prob_map def check_possible_recalculation(self): - if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: + if ( + self._prev_probs is None + or self._object_roi is not None + or self.skip_clicks > 0 + ): return False pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] if pred_mask.sum() > 0: - possible_object_roi = get_object_roi(pred_mask, [], - self.expansion_ratio, self.min_crop_size) - image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) + possible_object_roi = get_object_roi( + pred_mask, [], self.expansion_ratio, self.min_crop_size + ) + image_roi = ( + 0, + self._input_image_shape[2] - 1, + 0, + self._input_image_shape[3] - 1, + ) if get_bbox_iou(possible_object_roi, image_roi) < 0.50: return True return False def get_state(self): roi_image = self._roi_image.cpu() if self._roi_image is not None else None - return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed + return ( + self._input_image_shape, + self._object_roi, + self._prev_probs, + roi_image, + self.image_changed, + ) def set_state(self, state): - self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state + ( + self._input_image_shape, + self._object_roi, + self._prev_probs, + self._roi_image, + self.image_changed, + ) = state def reset(self): self._input_image_shape = None @@ -157,9 +197,13 @@ def get_roi_image_nd(image_nd, object_roi, target_size): new_width = int(round(width * scale)) with torch.no_grad(): - roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] - roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), - mode='bilinear', align_corners=True) + roi_image_nd = image_nd[:, :, rmin : rmax + 1, cmin : cmax + 1] + roi_image_nd = torch.nn.functional.interpolate( + roi_image_nd, + size=(new_height, new_width), + mode="bilinear", + align_corners=True, + ) return roi_image_nd diff --git a/isegm/inference/utils.py b/isegm/inference/utils.py index 7102d4075c9812de6a26b21a8a8946c44c3ddb3f..62826e6d6d85f42f9c54ab20636ef37100d732c8 100644 --- a/isegm/inference/utils.py +++ b/isegm/inference/utils.py @@ -1,10 +1,11 @@ from datetime import timedelta from pathlib import Path -import torch import numpy as np +import torch -from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset +from isegm.data.datasets import (BerkeleyDataset, DavisDataset, GrabCutDataset, + PascalVocDataset, SBDEvaluationDataset) from isegm.utils.serialization import load_model @@ -20,7 +21,7 @@ def get_time_metrics(all_ious, elapsed_time): def load_is_model(checkpoint, device, **kwargs): if isinstance(checkpoint, (str, Path)): - state_dict = torch.load(checkpoint, map_location='cpu') + state_dict = torch.load(checkpoint, map_location="cpu") else: state_dict = checkpoint @@ -34,8 +35,8 @@ def load_is_model(checkpoint, device, **kwargs): def load_single_is_model(state_dict, device, **kwargs): - model = load_model(state_dict['config'], **kwargs) - model.load_state_dict(state_dict['state_dict'], strict=False) + model = load_model(state_dict["config"], **kwargs) + model.load_state_dict(state_dict["state_dict"], strict=False) for param in model.parameters(): param.requires_grad = False @@ -46,19 +47,19 @@ def load_single_is_model(state_dict, device, **kwargs): def get_dataset(dataset_name, cfg): - if dataset_name == 'GrabCut': + if dataset_name == "GrabCut": dataset = GrabCutDataset(cfg.GRABCUT_PATH) - elif dataset_name == 'Berkeley': + elif dataset_name == "Berkeley": dataset = BerkeleyDataset(cfg.BERKELEY_PATH) - elif dataset_name == 'DAVIS': + elif dataset_name == "DAVIS": dataset = DavisDataset(cfg.DAVIS_PATH) - elif dataset_name == 'SBD': + elif dataset_name == "SBD": dataset = SBDEvaluationDataset(cfg.SBD_PATH) - elif dataset_name == 'SBD_Train': - dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train') - elif dataset_name == 'PascalVOC': - dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test') - elif dataset_name == 'COCO_MVal': + elif dataset_name == "SBD_Train": + dataset = SBDEvaluationDataset(cfg.SBD_PATH, split="train") + elif dataset_name == "PascalVOC": + dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split="test") + elif dataset_name == "COCO_MVal": dataset = DavisDataset(cfg.COCO_MVAL_PATH) else: dataset = None @@ -70,8 +71,12 @@ def get_iou(gt_mask, pred_mask, ignore_label=-1): ignore_gt_mask_inv = gt_mask != ignore_label obj_gt_mask = gt_mask == 1 - intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() - union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + intersection = np.logical_and( + np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv + ).sum() + union = np.logical_and( + np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv + ).sum() return intersection / union @@ -84,8 +89,9 @@ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): noc_list = [] over_max_list = [] for iou_thr in iou_thrs: - scores_arr = np.array([_get_noc(iou_arr, iou_thr) - for iou_arr in all_ious], dtype=np.int) + scores_arr = np.array( + [_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int + ) score = scores_arr.mean() over_max = (scores_arr == max_clicks).sum() @@ -98,46 +104,58 @@ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): def find_checkpoint(weights_folder, checkpoint_name): weights_folder = Path(weights_folder) - if ':' in checkpoint_name: - model_name, checkpoint_name = checkpoint_name.split(':') - models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] + if ":" in checkpoint_name: + model_name, checkpoint_name = checkpoint_name.split(":") + models_candidates = [ + x for x in weights_folder.glob(f"{model_name}*") if x.is_dir() + ] assert len(models_candidates) == 1 model_folder = models_candidates[0] else: model_folder = weights_folder - if checkpoint_name.endswith('.pth'): + if checkpoint_name.endswith(".pth"): if Path(checkpoint_name).exists(): checkpoint_path = checkpoint_name else: checkpoint_path = weights_folder / checkpoint_name else: - model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) + model_checkpoints = list(model_folder.rglob(f"{checkpoint_name}*.pth")) assert len(model_checkpoints) == 1 checkpoint_path = model_checkpoints[0] return str(checkpoint_path) -def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, - n_clicks=20, model_name=None): - table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' - f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' - f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' - f'{"SPC,s":^7}|{"Time":^9}|') +def get_results_table( + noc_list, + over_max_list, + brs_type, + dataset_name, + mean_spc, + elapsed_time, + n_clicks=20, + model_name=None, +): + table_header = ( + f'|{"BRS Type":^13}|{"Dataset":^11}|' + f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' + f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' + f'{"SPC,s":^7}|{"Time":^9}|' + ) row_width = len(table_header) - header = f'Eval results for model: {model_name}\n' if model_name is not None else '' - header += '-' * row_width + '\n' - header += table_header + '\n' + '-' * row_width + header = f"Eval results for model: {model_name}\n" if model_name is not None else "" + header += "-" * row_width + "\n" + header += table_header + "\n" + "-" * row_width eval_time = str(timedelta(seconds=int(elapsed_time))) - table_row = f'|{brs_type:^13}|{dataset_name:^11}|' - table_row += f'{noc_list[0]:^9.2f}|' - table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' - table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' - table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' - table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' - table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' - - return header, table_row \ No newline at end of file + table_row = f"|{brs_type:^13}|{dataset_name:^11}|" + table_row += f"{noc_list[0]:^9.2f}|" + table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|" + + return header, table_row diff --git a/isegm/model/initializer.py b/isegm/model/initializer.py index 470c7df4659bc1e80ceec80a170b3b2e0302fb84..89d0e8856d6a434f08842aa73fe44dc3b54de597 100644 --- a/isegm/model/initializer.py +++ b/isegm/model/initializer.py @@ -1,6 +1,6 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np class Initializer(object): @@ -9,24 +9,37 @@ class Initializer(object): self.gamma = gamma def __call__(self, m): - if getattr(m, '__initialized', False): + if getattr(m, "__initialized", False): return - if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, - nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, - nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: + if ( + isinstance( + m, + ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.GroupNorm, + nn.SyncBatchNorm, + ), + ) + or "BatchNorm" in m.__class__.__name__ + ): if m.weight is not None: self._init_gamma(m.weight.data) if m.bias is not None: self._init_beta(m.bias.data) else: - if getattr(m, 'weight', None) is not None: + if getattr(m, "weight", None) is not None: self._init_weight(m.weight.data) - if getattr(m, 'bias', None) is not None: + if getattr(m, "bias", None) is not None: self._init_bias(m.bias.data) if self.local_init: - object.__setattr__(m, '__initialized', True) + object.__setattr__(m, "__initialized", True) def _init_weight(self, data): nn.init.uniform_(data, -0.07, 0.07) @@ -71,13 +84,15 @@ class Bilinear(Initializer): center = scale - 0.5 * (1 + kernel_size % 2) og = np.ogrid[:kernel_size, :kernel_size] - kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) + kernel = (1 - np.abs(og[0] - center) / scale) * ( + 1 - np.abs(og[1] - center) / scale + ) return torch.tensor(kernel, dtype=torch.float32) class XavierGluon(Initializer): - def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): + def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3, **kwargs): super().__init__(**kwargs) self.rnd_type = rnd_type @@ -87,19 +102,19 @@ class XavierGluon(Initializer): def _init_weight(self, arr): fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) - if self.factor_type == 'avg': + if self.factor_type == "avg": factor = (fan_in + fan_out) / 2.0 - elif self.factor_type == 'in': + elif self.factor_type == "in": factor = fan_in - elif self.factor_type == 'out': + elif self.factor_type == "out": factor = fan_out else: - raise ValueError('Incorrect factor type') + raise ValueError("Incorrect factor type") scale = np.sqrt(self.magnitude / factor) - if self.rnd_type == 'uniform': + if self.rnd_type == "uniform": nn.init.uniform_(arr, -scale, scale) - elif self.rnd_type == 'gaussian': + elif self.rnd_type == "gaussian": nn.init.normal_(arr, 0, scale) else: - raise ValueError('Unknown random type') + raise ValueError("Unknown random type") diff --git a/isegm/model/is_deeplab_model.py b/isegm/model/is_deeplab_model.py index 45fa55364d14d129889fce083a791be1e48a35c9..79fc9c2b4156454b5f352353260ec86ee22a6c64 100644 --- a/isegm/model/is_deeplab_model.py +++ b/isegm/model/is_deeplab_model.py @@ -1,25 +1,44 @@ import torch.nn as nn +from isegm.model.modifiers import LRMult from isegm.utils.serialization import serialize + from .is_model import ISModel -from .modeling.deeplab_v3 import DeepLabV3Plus from .modeling.basic_blocks import SepConvHead -from isegm.model.modifiers import LRMult +from .modeling.deeplab_v3 import DeepLabV3Plus class DeeplabModel(ISModel): @serialize - def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, - backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs): + def __init__( + self, + backbone="resnet50", + deeplab_ch=256, + aspp_dropout=0.5, + backbone_norm_layer=None, + backbone_lr_mult=0.1, + norm_layer=nn.BatchNorm2d, + **kwargs + ): super().__init__(norm_layer=norm_layer, **kwargs) - self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout, - norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer) + self.feature_extractor = DeepLabV3Plus( + backbone=backbone, + ch=deeplab_ch, + project_dropout=aspp_dropout, + norm_layer=norm_layer, + backbone_norm_layer=backbone_norm_layer, + ) self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult)) - self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, - num_layers=2, norm_layer=norm_layer) + self.head = SepConvHead( + 1, + in_channels=deeplab_ch, + mid_channels=deeplab_ch // 2, + num_layers=2, + norm_layer=norm_layer, + ) def backbone_forward(self, image, coord_features=None): backbone_features = self.feature_extractor(image, coord_features) - return {'instances': self.head(backbone_features[0])} + return {"instances": self.head(backbone_features[0])} diff --git a/isegm/model/is_hrnet_model.py b/isegm/model/is_hrnet_model.py index b8a82e746adf49e44d7ff011bef3c7cb105ae4cb..bf6d76152fa4aaf90b1f373115e8e6bc86732b04 100644 --- a/isegm/model/is_hrnet_model.py +++ b/isegm/model/is_hrnet_model.py @@ -1,19 +1,32 @@ import torch.nn as nn +from isegm.model.modifiers import LRMult from isegm.utils.serialization import serialize + from .is_model import ISModel from .modeling.hrnet_ocr import HighResolutionNet -from isegm.model.modifiers import LRMult class HRNetModel(ISModel): @serialize - def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1, - norm_layer=nn.BatchNorm2d, **kwargs): + def __init__( + self, + width=48, + ocr_width=256, + small=False, + backbone_lr_mult=0.1, + norm_layer=nn.BatchNorm2d, + **kwargs + ): super().__init__(norm_layer=norm_layer, **kwargs) - self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small, - num_classes=1, norm_layer=norm_layer) + self.feature_extractor = HighResolutionNet( + width=width, + ocr_width=ocr_width, + small=small, + num_classes=1, + norm_layer=norm_layer, + ) self.feature_extractor.apply(LRMult(backbone_lr_mult)) if ocr_width > 0: self.feature_extractor.ocr_distri_head.apply(LRMult(1.0)) @@ -23,4 +36,4 @@ class HRNetModel(ISModel): def backbone_forward(self, image, coord_features=None): net_outputs = self.feature_extractor(image, coord_features) - return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]} + return {"instances": net_outputs[0], "instances_aux": net_outputs[1]} diff --git a/isegm/model/is_model.py b/isegm/model/is_model.py index f6555401b15a0f72c252745da726beaa602e6231..ef1bf8b9398272039281f9d40b59992a793eee6b 100644 --- a/isegm/model/is_model.py +++ b/isegm/model/is_model.py @@ -1,17 +1,27 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np -from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize from isegm.model.modifiers import LRMult +from isegm.model.ops import BatchImageNormalize, DistMaps, ScaleLayer class ISModel(nn.Module): - def __init__(self, use_rgb_conv=True, with_aux_output=False, - norm_radius=260, use_disks=False, cpu_dist_maps=False, - clicks_groups=None, with_prev_mask=False, use_leaky_relu=False, - binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d, - norm_mean_std=([.485, .456, .406], [.229, .224, .225])): + def __init__( + self, + use_rgb_conv=True, + with_aux_output=False, + norm_radius=260, + use_disks=False, + cpu_dist_maps=False, + clicks_groups=None, + with_prev_mask=False, + use_leaky_relu=False, + binary_prev_mask=False, + conv_extend=False, + norm_layer=nn.BatchNorm2d, + norm_mean_std=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ): super().__init__() self.with_aux_output = with_aux_output self.clicks_groups = clicks_groups @@ -28,35 +38,64 @@ class ISModel(nn.Module): if use_rgb_conv: rgb_conv_layers = [ - nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1), + nn.Conv2d( + in_channels=3 + self.coord_feature_ch, + out_channels=6 + self.coord_feature_ch, + kernel_size=1, + ), norm_layer(6 + self.coord_feature_ch), - nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), - nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1) + nn.LeakyReLU(negative_slope=0.2) + if use_leaky_relu + else nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1 + ), ] self.rgb_conv = nn.Sequential(*rgb_conv_layers) elif conv_extend: self.rgb_conv = None - self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64, - kernel_size=3, stride=2, padding=1) + self.maps_transform = nn.Conv2d( + in_channels=self.coord_feature_ch, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + ) self.maps_transform.apply(LRMult(0.1)) else: self.rgb_conv = None mt_layers = [ - nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1), - nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), - nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1), - ScaleLayer(init_value=0.05, lr_mult=1) + nn.Conv2d( + in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1 + ), + nn.LeakyReLU(negative_slope=0.2) + if use_leaky_relu + else nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1 + ), + ScaleLayer(init_value=0.05, lr_mult=1), ] self.maps_transform = nn.Sequential(*mt_layers) if self.clicks_groups is not None: self.dist_maps = nn.ModuleList() for click_radius in self.clicks_groups: - self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0, - cpu_mode=cpu_dist_maps, use_disks=use_disks)) + self.dist_maps.append( + DistMaps( + norm_radius=click_radius, + spatial_scale=1.0, + cpu_mode=cpu_dist_maps, + use_disks=use_disks, + ) + ) else: - self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, - cpu_mode=cpu_dist_maps, use_disks=use_disks) + self.dist_maps = DistMaps( + norm_radius=norm_radius, + spatial_scale=1.0, + cpu_mode=cpu_dist_maps, + use_disks=use_disks, + ) def forward(self, image, points): image, prev_mask = self.prepare_input(image) @@ -69,11 +108,19 @@ class ISModel(nn.Module): coord_features = self.maps_transform(coord_features) outputs = self.backbone_forward(image, coord_features) - outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:], - mode='bilinear', align_corners=True) + outputs["instances"] = nn.functional.interpolate( + outputs["instances"], + size=image.size()[2:], + mode="bilinear", + align_corners=True, + ) if self.with_aux_output: - outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:], - mode='bilinear', align_corners=True) + outputs["instances_aux"] = nn.functional.interpolate( + outputs["instances_aux"], + size=image.size()[2:], + mode="bilinear", + align_corners=True, + ) return outputs @@ -93,8 +140,13 @@ class ISModel(nn.Module): def get_coord_features(self, image, prev_mask, points): if self.clicks_groups is not None: - points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,)) - coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)] + points_groups = split_points_by_order( + points, groups=(2,) + (1,) * (len(self.clicks_groups) - 2) + (-1,) + ) + coord_features = [ + dist_map(image, pg) + for dist_map, pg in zip(self.dist_maps, points_groups) + ] coord_features = torch.cat(coord_features, dim=1) else: coord_features = self.dist_maps(image, points) @@ -112,8 +164,7 @@ def split_points_by_order(tpoints: torch.Tensor, groups): num_points = points.shape[1] // 2 groups = [x if x > 0 else num_points for x in groups] - group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) - for x in groups] + group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) for x in groups] last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int) for group_indx, group_size in enumerate(groups): @@ -127,7 +178,9 @@ def split_points_by_order(tpoints: torch.Tensor, groups): continue is_negative = int(pindx >= num_points) - if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click + if group_id >= num_groups or ( + group_id == 0 and is_negative + ): # disable negative first click group_id = num_groups - 1 new_point_indx = last_point_indx_group[bindx, group_id, is_negative] @@ -135,7 +188,9 @@ def split_points_by_order(tpoints: torch.Tensor, groups): group_points[group_id][bindx, new_point_indx, :] = point - group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) - for x in group_points] + group_points = [ + torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) + for x in group_points + ] return group_points diff --git a/isegm/model/losses.py b/isegm/model/losses.py index b90f18f31e7718cf6c79a267be0ccb0d99797325..95fb24f5aaf052d19e6be094bafa013d56e7d93e 100644 --- a/isegm/model/losses.py +++ b/isegm/model/losses.py @@ -7,10 +7,20 @@ from isegm.utils import misc class NormalizedFocalLossSigmoid(nn.Module): - def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, - from_sigmoid=False, detach_delimeter=True, - batch_axis=0, weight=None, size_average=True, - ignore_label=-1): + def __init__( + self, + axis=-1, + alpha=0.25, + gamma=2, + max_mult=-1, + eps=1e-12, + from_sigmoid=False, + detach_delimeter=True, + batch_axis=0, + weight=None, + size_average=True, + ignore_label=-1, + ): super(NormalizedFocalLossSigmoid, self).__init__() self._axis = axis self._alpha = alpha @@ -34,8 +44,12 @@ class NormalizedFocalLossSigmoid(nn.Module): if not self._from_logits: pred = torch.sigmoid(pred) - alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) - pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + alpha = torch.where( + one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight + ) + pt = torch.where( + sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred) + ) beta = (1 - pt) ** self._gamma @@ -49,37 +63,69 @@ class NormalizedFocalLossSigmoid(nn.Module): beta = torch.clamp_max(beta, self._max_mult) with torch.no_grad(): - ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() - sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() + ignore_area = ( + torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))) + .cpu() + .numpy() + ) + sample_mult = ( + torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() + ) if np.any(ignore_area == 0): - self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() + self._k_sum = ( + 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() + ) beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) beta_pmax = beta_pmax.mean().item() self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax - loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = ( + -alpha + * beta + * torch.log( + torch.min( + pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device) + ) + ) + ) loss = self._weight * (loss * sample_weight) if self._size_average: - bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) - loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) + bsum = torch.sum( + sample_weight, + dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis), + ) + loss = torch.sum( + loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) + ) / (bsum + self._eps) else: - loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + loss = torch.sum( + loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) + ) return loss def log_states(self, sw, name, global_step): - sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) - sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) + sw.add_scalar(tag=name + "_k", value=self._k_sum, global_step=global_step) + sw.add_scalar(tag=name + "_m", value=self._m_max, global_step=global_step) class FocalLoss(nn.Module): - def __init__(self, axis=-1, alpha=0.25, gamma=2, - from_logits=False, batch_axis=0, - weight=None, num_class=None, - eps=1e-9, size_average=True, scale=1.0, - ignore_label=-1): + def __init__( + self, + axis=-1, + alpha=0.25, + gamma=2, + from_logits=False, + batch_axis=0, + weight=None, + num_class=None, + eps=1e-9, + size_average=True, + scale=1.0, + ignore_label=-1, + ): super(FocalLoss, self).__init__() self._axis = axis self._alpha = alpha @@ -101,19 +147,38 @@ class FocalLoss(nn.Module): if not self._from_logits: pred = torch.sigmoid(pred) - alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) - pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + alpha = torch.where( + one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight + ) + pt = torch.where( + sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred) + ) beta = (1 - pt) ** self._gamma - loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = ( + -alpha + * beta + * torch.log( + torch.min( + pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device) + ) + ) + ) loss = self._weight * (loss * sample_weight) if self._size_average: - tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) - loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) + tsum = torch.sum( + sample_weight, + dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis), + ) + loss = torch.sum( + loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) + ) / (tsum + self._eps) else: - loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + loss = torch.sum( + loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) + ) return self._scale * loss @@ -131,8 +196,9 @@ class SoftIoU(nn.Module): if not self._from_sigmoid: pred = torch.sigmoid(pred) - loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \ - / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8) + loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) / ( + torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8 + ) return loss @@ -154,8 +220,12 @@ class SigmoidBinaryCrossEntropyLoss(nn.Module): loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) else: eps = 1e-12 - loss = -(torch.log(pred + eps) * label - + torch.log(1. - pred + eps) * (1. - label)) + loss = -( + torch.log(pred + eps) * label + + torch.log(1.0 - pred + eps) * (1.0 - label) + ) loss = self._weight * (loss * sample_weight) - return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + return torch.mean( + loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) + ) diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py index a572dcd97ed2dac222fa51a33657aa5b403dbb2a..9f2e889ee18fe135881c53d7d28cf14fa6b21926 100644 --- a/isegm/model/metrics.py +++ b/isegm/model/metrics.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch from isegm.utils import misc @@ -27,9 +27,17 @@ class TrainMetric(object): class AdaptiveIoU(TrainMetric): - def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, - ignore_label=-1, from_logits=True, - pred_output='instances', gt_output='instances'): + def __init__( + self, + init_thresh=0.4, + thresh_step=0.025, + thresh_beta=0.99, + iou_beta=0.9, + ignore_label=-1, + from_logits=True, + pred_output="instances", + gt_output="instances", + ): super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) self._ignore_label = ignore_label self._from_logits = from_logits @@ -59,7 +67,9 @@ class AdaptiveIoU(TrainMetric): max_iou = temp_iou best_thresh = t - self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh + self._iou_thresh = ( + self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh + ) self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou self._epoch_iou_sum += max_iou self._epoch_batch_count += 1 @@ -75,8 +85,14 @@ class AdaptiveIoU(TrainMetric): self._epoch_batch_count = 0 def log_states(self, sw, tag_prefix, global_step): - sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) - sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) + sw.add_scalar( + tag=tag_prefix + "_ema_iou", value=self._ema_iou, global_step=global_step + ) + sw.add_scalar( + tag=tag_prefix + "_iou_thresh", + value=self._iou_thresh, + global_step=global_step, + ) @property def iou_thresh(self): @@ -88,8 +104,18 @@ def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) - union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() - intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + union = ( + torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims) + .detach() + .cpu() + .numpy() + ) + intersection = ( + torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims) + .detach() + .cpu() + .numpy() + ) nonzero = union > 0 iou = intersection[nonzero] / union[nonzero] diff --git a/isegm/model/modeling/basic_blocks.py b/isegm/model/modeling/basic_blocks.py index 13753e85353ed9250aa3888ab2e715350b1b2c50..e5840a0920f3ceae73df25cf3c17bfffb50958ac 100644 --- a/isegm/model/modeling/basic_blocks.py +++ b/isegm/model/modeling/basic_blocks.py @@ -4,18 +4,28 @@ from isegm.model import ops class ConvHead(nn.Module): - def __init__(self, out_channels, in_channels=32, num_layers=1, - kernel_size=3, padding=1, - norm_layer=nn.BatchNorm2d): + def __init__( + self, + out_channels, + in_channels=32, + num_layers=1, + kernel_size=3, + padding=1, + norm_layer=nn.BatchNorm2d, + ): super(ConvHead, self).__init__() convhead = [] for i in range(num_layers): - convhead.extend([ - nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), - nn.ReLU(), - norm_layer(in_channels) if norm_layer is not None else nn.Identity() - ]) + convhead.extend( + [ + nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), + nn.ReLU(), + norm_layer(in_channels) + if norm_layer is not None + else nn.Identity(), + ] + ) convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) self.convhead = nn.Sequential(*convhead) @@ -25,25 +35,43 @@ class ConvHead(nn.Module): class SepConvHead(nn.Module): - def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, - kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, - norm_layer=nn.BatchNorm2d): + def __init__( + self, + num_outputs, + in_channels, + mid_channels, + num_layers=1, + kernel_size=3, + padding=1, + dropout_ratio=0.0, + dropout_indx=0, + norm_layer=nn.BatchNorm2d, + ): super(SepConvHead, self).__init__() sepconvhead = [] for i in range(num_layers): sepconvhead.append( - SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, - out_channels=mid_channels, - dw_kernel=kernel_size, dw_padding=padding, - norm_layer=norm_layer, activation='relu') + SeparableConv2d( + in_channels=in_channels if i == 0 else mid_channels, + out_channels=mid_channels, + dw_kernel=kernel_size, + dw_padding=padding, + norm_layer=norm_layer, + activation="relu", + ) ) if dropout_ratio > 0 and dropout_indx == i: sepconvhead.append(nn.Dropout(dropout_ratio)) sepconvhead.append( - nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) + nn.Conv2d( + in_channels=mid_channels, + out_channels=num_outputs, + kernel_size=1, + padding=0, + ) ) self.layers = nn.Sequential(*sepconvhead) @@ -55,16 +83,34 @@ class SepConvHead(nn.Module): class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, - activation=None, use_bias=False, norm_layer=None): + def __init__( + self, + in_channels, + out_channels, + dw_kernel, + dw_padding, + dw_stride=1, + activation=None, + use_bias=False, + norm_layer=None, + ): super(SeparableConv2d, self).__init__() _activation = ops.select_activation_function(activation) self.body = nn.Sequential( - nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, - padding=dw_padding, bias=use_bias, groups=in_channels), - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), + nn.Conv2d( + in_channels, + in_channels, + kernel_size=dw_kernel, + stride=dw_stride, + padding=dw_padding, + bias=use_bias, + groups=in_channels, + ), + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias + ), norm_layer(out_channels) if norm_layer is not None else nn.Identity(), - _activation() + _activation(), ) def forward(self, x): diff --git a/isegm/model/modeling/deeplab_v3.py b/isegm/model/modeling/deeplab_v3.py index 8219a4ef18048a0fc79fdf3e5b603af7eac03892..500a935c42681b04879a35bd47c564242499d8f6 100644 --- a/isegm/model/modeling/deeplab_v3.py +++ b/isegm/model/modeling/deeplab_v3.py @@ -1,21 +1,26 @@ from contextlib import ExitStack import torch -from torch import nn import torch.nn.functional as F +from torch import nn + +from isegm.model import ops from .basic_blocks import SeparableConv2d from .resnet import ResNetBackbone -from isegm.model import ops class DeepLabV3Plus(nn.Module): - def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, - backbone_norm_layer=None, - ch=256, - project_dropout=0.5, - inference_mode=False, - **kwargs): + def __init__( + self, + backbone="resnet50", + norm_layer=nn.BatchNorm2d, + backbone_norm_layer=None, + ch=256, + project_dropout=0.5, + inference_mode=False, + **kwargs + ): super(DeepLabV3Plus, self).__init__() if backbone_norm_layer is None: backbone_norm_layer = norm_layer @@ -29,28 +34,44 @@ class DeepLabV3Plus(nn.Module): self.skip_project_in_channels = 256 # layer 1 out_channels self._kwargs = kwargs - if backbone == 'resnet34': + if backbone == "resnet34": self.aspp_in_channels = 512 self.skip_project_in_channels = 64 - self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, - norm_layer=self.backbone_norm_layer, **kwargs) + self.backbone = ResNetBackbone( + backbone=self.backbone_name, + pretrained_base=False, + norm_layer=self.backbone_norm_layer, + **kwargs + ) - self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, - norm_layer=self.norm_layer) - self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) - self.aspp = _ASPP(in_channels=self.aspp_in_channels, - atrous_rates=[12, 24, 36], - out_channels=ch, - project_dropout=project_dropout, - norm_layer=self.norm_layer) + self.head = _DeepLabHead( + in_channels=ch + 32, + mid_channels=ch, + out_channels=ch, + norm_layer=self.norm_layer, + ) + self.skip_project = _SkipProject( + self.skip_project_in_channels, 32, norm_layer=self.norm_layer + ) + self.aspp = _ASPP( + in_channels=self.aspp_in_channels, + atrous_rates=[12, 24, 36], + out_channels=ch, + project_dropout=project_dropout, + norm_layer=self.norm_layer, + ) if inference_mode: self.set_prediction_mode() def load_pretrained_weights(self): - pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, - norm_layer=self.backbone_norm_layer, **self._kwargs) + pretrained = ResNetBackbone( + backbone=self.backbone_name, + pretrained_base=True, + norm_layer=self.backbone_norm_layer, + **self._kwargs + ) backbone_state_dict = self.backbone.state_dict() pretrained_state_dict = pretrained.state_dict() @@ -74,11 +95,11 @@ class DeepLabV3Plus(nn.Module): c1 = self.skip_project(c1) x = self.aspp(c4) - x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) + x = F.interpolate(x, c1.size()[2:], mode="bilinear", align_corners=True) x = torch.cat((x, c1), dim=1) x = self.head(x) - return x, + return (x,) class _SkipProject(nn.Module): @@ -89,7 +110,7 @@ class _SkipProject(nn.Module): self.skip_project = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), norm_layer(out_channels), - _activation() + _activation(), ) def forward(self, x): @@ -97,15 +118,31 @@ class _SkipProject(nn.Module): class _DeepLabHead(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): + def __init__( + self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d + ): super(_DeepLabHead, self).__init__() self.block = nn.Sequential( - SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, - dw_padding=1, activation='relu', norm_layer=norm_layer), - SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, - dw_padding=1, activation='relu', norm_layer=norm_layer), - nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) + SeparableConv2d( + in_channels=in_channels, + out_channels=mid_channels, + dw_kernel=3, + dw_padding=1, + activation="relu", + norm_layer=norm_layer, + ), + SeparableConv2d( + in_channels=mid_channels, + out_channels=mid_channels, + dw_kernel=3, + dw_padding=1, + activation="relu", + norm_layer=norm_layer, + ), + nn.Conv2d( + in_channels=mid_channels, out_channels=out_channels, kernel_size=1 + ), ) def forward(self, x): @@ -113,14 +150,25 @@ class _DeepLabHead(nn.Module): class _ASPP(nn.Module): - def __init__(self, in_channels, atrous_rates, out_channels=256, - project_dropout=0.5, norm_layer=nn.BatchNorm2d): + def __init__( + self, + in_channels, + atrous_rates, + out_channels=256, + project_dropout=0.5, + norm_layer=nn.BatchNorm2d, + ): super(_ASPP, self).__init__() b0 = nn.Sequential( - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), norm_layer(out_channels), - nn.ReLU() + nn.ReLU(), ) rate1, rate2, rate3 = tuple(atrous_rates) @@ -132,10 +180,14 @@ class _ASPP(nn.Module): self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) project = [ - nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, - kernel_size=1, bias=False), + nn.Conv2d( + in_channels=5 * out_channels, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), norm_layer(out_channels), - nn.ReLU() + nn.ReLU(), ] if project_dropout > 0: project.append(nn.Dropout(project_dropout)) @@ -153,24 +205,33 @@ class _AsppPooling(nn.Module): self.gap = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, - kernel_size=1, bias=False), + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), norm_layer(out_channels), - nn.ReLU() + nn.ReLU(), ) def forward(self, x): pool = self.gap(x) - return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) + return F.interpolate(pool, x.size()[2:], mode="bilinear", align_corners=True) def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): block = nn.Sequential( - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, - kernel_size=3, padding=atrous_rate, - dilation=atrous_rate, bias=False), + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=atrous_rate, + dilation=atrous_rate, + bias=False, + ), norm_layer(out_channels), - nn.ReLU() + nn.ReLU(), ) return block diff --git a/isegm/model/modeling/hrnet_ocr.py b/isegm/model/modeling/hrnet_ocr.py index d386ee0d376df2d498ef3c05f743caaf83374273..709a2f29c42f7898d86ee3439d72505315ecd403 100644 --- a/isegm/model/modeling/hrnet_ocr.py +++ b/isegm/model/modeling/hrnet_ocr.py @@ -1,19 +1,30 @@ import os + import numpy as np import torch -import torch.nn as nn import torch._utils +import torch.nn as nn import torch.nn.functional as F -from .ocr import SpatialOCR_Module, SpatialGather_Module + +from .ocr import SpatialGather_Module, SpatialOCR_Module from .resnetv1b import BasicBlockV1b, BottleneckV1b relu_inplace = True class HighResolutionModule(nn.Module): - def __init__(self, num_branches, blocks, num_blocks, num_inchannels, - num_channels, fuse_method,multi_scale_output=True, - norm_layer=nn.BatchNorm2d, align_corners=True): + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True, + norm_layer=nn.BatchNorm2d, + align_corners=True, + ): super(HighResolutionModule, self).__init__() self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) @@ -26,48 +37,67 @@ class HighResolutionModule(nn.Module): self.multi_scale_output = multi_scale_output self.branches = self._make_branches( - num_branches, blocks, num_blocks, num_channels) + num_branches, blocks, num_blocks, num_channels + ) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=relu_inplace) def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): - error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( - num_branches, len(num_blocks)) + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( + num_branches, len(num_blocks) + ) raise ValueError(error_msg) if num_branches != len(num_channels): - error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( - num_branches, len(num_channels)) + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( + num_branches, len(num_channels) + ) raise ValueError(error_msg) if num_branches != len(num_inchannels): - error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( - num_branches, len(num_inchannels)) + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( + num_branches, len(num_inchannels) + ) raise ValueError(error_msg) - def _make_one_branch(self, branch_index, block, num_blocks, num_channels, - stride=1): + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None - if stride != 1 or \ - self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + if ( + stride != 1 + or self.num_inchannels[branch_index] + != num_channels[branch_index] * block.expansion + ): downsample = nn.Sequential( - nn.Conv2d(self.num_inchannels[branch_index], - num_channels[branch_index] * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), self.norm_layer(num_channels[branch_index] * block.expansion), ) layers = [] - layers.append(block(self.num_inchannels[branch_index], - num_channels[branch_index], stride, - downsample=downsample, norm_layer=self.norm_layer)) - self.num_inchannels[branch_index] = \ - num_channels[branch_index] * block.expansion + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + norm_layer=self.norm_layer, + ) + ) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): - layers.append(block(self.num_inchannels[branch_index], - num_channels[branch_index], - norm_layer=self.norm_layer)) + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer, + ) + ) return nn.Sequential(*layers) @@ -75,8 +105,7 @@ class HighResolutionModule(nn.Module): branches = [] for i in range(num_branches): - branches.append( - self._make_one_branch(i, block, num_blocks, num_channels)) + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) @@ -91,12 +120,17 @@ class HighResolutionModule(nn.Module): fuse_layer = [] for j in range(num_branches): if j > i: - fuse_layer.append(nn.Sequential( - nn.Conv2d(in_channels=num_inchannels[j], - out_channels=num_inchannels[i], - kernel_size=1, - bias=False), - self.norm_layer(num_inchannels[i]))) + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False, + ), + self.norm_layer(num_inchannels[i]), + ) + ) elif j == i: fuse_layer.append(None) else: @@ -104,19 +138,35 @@ class HighResolutionModule(nn.Module): for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] - conv3x3s.append(nn.Sequential( - nn.Conv2d(num_inchannels[j], - num_outchannels_conv3x3, - kernel_size=3, stride=2, padding=1, bias=False), - self.norm_layer(num_outchannels_conv3x3))) + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + self.norm_layer(num_outchannels_conv3x3), + ) + ) else: num_outchannels_conv3x3 = num_inchannels[j] - conv3x3s.append(nn.Sequential( - nn.Conv2d(num_inchannels[j], - num_outchannels_conv3x3, - kernel_size=3, stride=2, padding=1, bias=False), - self.norm_layer(num_outchannels_conv3x3), - nn.ReLU(inplace=relu_inplace))) + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace), + ) + ) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) @@ -144,7 +194,9 @@ class HighResolutionModule(nn.Module): y = y + F.interpolate( self.fuse_layers[i][j](x[j]), size=[height_output, width_output], - mode='bilinear', align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) else: y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) @@ -153,8 +205,15 @@ class HighResolutionModule(nn.Module): class HighResolutionNet(nn.Module): - def __init__(self, width, num_classes, ocr_width=256, small=False, - norm_layer=nn.BatchNorm2d, align_corners=True): + def __init__( + self, + width, + num_classes, + ocr_width=256, + small=False, + norm_layer=nn.BatchNorm2d, + align_corners=True, + ): super(HighResolutionNet, self).__init__() self.norm_layer = norm_layer self.width = width @@ -170,40 +229,61 @@ class HighResolutionNet(nn.Module): num_blocks = 2 if small else 4 stage1_num_channels = 64 - self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + self.layer1 = self._make_layer( + BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks + ) stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels self.stage2_num_branches = 2 num_channels = [width, 2 * width] num_inchannels = [ - num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels)) + ] self.transition1 = self._make_transition_layer( - [stage1_out_channel], num_inchannels) + [stage1_out_channel], num_inchannels + ) self.stage2, pre_stage_channels = self._make_stage( - BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, - num_blocks=2 * [num_blocks], num_channels=num_channels) + BasicBlockV1b, + num_inchannels=num_inchannels, + num_modules=1, + num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], + num_channels=num_channels, + ) self.stage3_num_branches = 3 num_channels = [width, 2 * width, 4 * width] num_inchannels = [ - num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels)) + ] self.transition2 = self._make_transition_layer( - pre_stage_channels, num_inchannels) + pre_stage_channels, num_inchannels + ) self.stage3, pre_stage_channels = self._make_stage( - BasicBlockV1b, num_inchannels=num_inchannels, - num_modules=3 if small else 4, num_branches=self.stage3_num_branches, - num_blocks=3 * [num_blocks], num_channels=num_channels) + BasicBlockV1b, + num_inchannels=num_inchannels, + num_modules=3 if small else 4, + num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], + num_channels=num_channels, + ) self.stage4_num_branches = 4 num_channels = [width, 2 * width, 4 * width, 8 * width] num_inchannels = [ - num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels)) + ] self.transition3 = self._make_transition_layer( - pre_stage_channels, num_inchannels) + pre_stage_channels, num_inchannels + ) self.stage4, pre_stage_channels = self._make_stage( - BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + BasicBlockV1b, + num_inchannels=num_inchannels, + num_modules=2 if small else 3, num_branches=self.stage4_num_branches, - num_blocks=4 * [num_blocks], num_channels=num_channels) + num_blocks=4 * [num_blocks], + num_channels=num_channels, + ) last_inp_channels = np.int(np.sum(pre_stage_channels)) if self.ocr_width > 0: @@ -211,43 +291,77 @@ class HighResolutionNet(nn.Module): ocr_key_channels = self.ocr_width self.conv3x3_ocr = nn.Sequential( - nn.Conv2d(last_inp_channels, ocr_mid_channels, - kernel_size=3, stride=1, padding=1), + nn.Conv2d( + last_inp_channels, + ocr_mid_channels, + kernel_size=3, + stride=1, + padding=1, + ), norm_layer(ocr_mid_channels), nn.ReLU(inplace=relu_inplace), ) self.ocr_gather_head = SpatialGather_Module(num_classes) - self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, - key_channels=ocr_key_channels, - out_channels=ocr_mid_channels, - scale=1, - dropout=0.05, - norm_layer=norm_layer, - align_corners=align_corners) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners, + ) self.cls_head = nn.Conv2d( - ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) + ocr_mid_channels, + num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) self.aux_head = nn.Sequential( - nn.Conv2d(last_inp_channels, last_inp_channels, - kernel_size=1, stride=1, padding=0), + nn.Conv2d( + last_inp_channels, + last_inp_channels, + kernel_size=1, + stride=1, + padding=0, + ), norm_layer(last_inp_channels), nn.ReLU(inplace=relu_inplace), - nn.Conv2d(last_inp_channels, num_classes, - kernel_size=1, stride=1, padding=0, bias=True) + nn.Conv2d( + last_inp_channels, + num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), ) else: self.cls_head = nn.Sequential( - nn.Conv2d(last_inp_channels, last_inp_channels, - kernel_size=3, stride=1, padding=1), + nn.Conv2d( + last_inp_channels, + last_inp_channels, + kernel_size=3, + stride=1, + padding=1, + ), norm_layer(last_inp_channels), nn.ReLU(inplace=relu_inplace), - nn.Conv2d(last_inp_channels, num_classes, - kernel_size=1, stride=1, padding=0, bias=True) + nn.Conv2d( + last_inp_channels, + num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), ) - def _make_transition_layer( - self, num_channels_pre_layer, num_channels_cur_layer): + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) @@ -255,28 +369,45 @@ class HighResolutionNet(nn.Module): for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: - transition_layers.append(nn.Sequential( - nn.Conv2d(num_channels_pre_layer[i], - num_channels_cur_layer[i], - kernel_size=3, - stride=1, - padding=1, - bias=False), - self.norm_layer(num_channels_cur_layer[i]), - nn.ReLU(inplace=relu_inplace))) + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace), + ) + ) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] - outchannels = num_channels_cur_layer[i] \ - if j == i - num_branches_pre else inchannels - conv3x3s.append(nn.Sequential( - nn.Conv2d(inchannels, outchannels, - kernel_size=3, stride=2, padding=1, bias=False), - self.norm_layer(outchannels), - nn.ReLU(inplace=relu_inplace))) + outchannels = ( + num_channels_cur_layer[i] + if j == i - num_branches_pre + else inchannels + ) + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, + outchannels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace), + ) + ) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) @@ -285,24 +416,43 @@ class HighResolutionNet(nn.Module): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), self.norm_layer(planes * block.expansion), ) layers = [] - layers.append(block(inplanes, planes, stride, - downsample=downsample, norm_layer=self.norm_layer)) + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + norm_layer=self.norm_layer, + ) + ) inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) return nn.Sequential(*layers) - def _make_stage(self, block, num_inchannels, - num_modules, num_branches, num_blocks, num_channels, - fuse_method='SUM', - multi_scale_output=True): + def _make_stage( + self, + block, + num_inchannels, + num_modules, + num_branches, + num_blocks, + num_channels, + fuse_method="SUM", + multi_scale_output=True, + ): modules = [] for i in range(num_modules): # multi_scale_output is only used last module @@ -311,15 +461,17 @@ class HighResolutionNet(nn.Module): else: reset_multi_scale_output = True modules.append( - HighResolutionModule(num_branches, - block, - num_blocks, - num_inchannels, - num_channels, - fuse_method, - reset_multi_scale_output, - norm_layer=self.norm_layer, - align_corners=self.align_corners) + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners, + ) ) num_inchannels = modules[-1].get_num_inchannels() @@ -387,30 +539,38 @@ class HighResolutionNet(nn.Module): def aggregate_hrnet_features(self, x): # Upsampling x0_h, x0_w = x[0].size(2), x[0].size(3) - x1 = F.interpolate(x[1], size=(x0_h, x0_w), - mode='bilinear', align_corners=self.align_corners) - x2 = F.interpolate(x[2], size=(x0_h, x0_w), - mode='bilinear', align_corners=self.align_corners) - x3 = F.interpolate(x[3], size=(x0_h, x0_w), - mode='bilinear', align_corners=self.align_corners) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners + ) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners + ) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners + ) return torch.cat([x[0], x1, x2, x3], 1) - def load_pretrained_weights(self, pretrained_path=''): + def load_pretrained_weights(self, pretrained_path=""): model_dict = self.state_dict() if not os.path.exists(pretrained_path): print(f'\nFile "{pretrained_path}" does not exist.') - print('You need to specify the correct path to the pre-trained weights.\n' - 'You can download the weights for HRNet from the repository:\n' - 'https://github.com/HRNet/HRNet-Image-Classification') + print( + "You need to specify the correct path to the pre-trained weights.\n" + "You can download the weights for HRNet from the repository:\n" + "https://github.com/HRNet/HRNet-Image-Classification" + ) exit(1) - pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) - pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in - pretrained_dict.items()} - - pretrained_dict = {k: v for k, v in pretrained_dict.items() - if k in model_dict.keys()} + pretrained_dict = torch.load(pretrained_path, map_location={"cuda:0": "cpu"}) + pretrained_dict = { + k.replace("last_layer", "aux_head").replace("model.", ""): v + for k, v in pretrained_dict.items() + } + + pretrained_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict.keys() + } model_dict.update(pretrained_dict) self.load_state_dict(model_dict) diff --git a/isegm/model/modeling/ocr.py b/isegm/model/modeling/ocr.py index df3b4f67959fc6a088b93ee7a34b15c1e07402df..4742db102ecbe51a922d512f5c568caa2c2f35c9 100644 --- a/isegm/model/modeling/ocr.py +++ b/isegm/model/modeling/ocr.py @@ -1,14 +1,14 @@ import torch -import torch.nn as nn import torch._utils +import torch.nn as nn import torch.nn.functional as F class SpatialGather_Module(nn.Module): """ - Aggregate the context features according to the initial - predicted probability distribution. - Employ the soft-weighted method to aggregate the context. + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. """ def __init__(self, cls_num=0, scale=1): @@ -22,8 +22,9 @@ class SpatialGather_Module(nn.Module): feats = feats.view(batch_size, feats.size(1), -1) feats = feats.permute(0, 2, 1) # batch x hw x c probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw - ocr_context = torch.matmul(probs, feats) \ - .permute(0, 2, 1).unsqueeze(3) # batch x k x c + ocr_context = ( + torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3) + ) # batch x k x c return ocr_context @@ -33,23 +34,26 @@ class SpatialOCR_Module(nn.Module): We aggregate the global object representation to update the representation for each pixel. """ - def __init__(self, - in_channels, - key_channels, - out_channels, - scale=1, - dropout=0.1, - norm_layer=nn.BatchNorm2d, - align_corners=True): + def __init__( + self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True, + ): super(SpatialOCR_Module, self).__init__() - self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, - norm_layer, align_corners) + self.object_context_block = ObjectAttentionBlock2D( + in_channels, key_channels, scale, norm_layer, align_corners + ) _in_channels = 2 * in_channels self.conv_bn_dropout = nn.Sequential( nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), - nn.Dropout2d(dropout) + nn.Dropout2d(dropout), ) def forward(self, feats, proxy_feats): @@ -61,7 +65,7 @@ class SpatialOCR_Module(nn.Module): class ObjectAttentionBlock2D(nn.Module): - ''' + """ The basic implementation for object context block Input: N X C X H X W @@ -72,14 +76,16 @@ class ObjectAttentionBlock2D(nn.Module): bn_type : specify the bn type Return: N X C X H X W - ''' - - def __init__(self, - in_channels, - key_channels, - scale=1, - norm_layer=nn.BatchNorm2d, - align_corners=True): + """ + + def __init__( + self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True, + ): super(ObjectAttentionBlock2D, self).__init__() self.scale = scale self.in_channels = in_channels @@ -88,30 +94,66 @@ class ObjectAttentionBlock2D(nn.Module): self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) self.f_pixel = nn.Sequential( - nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), - nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0, bias=False), - nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) ) self.f_object = nn.Sequential( - nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), - nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0, bias=False), - nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) ) self.f_down = nn.Sequential( - nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, - kernel_size=1, stride=1, padding=0, bias=False), - nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), ) self.f_up = nn.Sequential( - nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, - kernel_size=1, stride=1, padding=0, bias=False), - nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)), ) def forward(self, x, proxy): @@ -126,7 +168,7 @@ class ObjectAttentionBlock2D(nn.Module): value = value.permute(0, 2, 1) sim_map = torch.matmul(query, key) - sim_map = (self.key_channels ** -.5) * sim_map + sim_map = (self.key_channels**-0.5) * sim_map sim_map = F.softmax(sim_map, dim=-1) # add bg context ... @@ -135,7 +177,11 @@ class ObjectAttentionBlock2D(nn.Module): context = context.view(batch_size, self.key_channels, *x.size()[2:]) context = self.f_up(context) if self.scale > 1: - context = F.interpolate(input=context, size=(h, w), - mode='bilinear', align_corners=self.align_corners) + context = F.interpolate( + input=context, + size=(h, w), + mode="bilinear", + align_corners=self.align_corners, + ) return context diff --git a/isegm/model/modeling/resnet.py b/isegm/model/modeling/resnet.py index 65fe949cef0035ba691ee319b25a0132d8ad37fe..fd413f68e159393a35f0e8c106dba4a29c9fd170 100644 --- a/isegm/model/modeling/resnet.py +++ b/isegm/model/modeling/resnet.py @@ -1,21 +1,32 @@ import torch + from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s class ResNetBackbone(torch.nn.Module): - def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): + def __init__( + self, backbone="resnet50", pretrained_base=True, dilated=True, **kwargs + ): super(ResNetBackbone, self).__init__() - if backbone == 'resnet34': - pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) - elif backbone == 'resnet50': - pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) - elif backbone == 'resnet101': - pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) - elif backbone == 'resnet152': - pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + if backbone == "resnet34": + pretrained = resnet34_v1b( + pretrained=pretrained_base, dilated=dilated, **kwargs + ) + elif backbone == "resnet50": + pretrained = resnet50_v1s( + pretrained=pretrained_base, dilated=dilated, **kwargs + ) + elif backbone == "resnet101": + pretrained = resnet101_v1s( + pretrained=pretrained_base, dilated=dilated, **kwargs + ) + elif backbone == "resnet152": + pretrained = resnet152_v1s( + pretrained=pretrained_base, dilated=dilated, **kwargs + ) else: - raise RuntimeError(f'unknown backbone: {backbone}') + raise RuntimeError(f"unknown backbone: {backbone}") self.conv1 = pretrained.conv1 self.bn1 = pretrained.bn1 @@ -31,9 +42,12 @@ class ResNetBackbone(torch.nn.Module): x = self.bn1(x) x = self.relu(x) if additional_features is not None: - x = x + torch.nn.functional.pad(additional_features, - [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], - mode='constant', value=0) + x = x + torch.nn.functional.pad( + additional_features, + [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], + mode="constant", + value=0, + ) x = self.maxpool(x) c1 = self.layer1(x) c2 = self.layer2(c1) diff --git a/isegm/model/modeling/resnetv1b.py b/isegm/model/modeling/resnetv1b.py index 4ad24cef5bde19f2627cfd3f755636f37cfb39ac..3f1ab9979db184c258fe23a733c92fe33acda4c2 100644 --- a/isegm/model/modeling/resnetv1b.py +++ b/isegm/model/modeling/resnetv1b.py @@ -1,19 +1,42 @@ import torch import torch.nn as nn -GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + +GLUON_RESNET_TORCH_HUB = "rwightman/pytorch-pretrained-gluonresnet" class BasicBlockV1b(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, - previous_dilation=1, norm_layer=nn.BatchNorm2d): + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + previous_dilation=1, + norm_layer=nn.BatchNorm2d, + ): super(BasicBlockV1b, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, bias=False) + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False, + ) self.bn1 = norm_layer(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, - padding=previous_dilation, dilation=previous_dilation, bias=False) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=previous_dilation, + dilation=previous_dilation, + bias=False, + ) self.bn2 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) @@ -42,17 +65,34 @@ class BasicBlockV1b(nn.Module): class BottleneckV1b(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, - previous_dilation=1, norm_layer=nn.BatchNorm2d): + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + previous_dilation=1, + norm_layer=nn.BatchNorm2d, + ): super(BottleneckV1b, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = norm_layer(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, bias=False) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False, + ) self.bn2 = norm_layer(planes) - self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) @@ -83,7 +123,7 @@ class BottleneckV1b(nn.Module): class ResNetV1b(nn.Module): - """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + """Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. Parameters ---------- @@ -111,86 +151,198 @@ class ResNetV1b(nn.Module): - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." """ - def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, - avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): - self.inplanes = stem_width*2 if deep_stem else 64 + + def __init__( + self, + block, + layers, + classes=1000, + dilated=True, + deep_stem=False, + stem_width=32, + avg_down=False, + final_drop=0.0, + norm_layer=nn.BatchNorm2d, + ): + self.inplanes = stem_width * 2 if deep_stem else 64 super(ResNetV1b, self).__init__() if not deep_stem: - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False + ) else: self.conv1 = nn.Sequential( - nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + nn.Conv2d( + 3, stem_width, kernel_size=3, stride=2, padding=1, bias=False + ), norm_layer(stem_width), nn.ReLU(True), - nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + stem_width, + stem_width, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), norm_layer(stem_width), nn.ReLU(True), - nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + nn.Conv2d( + stem_width, + 2 * stem_width, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(True) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, - norm_layer=norm_layer) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, - norm_layer=norm_layer) + self.layer1 = self._make_layer( + block, 64, layers[0], avg_down=avg_down, norm_layer=norm_layer + ) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, avg_down=avg_down, norm_layer=norm_layer + ) if dilated: - self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, - avg_down=avg_down, norm_layer=norm_layer) - self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, - avg_down=avg_down, norm_layer=norm_layer) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=1, + dilation=2, + avg_down=avg_down, + norm_layer=norm_layer, + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=4, + avg_down=avg_down, + norm_layer=norm_layer, + ) else: - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - avg_down=avg_down, norm_layer=norm_layer) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - avg_down=avg_down, norm_layer=norm_layer) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + avg_down=avg_down, + norm_layer=norm_layer, + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + avg_down=avg_down, + norm_layer=norm_layer, + ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.drop = None if final_drop > 0.0: self.drop = nn.Dropout(final_drop) self.fc = nn.Linear(512 * block.expansion, classes) - def _make_layer(self, block, planes, blocks, stride=1, dilation=1, - avg_down=False, norm_layer=nn.BatchNorm2d): + def _make_layer( + self, + block, + planes, + blocks, + stride=1, + dilation=1, + avg_down=False, + norm_layer=nn.BatchNorm2d, + ): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = [] if avg_down: if dilation == 1: downsample.append( - nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False, + ) ) else: downsample.append( - nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + nn.AvgPool2d( + kernel_size=1, + stride=1, + ceil_mode=True, + count_include_pad=False, + ) ) - downsample.extend([ - nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, - kernel_size=1, stride=1, bias=False), - norm_layer(planes * block.expansion) - ]) + downsample.extend( + [ + nn.Conv2d( + self.inplanes, + out_channels=planes * block.expansion, + kernel_size=1, + stride=1, + bias=False, + ), + norm_layer(planes * block.expansion), + ] + ) downsample = nn.Sequential(*downsample) else: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - norm_layer(planes * block.expansion) + nn.Conv2d( + self.inplanes, + out_channels=planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + norm_layer(planes * block.expansion), ) layers = [] if dilation in (1, 2): - layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, - previous_dilation=dilation, norm_layer=norm_layer)) + layers.append( + block( + self.inplanes, + planes, + stride, + dilation=1, + downsample=downsample, + previous_dilation=dilation, + norm_layer=norm_layer, + ) + ) elif dilation == 4: - layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, - previous_dilation=dilation, norm_layer=norm_layer)) + layers.append( + block( + self.inplanes, + planes, + stride, + dilation=2, + downsample=downsample, + previous_dilation=dilation, + norm_layer=norm_layer, + ) + ) else: raise RuntimeError("=> unknown dilation size: {}".format(dilation)) self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, dilation=dilation, - previous_dilation=dilation, norm_layer=norm_layer)) + layers.append( + block( + self.inplanes, + planes, + dilation=dilation, + previous_dilation=dilation, + norm_layer=norm_layer, + ) + ) return nn.Sequential(*layers) @@ -229,8 +381,10 @@ def resnet34_v1b(pretrained=False, **kwargs): if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( - torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), - model_dict.keys() + torch.hub.load( + GLUON_RESNET_TORCH_HUB, "gluon_resnet34_v1b", pretrained=True + ).state_dict(), + model_dict.keys(), ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) @@ -238,12 +392,16 @@ def resnet34_v1b(pretrained=False, **kwargs): def resnet50_v1s(pretrained=False, **kwargs): - model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + model = ResNetV1b( + BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs + ) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( - torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), - model_dict.keys() + torch.hub.load( + GLUON_RESNET_TORCH_HUB, "gluon_resnet50_v1s", pretrained=True + ).state_dict(), + model_dict.keys(), ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) @@ -251,12 +409,16 @@ def resnet50_v1s(pretrained=False, **kwargs): def resnet101_v1s(pretrained=False, **kwargs): - model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + model = ResNetV1b( + BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs + ) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( - torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), - model_dict.keys() + torch.hub.load( + GLUON_RESNET_TORCH_HUB, "gluon_resnet101_v1s", pretrained=True + ).state_dict(), + model_dict.keys(), ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) @@ -264,12 +426,16 @@ def resnet101_v1s(pretrained=False, **kwargs): def resnet152_v1s(pretrained=False, **kwargs): - model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + model = ResNetV1b( + BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs + ) if pretrained: model_dict = model.state_dict() filtered_orig_dict = _safe_state_dict_filtering( - torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), - model_dict.keys() + torch.hub.load( + GLUON_RESNET_TORCH_HUB, "gluon_resnet152_v1s", pretrained=True + ).state_dict(), + model_dict.keys(), ) model_dict.update(filtered_orig_dict) model.load_state_dict(model_dict) diff --git a/isegm/model/modifiers.py b/isegm/model/modifiers.py index 046221838069e90ae201b9169db159cc69c13244..c57eb92d5dd75d91e26fe9d5ff3ae82bfad4cacf 100644 --- a/isegm/model/modifiers.py +++ b/isegm/model/modifiers.py @@ -1,11 +1,9 @@ - - class LRMult(object): - def __init__(self, lr_mult=1.): + def __init__(self, lr_mult=1.0): self.lr_mult = lr_mult def __call__(self, m): - if getattr(m, 'weight', None) is not None: + if getattr(m, "weight", None) is not None: m.weight.lr_mult = self.lr_mult - if getattr(m, 'bias', None) is not None: + if getattr(m, "bias", None) is not None: m.bias.lr_mult = self.lr_mult diff --git a/isegm/model/ops.py b/isegm/model/ops.py index 9be9c73cbef7b83645af93e1fa7338fa6513a92b..bece8dad2922399710cd0dd98fdf2adc61623112 100644 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -1,14 +1,15 @@ +import numpy as np import torch from torch import nn as nn -import numpy as np + import isegm.model.initializer as initializer def select_activation_function(activation): if isinstance(activation, str): - if activation.lower() == 'relu': + if activation.lower() == "relu": return nn.ReLU - elif activation.lower() == 'softplus': + elif activation.lower() == "softplus": return nn.Softplus else: raise ValueError(f"Unknown activation type {activation}") @@ -24,14 +25,18 @@ class BilinearConvTranspose2d(nn.ConvTranspose2d): self.scale = scale super().__init__( - in_channels, out_channels, + in_channels, + out_channels, kernel_size=kernel_size, stride=scale, padding=1, groups=groups, - bias=False) + bias=False, + ) - self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) + self.apply( + initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups) + ) class DistMaps(nn.Module): @@ -43,29 +48,47 @@ class DistMaps(nn.Module): self.use_disks = use_disks if self.cpu_mode: from isegm.utils.cython import get_dist_maps + self._get_dist_maps = get_dist_maps def get_coord_features(self, points, batchsize, rows, cols): if self.cpu_mode: coords = [] for i in range(batchsize): - norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius - coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, - norm_delimeter)) - coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + norm_delimeter = ( + 1.0 if self.use_disks else self.spatial_scale * self.norm_radius + ) + coords.append( + self._get_dist_maps( + points[i].cpu().float().numpy(), rows, cols, norm_delimeter + ) + ) + coords = ( + torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + ) else: num_points = points.shape[1] // 2 points = points.view(-1, points.size(2)) points, points_order = torch.split(points, [2, 1], dim=1) invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 - row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) - col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + row_array = torch.arange( + start=0, end=rows, step=1, dtype=torch.float32, device=points.device + ) + col_array = torch.arange( + start=0, end=cols, step=1, dtype=torch.float32, device=points.device + ) coord_rows, coord_cols = torch.meshgrid(row_array, col_array) - coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) - - add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) + coords = ( + torch.stack((coord_rows, coord_cols), dim=0) + .unsqueeze(0) + .repeat(points.size(0), 1, 1, 1) + ) + + add_xy = (points * self.spatial_scale).view( + points.size(0), points.size(1), 1, 1 + ) coords.add_(-add_xy) if not self.use_disks: coords.div_(self.norm_radius * self.spatial_scale) diff --git a/isegm/utils/cython/__init__.py b/isegm/utils/cython/__init__.py index eb66bdbba883b9477bbc1a52d8355131d32a04cb..b04e42c5cf6afd9523f925d381c797cf9dd036b4 100644 --- a/isegm/utils/cython/__init__.py +++ b/isegm/utils/cython/__init__.py @@ -1,2 +1,2 @@ # noinspection PyUnresolvedReferences -from .dist_maps import get_dist_maps \ No newline at end of file +from .dist_maps import get_dist_maps diff --git a/isegm/utils/cython/_get_dist_maps.pyx b/isegm/utils/cython/_get_dist_maps.pyx index 779a7f02ad7c2ba25e68302c6fc6683cd4ab54f7..02e6b3d95dbcd7a834c0675fab54794f884ac4da 100644 --- a/isegm/utils/cython/_get_dist_maps.pyx +++ b/isegm/utils/cython/_get_dist_maps.pyx @@ -1,7 +1,8 @@ import numpy as np + cimport cython cimport numpy as np -from libc.stdlib cimport malloc, free +from libc.stdlib cimport free, malloc ctypedef struct qnode: int row diff --git a/isegm/utils/cython/dist_maps.py b/isegm/utils/cython/dist_maps.py index 8ffa1e3f25231cd7c48b66ef8ef5167235c3ea4e..bd7fc6719d10203a87c28caa78f756fcf3aab2ac 100644 --- a/isegm/utils/cython/dist_maps.py +++ b/isegm/utils/cython/dist_maps.py @@ -1,3 +1,5 @@ -import pyximport; pyximport.install(pyximport=True, language_level=3) +import pyximport + +pyximport.install(pyximport=True, language_level=3) # noinspection PyUnresolvedReferences -from ._get_dist_maps import get_dist_maps \ No newline at end of file +from ._get_dist_maps import get_dist_maps diff --git a/isegm/utils/distributed.py b/isegm/utils/distributed.py index a1e48f50500ee7440be035b17107573e86bb5d24..639ede61a25597045536b908c224afb7ae6d144b 100644 --- a/isegm/utils/distributed.py +++ b/isegm/utils/distributed.py @@ -10,7 +10,11 @@ def get_rank(): def synchronize(): - if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1: + if ( + not dist.is_available() + or not dist.is_initialized() + or dist.get_world_size() == 1 + ): return dist.barrier() @@ -58,10 +62,15 @@ def get_sampler(dataset, shuffle, distributed): def get_dp_wrapper(distributed): - class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel): + class DPWrapper( + torch.nn.parallel.DistributedDataParallel + if distributed + else torch.nn.DataParallel + ): def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.module, name) + return DPWrapper diff --git a/isegm/utils/exp.py b/isegm/utils/exp.py index 1ff63ccb3524f06d76475f6b7a77058431b2fe14..ab267e08e3b91bd0a0c1c8b2a50c7727cab29139 100644 --- a/isegm/utils/exp.py +++ b/isegm/utils/exp.py @@ -1,16 +1,16 @@ import os -import sys -import shutil import pprint -from pathlib import Path +import shutil +import sys from datetime import datetime +from pathlib import Path -import yaml import torch +import yaml from easydict import EasyDict as edict -from .log import logger, add_logging -from .distributed import synchronize, get_world_size +from .distributed import get_world_size, synchronize +from .log import add_logging, logger def init_experiment(args, model_name): @@ -18,7 +18,9 @@ def init_experiment(args, model_name): ftree = get_model_family_tree(model_path, model_name=model_name) if ftree is None: - print('Models can only be located in the "models" directory in the root of the repository') + print( + 'Models can only be located in the "models" directory in the root of the repository' + ) sys.exit(1) cfg = load_config(model_path) @@ -27,37 +29,40 @@ def init_experiment(args, model_name): cfg.distributed = args.distributed cfg.local_rank = args.local_rank if cfg.distributed: - torch.distributed.init_process_group(backend='nccl', init_method='env://') + torch.distributed.init_process_group(backend="nccl", init_method="env://") if args.workers > 0: - torch.multiprocessing.set_start_method('forkserver', force=True) + torch.multiprocessing.set_start_method("forkserver", force=True) experiments_path = Path(cfg.EXPS_PATH) - exp_parent_path = experiments_path / '/'.join(ftree) + exp_parent_path = experiments_path / "/".join(ftree) exp_parent_path.mkdir(parents=True, exist_ok=True) if cfg.resume_exp: exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) else: last_exp_indx = find_last_exp_indx(exp_parent_path) - exp_name = f'{last_exp_indx:03d}' + exp_name = f"{last_exp_indx:03d}" if cfg.exp_name: - exp_name += '_' + cfg.exp_name + exp_name += "_" + cfg.exp_name exp_path = exp_parent_path / exp_name synchronize() if cfg.local_rank == 0: exp_path.mkdir(parents=True) cfg.EXP_PATH = exp_path - cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' - cfg.VIS_PATH = exp_path / 'vis' - cfg.LOGS_PATH = exp_path / 'logs' + cfg.CHECKPOINTS_PATH = exp_path / "checkpoints" + cfg.VIS_PATH = exp_path / "vis" + cfg.LOGS_PATH = exp_path / "logs" if cfg.local_rank == 0: cfg.LOGS_PATH.mkdir(exist_ok=True) cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) cfg.VIS_PATH.mkdir(exist_ok=True) - dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py')) + dst_script_path = exp_path / ( + model_path.stem + + datetime.strftime(datetime.today(), "_%Y-%m-%d-%H-%M-%S.py") + ) if args.temp_model_path: shutil.copy(args.temp_model_path, dst_script_path) os.remove(args.temp_model_path) @@ -66,40 +71,40 @@ def init_experiment(args, model_name): synchronize() - if cfg.gpus != '': - gpu_ids = [int(id) for id in cfg.gpus.split(',')] + if cfg.gpus != "": + gpu_ids = [int(id) for id in cfg.gpus.split(",")] else: gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) - cfg.gpus = ','.join([str(id) for id in gpu_ids]) + cfg.gpus = ",".join([str(id) for id in gpu_ids]) cfg.gpu_ids = gpu_ids cfg.ngpus = len(gpu_ids) cfg.multi_gpu = cfg.ngpus > 1 if cfg.distributed: - cfg.device = torch.device('cuda') + cfg.device = torch.device("cuda") cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] torch.cuda.set_device(cfg.gpu_ids[0]) else: if cfg.multi_gpu: - os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus + os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus ngpus = torch.cuda.device_count() assert ngpus == cfg.ngpus - cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}') + cfg.device = torch.device(f"cuda:{cfg.gpu_ids[0]}") if cfg.local_rank == 0: - add_logging(cfg.LOGS_PATH, prefix='train_') - logger.info(f'Number of GPUs: {cfg.ngpus}') + add_logging(cfg.LOGS_PATH, prefix="train_") + logger.info(f"Number of GPUs: {cfg.ngpus}") if cfg.distributed: - logger.info(f'Multi-Process Multi-GPU Distributed Training') + logger.info(f"Multi-Process Multi-GPU Distributed Training") - logger.info('Run experiment with config:') + logger.info("Run experiment with config:") logger.info(pprint.pformat(cfg, indent=4)) return cfg -def get_model_family_tree(model_path, terminate_name='models', model_name=None): +def get_model_family_tree(model_path, terminate_name="models", model_name=None): if model_name is None: model_name = model_path.stem family_tree = [model_name] @@ -127,12 +132,14 @@ def find_last_exp_indx(exp_parent_path): def find_resume_exp(exp_parent_path, exp_pattern): - candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*')) + candidates = sorted(exp_parent_path.glob(f"{exp_pattern}*")) if len(candidates) == 0: - print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"') + print( + f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"' + ) sys.exit(1) elif len(candidates) > 1: - print('More than one experiment found:') + print("More than one experiment found:") for x in candidates: print(x) sys.exit(1) @@ -152,7 +159,7 @@ def update_config(cfg, args): def load_config(model_path): model_name = model_path.stem - config_path = model_path.parent / (model_name + '.yml') + config_path = model_path.parent / (model_name + ".yml") if config_path.exists(): cfg = load_config_file(config_path) @@ -162,7 +169,7 @@ def load_config(model_path): cwd = Path.cwd() config_parent = config_path.parent.absolute() while len(config_parent.parents) > 0: - config_path = config_parent / 'config.yml' + config_path = config_parent / "config.yml" if config_path.exists(): local_config = load_config_file(config_path, model_name=model_name) @@ -176,12 +183,12 @@ def load_config(model_path): def load_config_file(config_path, model_name=None, return_edict=False): - with open(config_path, 'r') as f: + with open(config_path, "r") as f: cfg = yaml.safe_load(f) - if 'SUBCONFIGS' in cfg: - if model_name is not None and model_name in cfg['SUBCONFIGS']: - cfg.update(cfg['SUBCONFIGS'][model_name]) - del cfg['SUBCONFIGS'] + if "SUBCONFIGS" in cfg: + if model_name is not None and model_name in cfg["SUBCONFIGS"]: + cfg.update(cfg["SUBCONFIGS"][model_name]) + del cfg["SUBCONFIGS"] return edict(cfg) if return_edict else cfg diff --git a/isegm/utils/exp_imports/default.py b/isegm/utils/exp_imports/default.py index e78e21c85013af8ccd4c23d860c792bc40a2d822..2deda90b78724ba47795544ceb56c4752262050b 100644 --- a/isegm/utils/exp_imports/default.py +++ b/isegm/utils/exp_imports/default.py @@ -1,16 +1,16 @@ -import torch from functools import partial -from easydict import EasyDict as edict + +import torch from albumentations import * +from easydict import EasyDict as edict from isegm.data.datasets import * -from isegm.model.losses import * +from isegm.data.points_sampler import MultiPointSampler from isegm.data.transforms import * from isegm.engine.trainer import ISTrainer -from isegm.model.metrics import AdaptiveIoU -from isegm.data.points_sampler import MultiPointSampler -from isegm.utils.log import logger from isegm.model import initializer - +from isegm.model.is_deeplab_model import DeeplabModel from isegm.model.is_hrnet_model import HRNetModel -from isegm.model.is_deeplab_model import DeeplabModel \ No newline at end of file +from isegm.model.losses import * +from isegm.model.metrics import AdaptiveIoU +from isegm.utils.log import logger diff --git a/isegm/utils/log.py b/isegm/utils/log.py index 1f9f8bdb4bdd74d72514db8cf9cecef51001a588..53ae2aaa1a1b30ef6b9e078ede5ae7cec722b142 100644 --- a/isegm/utils/log.py +++ b/isegm/utils/log.py @@ -1,13 +1,13 @@ import io -import time import logging +import time from datetime import datetime import numpy as np from torch.utils.tensorboard import SummaryWriter -LOGGER_NAME = 'root' -LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' +LOGGER_NAME = "root" +LOGGER_DATEFMT = "%Y-%m-%d %H:%M:%S" handler = logging.StreamHandler() @@ -17,12 +17,15 @@ logger.addHandler(handler) def add_logging(logs_path, prefix): - log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' + log_name = ( + prefix + datetime.strftime(datetime.today(), "%Y-%m-%d_%H-%M-%S") + ".log" + ) stdout_log_path = logs_path / log_name fh = logging.FileHandler(str(stdout_log_path)) - formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', - datefmt=LOGGER_DATEFMT) + formatter = logging.Formatter( + fmt="(%(levelname)s) %(asctime)s: %(message)s", datefmt=LOGGER_DATEFMT + ) fh.setFormatter(formatter) logger.addHandler(fh) @@ -30,7 +33,7 @@ def add_logging(logs_path, prefix): class TqdmToLogger(io.StringIO): logger = None level = None - buf = '' + buf = "" def __init__(self, logger, level=None, mininterval=5): super(TqdmToLogger, self).__init__() @@ -40,8 +43,8 @@ class TqdmToLogger(io.StringIO): self.last_time = 0 def write(self, buf): - self.buf = buf.strip('\r\n\t ') - + self.buf = buf.strip("\r\n\t ") + def flush(self): if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: self.logger.log(self.level, self.buf) @@ -64,8 +67,7 @@ class SummaryWriterAvg(SummaryWriter): avg_scalar.add(value) if avg_scalar.is_full(): - super().add_scalar(tag, avg_scalar.value, - global_step=global_step) + super().add_scalar(tag, avg_scalar.value, global_step=global_step) avg_scalar.reset() diff --git a/isegm/utils/misc.py b/isegm/utils/misc.py index 688c11e182f1aaea0f23d8e58811f713cf816da9..b342a18059b5f7694a190991bb8d1917fe54aa13 100644 --- a/isegm/utils/misc.py +++ b/isegm/utils/misc.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch from .log import logger @@ -12,25 +12,28 @@ def get_dims_with_exclusion(dim, exclude=None): return dims -def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False): +def save_checkpoint( + net, checkpoints_path, epoch=None, prefix="", verbose=True, multi_gpu=False +): if epoch is None: - checkpoint_name = 'last_checkpoint.pth' + checkpoint_name = "last_checkpoint.pth" else: - checkpoint_name = f'{epoch:03d}.pth' + checkpoint_name = f"{epoch:03d}.pth" if prefix: - checkpoint_name = f'{prefix}_{checkpoint_name}' + checkpoint_name = f"{prefix}_{checkpoint_name}" if not checkpoints_path.exists(): checkpoints_path.mkdir(parents=True) checkpoint_path = checkpoints_path / checkpoint_name if verbose: - logger.info(f'Save checkpoint to {str(checkpoint_path)}') + logger.info(f"Save checkpoint to {str(checkpoint_path)}") net = net.module if multi_gpu else net - torch.save({'state_dict': net.state_dict(), - 'config': net._config}, str(checkpoint_path)) + torch.save( + {"state_dict": net.state_dict(), "config": net._config}, str(checkpoint_path) + ) def get_bbox_from_mask(mask): @@ -61,8 +64,12 @@ def expand_bbox(bbox, expand_ratio, min_crop_size=None): def clamp_bbox(bbox, rmin, rmax, cmin, cmax): - return (max(rmin, bbox[0]), min(rmax, bbox[1]), - max(cmin, bbox[2]), min(cmax, bbox[3])) + return ( + max(rmin, bbox[0]), + min(rmax, bbox[1]), + max(cmin, bbox[2]), + min(cmax, bbox[3]), + ) def get_bbox_iou(b1, b2): diff --git a/isegm/utils/serialization.py b/isegm/utils/serialization.py index c73935b9aa7e7f2f5a11c685c4192321da78c5f3..612862cc54d573c685f9355128ac3dab0fbe9db7 100644 --- a/isegm/utils/serialization.py +++ b/isegm/utils/serialization.py @@ -1,6 +1,7 @@ -from functools import wraps -from copy import deepcopy import inspect +from copy import deepcopy +from functools import wraps + import torch.nn as nn @@ -13,10 +14,7 @@ def serialize(init): for pname, value in zip(parameters[1:], args): params[pname] = value - config = { - 'class': get_classname(self.__class__), - 'params': dict() - } + config = {"class": get_classname(self.__class__), "params": dict()} specified_params = set(params.keys()) for pname, param in get_default_params(self.__class__).items(): @@ -24,38 +22,38 @@ def serialize(init): params[pname] = param.default for name, value in list(params.items()): - param_type = 'builtin' + param_type = "builtin" if inspect.isclass(value): - param_type = 'class' + param_type = "class" value = get_classname(value) - config['params'][name] = { - 'type': param_type, - 'value': value, - 'specified': name in specified_params + config["params"][name] = { + "type": param_type, + "value": value, + "specified": name in specified_params, } - setattr(self, '_config', config) + setattr(self, "_config", config) init(self, *args, **kwargs) return new_init def load_model(config, **kwargs): - model_class = get_class_from_str(config['class']) + model_class = get_class_from_str(config["class"]) model_default_params = get_default_params(model_class) model_args = dict() - for pname, param in config['params'].items(): - value = param['value'] - if param['type'] == 'class': + for pname, param in config["params"].items(): + value = param["value"] + if param["type"] == "class": value = get_class_from_str(value) - if pname not in model_default_params and not param['specified']: + if pname not in model_default_params and not param["specified"]: continue assert pname in model_default_params - if not param['specified'] and model_default_params[pname].default == value: + if not param["specified"] and model_default_params[pname].default == value: continue model_args[pname] = value @@ -66,14 +64,14 @@ def load_model(config, **kwargs): def get_config_repr(config): config_str = f'Model: {config["class"]}\n' - for pname, param in config['params'].items(): + for pname, param in config["params"].items(): value = param["value"] - if param['type'] == 'class': - value = value.split('.')[-1] - param_str = f'{pname:<22} = {str(value):<12}' - if not param['specified']: - param_str += ' (default)' - config_str += param_str + '\n' + if param["type"] == "class": + value = value.split(".")[-1] + param_str = f"{pname:<22} = {str(value):<12}" + if not param["specified"]: + param_str += " (default)" + config_str += param_str + "\n" return config_str @@ -100,8 +98,8 @@ def get_classname(cls): def get_class_from_str(class_str): - components = class_str.split('.') - mod = __import__('.'.join(components[:-1])) + components = class_str.split(".") + mod = __import__(".".join(components[:-1])) for comp in components[1:]: mod = getattr(mod, comp) return mod diff --git a/isegm/utils/vis.py b/isegm/utils/vis.py index 9790a4ca76e7768fd95980cd2d3a492800bfdd1e..ae2d57c48198b2d4d304aecb6813cbd77dc74503 100644 --- a/isegm/utils/vis.py +++ b/isegm/utils/vis.py @@ -4,8 +4,9 @@ import cv2 import numpy as np -def visualize_instances(imask, bg_color=255, - boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): +def visualize_instances( + imask, bg_color=255, boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8 +): num_objects = imask.max() + 1 palette = get_palette(num_objects) if bg_color is not None: @@ -31,9 +32,9 @@ def get_palette(num_cls): i = 0 while lab > 0: - palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) - palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) - palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) + palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i) + palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i) + palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i) i = i + 1 lab >>= 3 @@ -93,7 +94,9 @@ def blend_mask(image, mask, alpha=0.6): def get_boundaries(instances_masks, boundaries_width=1): - boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) + boundaries = np.zeros( + (instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool + ) for obj_id in np.unique(instances_masks.flatten()): if obj_id == 0: @@ -101,15 +104,24 @@ def get_boundaries(instances_masks, boundaries_width=1): obj_mask = instances_masks == obj_id kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) - inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) + inner_mask = cv2.erode( + obj_mask.astype(np.uint8), kernel, iterations=boundaries_width + ).astype(np.bool) obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) boundaries = np.logical_or(boundaries, obj_boundary) return boundaries - - -def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), - neg_color=(255, 0, 0), radius=4): + + +def draw_with_blend_and_clicks( + img, + mask=None, + alpha=0.6, + clicks_list=None, + pos_color=(0, 255, 0), + neg_color=(255, 0, 0), + radius=4, +): result = img.copy() if mask is not None: @@ -117,9 +129,11 @@ def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_ rgb_mask = palette[mask.astype(np.uint8)] mask_region = (mask > 0).astype(np.uint8) - result = result * (1 - mask_region[:, :, np.newaxis]) + \ - (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ - alpha * rgb_mask + result = ( + result * (1 - mask_region[:, :, np.newaxis]) + + (1 - alpha) * mask_region[:, :, np.newaxis] * result + + alpha * rgb_mask + ) result = result.astype(np.uint8) # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) @@ -132,4 +146,3 @@ def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_ result = draw_points(result, neg_points, neg_color, radius=radius) return result -