diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a699bc5b3c2e987102ca93e0ee28d601e0a93d02 --- /dev/null +++ b/app.py @@ -0,0 +1,7 @@ +import gradio as gr + +def greet(name): + return "Hello " + name + "!!" + +iface = gr.Interface(fn=greet, inputs="text", outputs="text") +iface.launch() \ No newline at end of file diff --git a/main_model.pt b/main_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..b2748ddd085f0aa9bd3356170180eda79f0392d8 --- /dev/null +++ b/main_model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6817c7bdd29a33ed9379f72d082390bb4052fb307744671834ff6c011cefd051 +size 485832489 diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..30a751e9476dfef230bbcea6e04963e8157ec561 --- /dev/null +++ b/predict.py @@ -0,0 +1,1256 @@ +import torch +from torch.nn import ( + Module, + Conv2d, + BatchNorm2d, + Identity, + UpsamplingBilinear2d, + Mish, + ReLU, + Sequential, +) +from torch.nn.functional import interpolate, grid_sample, pad +import numpy as np +from copy import deepcopy +import os, argparse, math +import tifffile as tif +from typing import Tuple, List, Mapping + +from monai.utils import ( + BlendMode, + PytorchPadMode, + convert_data_type, + ensure_tuple, + fall_back_tuple, + look_up_option, + convert_to_dst_type, +) +from monai.utils.misc import ensure_tuple_size, ensure_tuple_rep, issequenceiterable +from monai.networks.layers.convutils import gaussian_1d +from monai.networks.layers.simplelayers import separable_filtering + +from segmentation_models_pytorch import MAnet + +from skimage.io import imread as io_imread +from skimage.util.dtype import dtype_range +from skimage._shared.utils import _supported_float_type +from scipy.ndimage import find_objects, binary_fill_holes + + +########################### Data Loading Modules ######################################################### +DTYPE_RANGE = dtype_range.copy() +DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items()) +DTYPE_RANGE.update( + { + "uint10": (0, 2 ** 10 - 1), + "uint12": (0, 2 ** 12 - 1), + "uint14": (0, 2 ** 14 - 1), + "bool": dtype_range[bool], + "float": dtype_range[np.float64], + } +) + + +def _output_dtype(dtype_or_range, image_dtype): + if type(dtype_or_range) in [list, tuple, np.ndarray]: + # pair of values: always return float. + return _supported_float_type(image_dtype) + if type(dtype_or_range) == type: + # already a type: return it + return dtype_or_range + if dtype_or_range in DTYPE_RANGE: + # string key in DTYPE_RANGE dictionary + try: + # if it's a canonical numpy dtype, convert + return np.dtype(dtype_or_range).type + except TypeError: # uint10, uint12, uint14 + # otherwise, return uint16 + return np.uint16 + else: + raise ValueError( + "Incorrect value for out_range, should be a valid image data " + f"type or a pair of values, got {dtype_or_range}." + ) + + +def intensity_range(image, range_values="image", clip_negative=False): + if range_values == "dtype": + range_values = image.dtype.type + + if range_values == "image": + i_min = np.min(image) + i_max = np.max(image) + elif range_values in DTYPE_RANGE: + i_min, i_max = DTYPE_RANGE[range_values] + if clip_negative: + i_min = 0 + else: + i_min, i_max = range_values + return i_min, i_max + + +def rescale_intensity(image, in_range="image", out_range="dtype"): + out_dtype = _output_dtype(out_range, image.dtype) + + imin, imax = map(float, intensity_range(image, in_range)) + omin, omax = map( + float, intensity_range(image, out_range, clip_negative=(imin >= 0)) + ) + image = np.clip(image, imin, imax) + + if imin != imax: + image = (image - imin) / (imax - imin) + return np.asarray(image * (omax - omin) + omin, dtype=out_dtype) + else: + return np.clip(image, omin, omax).astype(out_dtype) + + +def _normalize(img): + non_zero_vals = img[np.nonzero(img)] + percentiles = np.percentile(non_zero_vals, [0, 99.5]) + img_norm = rescale_intensity( + img, in_range=(percentiles[0], percentiles[1]), out_range="uint8" + ) + + return img_norm.astype(np.uint8) + + +def pred_transforms(filename): + # LoadImage + img = ( + tif.imread(filename) + if filename.endswith(".tif") or filename.endswith(".tiff") + else io_imread(filename) + ) + + if len(img.shape) == 2: + img = np.repeat(np.expand_dims(img, axis=-1), 3, axis=-1) + elif len(img.shape) == 3 and img.shape[-1] > 3: + img = img[:, :, :3] + + img = img.astype(np.float32) + img = _normalize(img) + img = np.moveaxis(img, -1, 0) + img = (img - img.min()) / (img.max() - img.min()) + + return torch.FloatTensor(img).unsqueeze(0) + + +################################################################################ + +########################### MODEL Architecture ################################# +class SegformerGH(MAnet): + def __init__( + self, + encoder_name: str = "mit_b5", + encoder_weights="imagenet", + decoder_channels=(256, 128, 64, 32, 32), + decoder_pab_channels=256, + in_channels: int = 3, + classes: int = 3, + ): + super(SegformerGH, self).__init__( + encoder_name=encoder_name, + encoder_weights=encoder_weights, + decoder_channels=decoder_channels, + decoder_pab_channels=decoder_pab_channels, + in_channels=in_channels, + classes=classes, + ) + + convert_relu_to_mish(self.encoder) + convert_relu_to_mish(self.decoder) + + self.cellprob_head = DeepSegmantationHead( + in_channels=decoder_channels[-1], out_channels=1, kernel_size=3, + ) + self.gradflow_head = DeepSegmantationHead( + in_channels=decoder_channels[-1], out_channels=2, kernel_size=3, + ) + + def forward(self, x): + """Sequentially pass `x` trough model`s encoder, decoder and heads""" + self.check_input_shape(x) + + features = self.encoder(x) + decoder_output = self.decoder(*features) + + gradflow_mask = self.gradflow_head(decoder_output) + cellprob_mask = self.cellprob_head(decoder_output) + + masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) + + return masks + + +class DeepSegmantationHead(Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d_1 = Conv2d( + in_channels, + in_channels // 2, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + bn = BatchNorm2d(in_channels // 2) + conv2d_2 = Conv2d( + in_channels // 2, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + mish = Mish(inplace=True) + + upsampling = ( + UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else Identity() + ) + activation = Identity() + super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation) + + +def convert_relu_to_mish(model): + for child_name, child in model.named_children(): + if isinstance(child, ReLU): + setattr(model, child_name, Mish(inplace=True)) + else: + convert_relu_to_mish(child) + + +##################################################################################### + +########################### Sliding Window Inference ################################# +class GaussianFilter(Module): + def __init__( + self, spatial_dims, sigma, truncated=4.0, approx="erf", requires_grad=False, + ) -> None: + if issequenceiterable(sigma): + if len(sigma) != spatial_dims: # type: ignore + raise ValueError + else: + sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore + super().__init__() + self.sigma = [ + torch.nn.Parameter( + torch.as_tensor( + s, + dtype=torch.float, + device=s.device if isinstance(s, torch.Tensor) else None, + ), + requires_grad=requires_grad, + ) + for s in sigma # type: ignore + ] + self.truncated = truncated + self.approx = approx + for idx, param in enumerate(self.sigma): + self.register_parameter(f"kernel_sigma_{idx}", param) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _kernel = [ + gaussian_1d(s, truncated=self.truncated, approx=self.approx) + for s in self.sigma + ] + return separable_filtering(x=x, kernels=_kernel) + + +def compute_importance_map( + patch_size, mode=BlendMode.CONSTANT, sigma_scale=0.125, device="cpu" +): + mode = look_up_option(mode, BlendMode) + device = torch.device(device) + + center_coords = [i // 2 for i in patch_size] + sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size)) + sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)] + + importance_map = torch.zeros(patch_size, device=device) + importance_map[tuple(center_coords)] = 1 + pt_gaussian = GaussianFilter(len(patch_size), sigmas).to( + device=device, dtype=torch.float + ) + importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0)) + importance_map = importance_map.squeeze(0).squeeze(0) + importance_map = importance_map / torch.max(importance_map) + importance_map = importance_map.float() + + return importance_map + + +def first(iterable, default=None): + for i in iterable: + return i + + return default + + +def dense_patch_slices(image_size, patch_size, scan_interval): + num_spatial_dims = len(image_size) + patch_size = get_valid_patch_size(image_size, patch_size) + scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims) + + scan_num = [] + for i in range(num_spatial_dims): + if scan_interval[i] == 0: + scan_num.append(1) + else: + num = int(math.ceil(float(image_size[i]) / scan_interval[i])) + scan_dim = first( + d + for d in range(num) + if d * scan_interval[i] + patch_size[i] >= image_size[i] + ) + scan_num.append(scan_dim + 1 if scan_dim is not None else 1) + + starts = [] + for dim in range(num_spatial_dims): + dim_starts = [] + for idx in range(scan_num[dim]): + start_idx = idx * scan_interval[dim] + start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0) + dim_starts.append(start_idx) + starts.append(dim_starts) + out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T + return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] + + +def get_valid_patch_size(image_size, patch_size): + ndim = len(image_size) + patch_size_ = ensure_tuple_size(patch_size, ndim) + + # ensure patch size dimensions are not larger than image dimension, if a dimension is None or 0 use whole dimension + return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_)) + + +class Resize: + def __init__(self, spatial_size): + self.size_mode = "all" + self.spatial_size = spatial_size + + def __call__(self, img): + input_ndim = img.ndim - 1 # spatial ndim + output_ndim = len(ensure_tuple(self.spatial_size)) + + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + + spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + + if ( + tuple(img.shape[1:]) == spatial_size_ + ): # spatial shape is already the desired + return img + + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) + + resized = interpolate( + input=img_.unsqueeze(0), size=spatial_size_, mode="nearest", + ) + out, *_ = convert_to_dst_type(resized.squeeze(0), img) + return out + + +def sliding_window_inference( + inputs, + roi_size, + sw_batch_size, + predictor, + overlap, + mode=BlendMode.CONSTANT, + sigma_scale=0.125, + padding_mode=PytorchPadMode.CONSTANT, + cval=0.0, + sw_device=None, + device=None, + roi_weight_map=None, +): + compute_dtype = inputs.dtype + num_spatial_dims = len(inputs.shape) - 2 + batch_size, _, *image_size_ = inputs.shape + + roi_size = fall_back_tuple(roi_size, image_size_) + # in case that image size is smaller than roi size + image_size = tuple( + max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims) + ) + pad_size = [] + + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + + inputs = pad( + inputs, + pad=pad_size, + mode=look_up_option(padding_mode, PytorchPadMode).value, + value=cval, + ) + + scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) + + # Store all slices in list + slices = dense_patch_slices(image_size, roi_size, scan_interval) + num_win = len(slices) # number of windows per image + total_slices = num_win * batch_size # total number of windows + + # Create window-level importance map + valid_patch_size = get_valid_patch_size(image_size, roi_size) + if valid_patch_size == roi_size and (roi_weight_map is not None): + importance_map = roi_weight_map + else: + importance_map = compute_importance_map( + valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device + ) + + importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore + # handle non-positive weights + min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) + importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to( + compute_dtype + ) + + # Perform predictions + dict_key, output_image_list, count_map_list = None, [], [] + _initialized_ss = -1 + is_tensor_output = ( + True # whether the predictor's output is a tensor (instead of dict/tuple) + ) + + # for each patch + for slice_g in range(0, total_slices, sw_batch_size): + slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) + unravel_slice = [ + [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + + list(slices[idx % num_win]) + for idx in slice_range + ] + window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to( + sw_device + ) + seg_prob_out = predictor(window_data) # batched patch segmentation + + # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. + seg_prob_tuple: Tuple[torch.Tensor, ...] + if isinstance(seg_prob_out, torch.Tensor): + seg_prob_tuple = (seg_prob_out,) + elif isinstance(seg_prob_out, Mapping): + if dict_key is None: + dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys + seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) + is_tensor_output = False + else: + seg_prob_tuple = ensure_tuple(seg_prob_out) + is_tensor_output = False + + # for each output in multi-output list + for ss, seg_prob in enumerate(seg_prob_tuple): + seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN + + # compute zoom scale: out_roi_size/in_roi_size + zoom_scale = [] + for axis, (img_s_i, out_w_i, in_w_i) in enumerate( + zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) + ): + _scale = out_w_i / float(in_w_i) + + zoom_scale.append(_scale) + + if _initialized_ss < ss: # init. the ss-th buffer at the first iteration + # construct multi-resolution outputs + output_classes = seg_prob.shape[1] + output_shape = [batch_size, output_classes] + [ + int(image_size_d * zoom_scale_d) + for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) + ] + # allocate memory to store the full output and the count for overlapping parts + output_image_list.append( + torch.zeros(output_shape, dtype=compute_dtype, device=device) + ) + count_map_list.append( + torch.zeros( + [1, 1] + output_shape[2:], dtype=compute_dtype, device=device + ) + ) + _initialized_ss += 1 + + # resizing the importance_map + resizer = Resize(spatial_size=seg_prob.shape[2:]) + + # store the result in the proper location of the full output. Apply weights from importance map. + for idx, original_idx in zip(slice_range, unravel_slice): + # zoom roi + original_idx_zoom = list( + original_idx + ) # 4D for 2D image, 5D for 3D image + for axis in range(2, len(original_idx_zoom)): + zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] + zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] + + original_idx_zoom[axis] = slice( + int(zoomed_start), int(zoomed_end), None + ) + importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to( + compute_dtype + ) + # store results and weights + output_image_list[ss][original_idx_zoom] += ( + importance_map_zoom * seg_prob[idx - slice_g] + ) + count_map_list[ss][original_idx_zoom] += ( + importance_map_zoom.unsqueeze(0) + .unsqueeze(0) + .expand(count_map_list[ss][original_idx_zoom].shape) + ) + + # account for any overlapping sections + for ss in range(len(output_image_list)): + output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to( + compute_dtype + ) + + # remove padding if image_size smaller than roi_size + for ss, output_i in enumerate(output_image_list): + zoom_scale = [ + seg_prob_map_shape_d / roi_size_d + for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) + ] + + final_slicing: List[slice] = [] + for sp in range(num_spatial_dims): + slice_dim = slice( + pad_size[sp * 2], + image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2], + ) + slice_dim = slice( + int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), + int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), + ) + final_slicing.insert(0, slice_dim) + while len(final_slicing) < len(output_i.shape): + final_slicing.insert(0, slice(None)) + output_image_list[ss] = output_i[final_slicing] + + if dict_key is not None: # if output of predictor is a dict + final_output = dict(zip(dict_key, output_image_list)) + else: + final_output = tuple(output_image_list) # type: ignore + + return final_output[0] if is_tensor_output else final_output # type: ignore + + +def _get_scan_interval( + image_size, roi_size, num_spatial_dims: int, overlap: float +) -> Tuple[int, ...]: + scan_interval = [] + + for i in range(num_spatial_dims): + if roi_size[i] == image_size[i]: + scan_interval.append(int(roi_size[i])) + else: + interval = int(roi_size[i] * (1 - overlap)) + scan_interval.append(interval if interval > 0 else 1) + + return tuple(scan_interval) + + +##################################################################################### + +########################### Main Inference Functions ################################# +def post_process(pred_mask, device): + dP, cellprob = pred_mask[:2], 1 / (1 + np.exp(-pred_mask[-1])) + H, W = pred_mask.shape[-2], pred_mask.shape[-1] + + if np.prod(H * W) < (5000 * 5000): + pred_mask = compute_masks( + dP, + cellprob, + use_gpu=True, + flow_threshold=0.4, + device=device, + cellprob_threshold=0.4, + )[0] + + else: + print("\n[Whole Slide] Grid Prediction starting...") + roi_size = 2000 + + # Get patch grid by roi_size + if H % roi_size != 0: + n_H = H // roi_size + 1 + new_H = roi_size * n_H + else: + n_H = H // roi_size + new_H = H + + if W % roi_size != 0: + n_W = W // roi_size + 1 + new_W = roi_size * n_W + else: + n_W = W // roi_size + new_W = W + + # Allocate values on the grid + pred_pad = np.zeros((new_H, new_W), dtype=np.uint32) + dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32) + cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32) + + dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob + + for i in range(n_H): + for j in range(n_W): + print("Pred on Grid (%d, %d) processing..." % (i, j)) + dP_roi = dP_pad[ + :, + roi_size * i : roi_size * (i + 1), + roi_size * j : roi_size * (j + 1), + ] + cellprob_roi = cellprob_pad[ + roi_size * i : roi_size * (i + 1), + roi_size * j : roi_size * (j + 1), + ] + + pred_mask = compute_masks( + dP_roi, + cellprob_roi, + use_gpu=True, + flow_threshold=0.4, + device=device, + cellprob_threshold=0.4, + )[0] + + pred_pad[ + roi_size * i : roi_size * (i + 1), + roi_size * j : roi_size * (j + 1), + ] = pred_mask + + pred_mask = pred_pad[:H, :W] + + cell_idx, cell_sizes = np.unique(pred_mask, return_counts=True) + cell_idx, cell_sizes = cell_idx[1:], cell_sizes[1:] + cell_drop = np.where(cell_sizes < np.mean(cell_sizes) - 2.7 * np.std(cell_sizes)) + + for drop_cell in cell_idx[cell_drop]: + pred_mask[pred_mask == drop_cell] = 0 + + return pred_mask + + +def hflip(x): + """flip batch of images horizontally""" + return x.flip(3) + + +def vflip(x): + """flip batch of images vertically""" + return x.flip(2) + + +class DualTransform: + identity_param = None + + def __init__( + self, name: str, params, + ): + self.params = params + self.pname = name + + def apply_aug_image(self, image, *args, **params): + raise NotImplementedError + + def apply_deaug_mask(self, mask, *args, **params): + raise NotImplementedError + + +class HorizontalFlip(DualTransform): + """Flip images horizontally (left->right)""" + + identity_param = False + + def __init__(self): + super().__init__("apply", [False, True]) + + def apply_aug_image(self, image, apply=False, **kwargs): + if apply: + image = hflip(image) + return image + + def apply_deaug_mask(self, mask, apply=False, **kwargs): + if apply: + mask = hflip(mask) + return mask + + +class VerticalFlip(DualTransform): + """Flip images vertically (up->down)""" + + identity_param = False + + def __init__(self): + super().__init__("apply", [False, True]) + + def apply_aug_image(self, image, apply=False, **kwargs): + if apply: + image = vflip(image) + return image + + def apply_deaug_mask(self, mask, apply=False, **kwargs): + if apply: + mask = vflip(mask) + return mask + + +#################### GradFlow Modules ################################################## +from scipy.ndimage.filters import maximum_filter1d +import scipy.ndimage +import fastremap +from skimage import morphology + +from scipy.ndimage import mean + +torch_GPU = torch.device("cuda") +torch_CPU = torch.device("cpu") + + +def _extend_centers_gpu( + neighbors, centers, isneighbor, Ly, Lx, n_iter=200, device=torch.device("cuda") +): + if device is not None: + device = device + nimg = neighbors.shape[0] // 9 + pt = torch.from_numpy(neighbors).to(device) + + T = torch.zeros((nimg, Ly, Lx), dtype=torch.double, device=device) + meds = torch.from_numpy(centers.astype(int)).to(device).long() + isneigh = torch.from_numpy(isneighbor).to(device) + for i in range(n_iter): + T[:, meds[:, 0], meds[:, 1]] += 1 + Tneigh = T[:, pt[:, :, 0], pt[:, :, 1]] + Tneigh *= isneigh + T[:, pt[0, :, 0], pt[0, :, 1]] = Tneigh.mean(axis=1) + del meds, isneigh, Tneigh + T = torch.log(1.0 + T) + # gradient positions + grads = T[:, pt[[2, 1, 4, 3], :, 0], pt[[2, 1, 4, 3], :, 1]] + del pt + dy = grads[:, 0] - grads[:, 1] + dx = grads[:, 2] - grads[:, 3] + del grads + mu_torch = np.stack((dy.cpu().squeeze(), dx.cpu().squeeze()), axis=-2) + return mu_torch + + +def diameters(masks): + _, counts = np.unique(np.int32(masks), return_counts=True) + counts = counts[1:] + md = np.median(counts ** 0.5) + if np.isnan(md): + md = 0 + md /= (np.pi ** 0.5) / 2 + return md, counts ** 0.5 + + +def masks_to_flows_gpu(masks, device=None): + if device is None: + device = torch.device("cuda") + + Ly0, Lx0 = masks.shape + Ly, Lx = Ly0 + 2, Lx0 + 2 + + masks_padded = np.zeros((Ly, Lx), np.int64) + masks_padded[1:-1, 1:-1] = masks + + # get mask pixel neighbors + y, x = np.nonzero(masks_padded) + neighborsY = np.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), axis=0) + neighborsX = np.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), axis=0) + neighbors = np.stack((neighborsY, neighborsX), axis=-1) + + # get mask centers + slices = scipy.ndimage.find_objects(masks) + + centers = np.zeros((masks.max(), 2), "int") + for i, si in enumerate(slices): + if si is not None: + sr, sc = si + + ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1 + yi, xi = np.nonzero(masks[sr, sc] == (i + 1)) + yi = yi.astype(np.int32) + 1 # add padding + xi = xi.astype(np.int32) + 1 # add padding + ymed = np.median(yi) + xmed = np.median(xi) + imin = np.argmin((xi - xmed) ** 2 + (yi - ymed) ** 2) + xmed = xi[imin] + ymed = yi[imin] + centers[i, 0] = ymed + sr.start + centers[i, 1] = xmed + sc.start + + # get neighbor validator (not all neighbors are in same mask) + neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1]] + isneighbor = neighbor_masks == neighbor_masks[0] + ext = np.array( + [[sr.stop - sr.start + 1, sc.stop - sc.start + 1] for sr, sc in slices] + ) + n_iter = 2 * (ext.sum(axis=1)).max() + # run diffusion + mu = _extend_centers_gpu( + neighbors, centers, isneighbor, Ly, Lx, n_iter=n_iter, device=device + ) + + # normalize + mu /= 1e-20 + (mu ** 2).sum(axis=0) ** 0.5 + + # put into original image + mu0 = np.zeros((2, Ly0, Lx0)) + mu0[:, y - 1, x - 1] = mu + mu_c = np.zeros_like(mu0) + return mu0, mu_c + + +def masks_to_flows(masks, use_gpu=False, device=None): + if masks.max() == 0 or (masks != 0).sum() == 1: + # dynamics_logger.warning('empty masks!') + return np.zeros((2, *masks.shape), "float32") + + if use_gpu: + if use_gpu and device is None: + device = torch_GPU + elif device is None: + device = torch_CPU + masks_to_flows_device = masks_to_flows_gpu + + if masks.ndim == 3: + Lz, Ly, Lx = masks.shape + mu = np.zeros((3, Lz, Ly, Lx), np.float32) + for z in range(Lz): + mu0 = masks_to_flows_device(masks[z], device=device)[0] + mu[[1, 2], z] += mu0 + for y in range(Ly): + mu0 = masks_to_flows_device(masks[:, y], device=device)[0] + mu[[0, 2], :, y] += mu0 + for x in range(Lx): + mu0 = masks_to_flows_device(masks[:, :, x], device=device)[0] + mu[[0, 1], :, :, x] += mu0 + return mu + elif masks.ndim == 2: + mu, mu_c = masks_to_flows_device(masks, device=device) + return mu + + else: + raise ValueError("masks_to_flows only takes 2D or 3D arrays") + + +def steps2D_interp(p, dP, niter, use_gpu=False, device=None): + shape = dP.shape[1:] + if use_gpu: + if device is None: + device = torch_GPU + shape = ( + np.array(shape)[[1, 0]].astype("float") - 1 + ) # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1 + pt = ( + torch.from_numpy(p[[1, 0]].T).float().to(device).unsqueeze(0).unsqueeze(0) + ) # p is n_points by 2, so pt is [1 1 2 n_points] + im = ( + torch.from_numpy(dP[[1, 0]]).float().to(device).unsqueeze(0) + ) # covert flow numpy array to tensor on GPU, add dimension + # normalize pt between 0 and 1, normalize the flow + for k in range(2): + im[:, k, :, :] *= 2.0 / shape[k] + pt[:, :, :, k] /= shape[k] + + # normalize to between -1 and 1 + pt = pt * 2 - 1 + + # here is where the stepping happens + for t in range(niter): + # align_corners default is False, just added to suppress warning + dPt = grid_sample(im, pt, align_corners=False) + + for k in range(2): # clamp the final pixel locations + pt[:, :, :, k] = torch.clamp( + pt[:, :, :, k] + dPt[:, k, :, :], -1.0, 1.0 + ) + + # undo the normalization from before, reverse order of operations + pt = (pt + 1) * 0.5 + for k in range(2): + pt[:, :, :, k] *= shape[k] + + p = pt[:, :, :, [1, 0]].cpu().numpy().squeeze().T + return p + + else: + assert print("ho") + + +def follow_flows(dP, mask=None, niter=200, interp=True, use_gpu=True, device=None): + shape = np.array(dP.shape[1:]).astype(np.int32) + niter = np.uint32(niter) + + p = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing="ij") + p = np.array(p).astype(np.float32) + + inds = np.array(np.nonzero(np.abs(dP[0]) > 1e-3)).astype(np.int32).T + + if inds.ndim < 2 or inds.shape[0] < 5: + return p, None + + if not interp: + assert print("woo") + + else: + p_interp = steps2D_interp( + p[:, inds[:, 0], inds[:, 1]], dP, niter, use_gpu=use_gpu, device=device + ) + p[:, inds[:, 0], inds[:, 1]] = p_interp + + return p, inds + + +def flow_error(maski, dP_net, use_gpu=False, device=None): + if dP_net.shape[1:] != maski.shape: + print("ERROR: net flow is not same size as predicted masks") + return + + # flows predicted from estimated masks + dP_masks = masks_to_flows(maski, use_gpu=use_gpu, device=device) + # difference between predicted flows vs mask flows + flow_errors = np.zeros(maski.max()) + for i in range(dP_masks.shape[0]): + flow_errors += mean( + (dP_masks[i] - dP_net[i] / 5.0) ** 2, + maski, + index=np.arange(1, maski.max() + 1), + ) + + return flow_errors, dP_masks + + +def remove_bad_flow_masks(masks, flows, threshold=0.4, use_gpu=False, device=None): + merrors, _ = flow_error(masks, flows, use_gpu, device) + badi = 1 + (merrors > threshold).nonzero()[0] + masks[np.isin(masks, badi)] = 0 + return masks + + +def get_masks(p, iscell=None, rpad=20): + pflows = [] + edges = [] + shape0 = p.shape[1:] + dims = len(p) + + for i in range(dims): + pflows.append(p[i].flatten().astype("int32")) + edges.append(np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1)) + + h, _ = np.histogramdd(tuple(pflows), bins=edges) + hmax = h.copy() + for i in range(dims): + hmax = maximum_filter1d(hmax, 5, axis=i) + + seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10)) + Nmax = h[seeds] + isort = np.argsort(Nmax)[::-1] + for s in seeds: + s = s[isort] + + pix = list(np.array(seeds).T) + + shape = h.shape + if dims == 3: + expand = np.nonzero(np.ones((3, 3, 3))) + else: + expand = np.nonzero(np.ones((3, 3))) + for e in expand: + e = np.expand_dims(e, 1) + + for iter in range(5): + for k in range(len(pix)): + if iter == 0: + pix[k] = list(pix[k]) + newpix = [] + iin = [] + for i, e in enumerate(expand): + epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1 + epix = epix.flatten() + iin.append(np.logical_and(epix >= 0, epix < shape[i])) + newpix.append(epix) + iin = np.all(tuple(iin), axis=0) + for p in newpix: + p = p[iin] + newpix = tuple(newpix) + igood = h[newpix] > 2 + for i in range(dims): + pix[k][i] = newpix[i][igood] + if iter == 4: + pix[k] = tuple(pix[k]) + + M = np.zeros(h.shape, np.uint32) + for k in range(len(pix)): + M[pix[k]] = 1 + k + + for i in range(dims): + pflows[i] = pflows[i] + rpad + M0 = M[tuple(pflows)] + + # remove big masks + uniq, counts = fastremap.unique(M0, return_counts=True) + big = np.prod(shape0) * 0.9 + bigc = uniq[counts > big] + if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0): + M0 = fastremap.mask(M0, bigc) + fastremap.renumber(M0, in_place=True) # convenient to guarantee non-skipped labels + M0 = np.reshape(M0, shape0) + return M0 + +def fill_holes_and_remove_small_masks(masks, min_size=15): + """ fill holes in masks (2D/3D) and discard masks smaller than min_size (2D) + + fill holes in each mask using scipy.ndimage.morphology.binary_fill_holes + (might have issues at borders between cells, todo: check and fix) + + Parameters + ---------------- + masks: int, 2D or 3D array + labelled masks, 0=NO masks; 1,2,...=mask labels, + size [Ly x Lx] or [Lz x Ly x Lx] + min_size: int (optional, default 15) + minimum number of pixels per mask, can turn off with -1 + Returns + --------------- + masks: int, 2D or 3D array + masks with holes filled and masks smaller than min_size removed, + 0=NO masks; 1,2,...=mask labels, + size [Ly x Lx] or [Lz x Ly x Lx] + + """ + + slices = find_objects(masks) + j = 0 + for i,slc in enumerate(slices): + if slc is not None: + msk = masks[slc] == (i+1) + npix = msk.sum() + if min_size > 0 and npix < min_size: + masks[slc][msk] = 0 + elif npix > 0: + if msk.ndim==3: + for k in range(msk.shape[0]): + msk[k] = binary_fill_holes(msk[k]) + else: + msk = binary_fill_holes(msk) + masks[slc][msk] = (j+1) + j+=1 + return masks + +def compute_masks( + dP, + cellprob, + p=None, + niter=200, + cellprob_threshold=0.4, + flow_threshold=0.4, + interp=True, + resize=None, + use_gpu=False, + device=None, +): + """compute masks using dynamics from dP, cellprob, and boundary""" + + cp_mask = cellprob > cellprob_threshold + cp_mask = morphology.remove_small_holes(cp_mask, area_threshold=16) + cp_mask = morphology.remove_small_objects(cp_mask, min_size=16) + + if np.any(cp_mask): # mask at this point is a cell cluster binary map, not labels + # follow flows + if p is None: + p, inds = follow_flows( + dP * cp_mask / 5.0, + niter=niter, + interp=interp, + use_gpu=use_gpu, + device=device, + ) + if inds is None: + shape = resize if resize is not None else cellprob.shape + mask = np.zeros(shape, np.uint16) + p = np.zeros((len(shape), *shape), np.uint16) + return mask, p + + # calculate masks + mask = get_masks(p, iscell=cp_mask) + + # flow thresholding factored out of get_masks + shape0 = p.shape[1:] + if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0: + # make sure labels are unique at output of get_masks + mask = remove_bad_flow_masks( + mask, dP, threshold=flow_threshold, use_gpu=use_gpu, device=device + ) + + mask = fill_holes_and_remove_small_masks(mask, min_size=15) + + else: # nothing to compute, just make it compatible + shape = resize if resize is not None else cellprob.shape + mask = np.zeros(shape, np.uint16) + p = np.zeros((len(shape), *shape), np.uint16) + return mask, p + + return mask, p + +def main(args): + model = torch.load(args.model_path, map_location=args.device) + model.eval() + hflip_tta = HorizontalFlip() + vflip_tta = VerticalFlip() + + img_names = sorted(os.listdir(args.input_path)) + os.makedirs(args.output_path, exist_ok=True) + + for img_name in img_names: + print(f"Segmenting {img_name}") + img_path = os.path.join(args.input_path, img_name) + img_data = pred_transforms(img_path) + img_data = img_data.to(args.device) + img_size = img_data.shape[-1] * img_data.shape[-2] + + if img_size < 1150000 and 900000 < img_size: + overlap = 0.5 + else: + overlap = 0.6 + + with torch.no_grad(): + img0 = img_data + outputs0 = sliding_window_inference( + img0, + 512, + 4, + model, + padding_mode="reflect", + mode="gaussian", + overlap=overlap, + device="cpu", + ) + outputs0 = outputs0.cpu().squeeze() + + if img_size < 2000 * 2000: + + model.load_state_dict(torch.load(args.model_path2, map_location=args.device)) + model.eval() + + img2 = hflip_tta.apply_aug_image(img_data, apply=True) + outputs2 = sliding_window_inference( + img2, + 512, + 4, + model, + padding_mode="reflect", + mode="gauusian", + overlap=overlap, + device="cpu", + ) + outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True) + outputs2 = outputs2.cpu().squeeze() + + outputs = torch.zeros_like(outputs0) + outputs[0] = (outputs0[0] + outputs2[0]) / 2 + outputs[1] = (outputs0[1] - outputs2[1]) / 2 + outputs[2] = (outputs0[2] + outputs2[2]) / 2 + + elif img_size < 5000*5000: + # Hflip TTA + img2 = hflip_tta.apply_aug_image(img_data, apply=True) + outputs2 = sliding_window_inference( + img2, + 512, + 4, + model, + padding_mode="reflect", + mode="gaussian", + overlap=overlap, + device="cpu", + ) + outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True) + outputs2 = outputs2.cpu().squeeze() + img2 = img2.cpu() + + ################## + # # + # ensemble # + # # + ################## + + model.load_state_dict(torch.load(args.model_path2, map_location=args.device)) + model.eval() + + img1 = img_data + outputs1 = sliding_window_inference( + img1, + 512, + 4, + model, + padding_mode="reflect", + mode="gaussian", + overlap=overlap, + device="cpu", + ) + outputs1 = outputs1.cpu().squeeze() + + # Vflip TTA + img3 = vflip_tta.apply_aug_image(img_data, apply=True) + outputs3 = sliding_window_inference( + img3, + 512, + 4, + model, + padding_mode="reflect", + mode="gaussian", + overlap=overlap, + device="cpu", + ) + outputs3 = vflip_tta.apply_deaug_mask(outputs3, apply=True) + outputs3 = outputs3.cpu().squeeze() + img3 = img3.cpu() + + # Merge Results + outputs = torch.zeros_like(outputs0) + outputs[0] = (outputs0[0] + outputs1[0] + outputs2[0] - outputs3[0]) / 4 + outputs[1] = (outputs0[1] + outputs1[1] - outputs2[1] + outputs3[1]) / 4 + outputs[2] = (outputs0[2] + outputs1[2] + outputs2[2] + outputs3[2]) / 4 + else: + outputs = outputs0 + + pred_mask = post_process(outputs.squeeze(0).cpu().numpy(), args.device) + + file_path = os.path.join( + args.output_path, img_name.split(".")[0] + "_label.tiff" + ) + + tif.imwrite(file_path, pred_mask, compression="zlib") + + +parser = argparse.ArgumentParser("Submission for Challenge", add_help=False) +parser.add_argument("--model_path", default="./model.pt", type=str) +parser.add_argument("--model_path2", default="./model_sec.pth", type=str) + +# Dataset parameters +parser.add_argument( + "-i", + "--input_path", + default="/workspace/inputs/", + type=str, + help="training data path; subfolders: images, labels", +) +parser.add_argument( + "-o", "--output_path", default="/workspace/outputs/", type=str, help="output path", +) +parser.add_argument("--device", default="cuda:0", type=str) + +args = parser.parse_args() + +if __name__ == "__main__": + print("Starting") + main(args) diff --git a/predict.sh b/predict.sh new file mode 100644 index 0000000000000000000000000000000000000000..62e2366e56d45f037b58d81de273fe38d94404b1 --- /dev/null +++ b/predict.sh @@ -0,0 +1 @@ +python predict.py -i "./inputs" -o "./outputs" --device "cuda:0" --model_path="./main_model.pt" --model_path2="./sub_model.pth" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a1af8561a0546f88b98cbf7a879243ea422fdbd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,83 @@ +backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work +beautifulsoup4 @ file:///opt/conda/conda-bld/beautifulsoup4_1650462163268/work +brotlipy==0.7.0 +certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi +cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work +chardet @ file:///tmp/build/80754af9/chardet_1607706768982/work +charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work +colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work +coloredlogs==15.0.1 +conda==4.13.0 +conda-build==3.21.9 +conda-content-trust @ file:///tmp/build/80754af9/conda-content-trust_1617045594566/work +conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1649105789509/work +cryptography @ file:///tmp/build/80754af9/cryptography_1652083456434/work +decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work +fastremap==1.13.3 +filelock @ file:///opt/conda/conda-bld/filelock_1647002191454/work +flatbuffers==22.9.24 +glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work +huggingface-hub==0.10.1 +humanfriendly==10.0 +idna @ file:///tmp/build/80754af9/idna_1637925883363/work +imagecodecs==2021.11.20 +imageio==2.22.2 +importlib-metadata==5.0.0 +itk==5.2.1.post1 +itk-core==5.2.1.post1 +itk-filtering==5.2.1.post1 +itk-io==5.2.1.post1 +itk-numerics==5.2.1.post1 +itk-registration==5.2.1.post1 +itk-segmentation==5.2.1.post1 +jedi @ file:///tmp/build/80754af9/jedi_1644299024593/work +Jinja2==2.10.1 +libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work +MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work +matplotlib-inline @ file:///tmp/build/80754af9/matplotlib-inline_1628242447089/work +mkl-fft==1.3.1 +mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work +mkl-service==2.4.0 +monai==0.9.0 +mpmath==1.2.1 +networkx==2.6.3 +numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1651563629415/work +onnxruntime-gpu==1.12.1 +opencv-python==4.6.0.66 +packaging==21.3 +parso @ file:///opt/conda/conda-bld/parso_1641458642106/work +pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work +pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work +Pillow==9.0.1 +pkginfo @ file:///tmp/build/80754af9/pkginfo_1643162084911/work +prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1633440160888/work +protobuf==4.21.8 +psutil @ file:///tmp/build/80754af9/psutil_1612298016854/work +ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl +pycosat==0.6.3 +pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work +Pygments @ file:///opt/conda/conda-bld/pygments_1644249106324/work +pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work +pyparsing==3.0.9 +PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work +pytz==2022.2.1 +PyWavelets==1.3.0 +PyYAML==6.0 +requests @ file:///opt/conda/conda-bld/requests_1641824580448/work +ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016701961/work +scikit-image==0.19.3 +scipy==1.7.2 +six @ file:///tmp/build/80754af9/six_1644875935023/work +soupsieve @ file:///tmp/build/80754af9/soupsieve_1636706018808/work +sympy==1.10.1 +tifffile==2021.11.2 +timm==0.6.11 +torch==1.12.1 +torchtext==0.13.1 +torchvision==0.13.1 +tqdm==4.64.1 +traitlets @ file:///tmp/build/80754af9/traitlets_1636710298902/work +typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work +urllib3 @ file:///opt/conda/conda-bld/urllib3_1643638302206/work +wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work +zipp==3.9.0 diff --git a/save_model.py b/save_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4538b09a36c6090870727e9309d3b6e32d84e6 --- /dev/null +++ b/save_model.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn + +from segmentation_models_pytorch import MAnet +from segmentation_models_pytorch.base.modules import Activation + + +class SegformerGH(MAnet): + def __init__( + self, + encoder_name: str = "mit_b5", + encoder_weights="imagenet", + decoder_channels=(256, 128, 64, 32, 32), + decoder_pab_channels=256, + in_channels: int = 3, + classes: int = 3, + ): + super(SegformerGH, self).__init__( + encoder_name=encoder_name, + encoder_weights=encoder_weights, + decoder_channels=decoder_channels, + decoder_pab_channels=decoder_pab_channels, + in_channels=in_channels, + classes=classes, + ) + + convert_relu_to_mish(self.encoder) + convert_relu_to_mish(self.decoder) + + self.cellprob_head = DeepSegmantationHead( + in_channels=decoder_channels[-1], out_channels=1, kernel_size=3, + ) + self.gradflow_head = DeepSegmantationHead( + in_channels=decoder_channels[-1], out_channels=2, kernel_size=3, + ) + + def forward(self, x): + """Sequentially pass `x` trough model`s encoder, decoder and heads""" + self.check_input_shape(x) + + features = self.encoder(x) + decoder_output = self.decoder(*features) + + gradflow_mask = self.gradflow_head(decoder_output) + cellprob_mask = self.cellprob_head(decoder_output) + + masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) + + return masks + + +class DeepSegmantationHead(nn.Sequential): + def __init__( + self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 + ): + conv2d_1 = nn.Conv2d( + in_channels, + in_channels // 2, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + bn = nn.BatchNorm2d(in_channels // 2) + conv2d_2 = nn.Conv2d( + in_channels // 2, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + mish = nn.Mish(inplace=True) + + upsampling = ( + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity() + ) + activation = Activation(activation) + super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation) + + +def convert_relu_to_mish(model): + for child_name, child in model.named_children(): + if isinstance(child, nn.ReLU): + setattr(model, child_name, nn.Mish(inplace=True)) + else: + convert_relu_to_mish(child) + + +if __name__ == "__main__": + model = SegformerGH( + encoder_name="mit_b5", + encoder_weights=None, + decoder_channels=(1024, 512, 256, 128, 64), + decoder_pab_channels=256, + in_channels=3, + classes=3, + ) + + model.load_state_dict(torch.load("./main_model.pth",map_location="cpu")) + torch.save(model, "main_model.pt") diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f34d0ad932b66e14a92026f8631f68a78283478a --- /dev/null +++ b/segmentation_models_pytorch/__init__.py @@ -0,0 +1,61 @@ +from . import datasets +from . import encoders +from . import decoders +from . import losses +from . import metrics + +from .decoders.unet import Unet +from .decoders.unetplusplus import UnetPlusPlus +from .decoders.manet import MAnet +from .decoders.linknet import Linknet +from .decoders.fpn import FPN +from .decoders.pspnet import PSPNet +from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus +from .decoders.pan import PAN + +from .__version__ import __version__ + +# some private imports for create_model function +from typing import Optional as _Optional +import torch as _torch + + +def create_model( + arch: str, + encoder_name: str = "resnet34", + encoder_weights: _Optional[str] = "imagenet", + in_channels: int = 3, + classes: int = 1, + **kwargs, +) -> _torch.nn.Module: + """Models entrypoint, allows to create any model architecture just with + parameters, without using its class + """ + + archs = [ + Unet, + UnetPlusPlus, + MAnet, + Linknet, + FPN, + PSPNet, + DeepLabV3, + DeepLabV3Plus, + PAN, + ] + archs_dict = {a.__name__.lower(): a for a in archs} + try: + model_class = archs_dict[arch.lower()] + except KeyError: + raise KeyError( + "Wrong architecture type `{}`. Available options are: {}".format( + arch, list(archs_dict.keys()), + ) + ) + return model_class( + encoder_name=encoder_name, + encoder_weights=encoder_weights, + in_channels=in_channels, + classes=classes, + **kwargs, + ) diff --git a/segmentation_models_pytorch/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6634d72c464b1d29e4d51b84547d230a1eb38bb Binary files /dev/null and b/segmentation_models_pytorch/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f32adb4dac1f3441f3732b36459e2afe87c9ff2d Binary files /dev/null and b/segmentation_models_pytorch/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/__pycache__/__version__.cpython-37.pyc b/segmentation_models_pytorch/__pycache__/__version__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1e21c78be0cec7e734ec953197a90bc205a1ff2 Binary files /dev/null and b/segmentation_models_pytorch/__pycache__/__version__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/__pycache__/__version__.cpython-39.pyc b/segmentation_models_pytorch/__pycache__/__version__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71424947b836332a7822146df93718f90744cf8a Binary files /dev/null and b/segmentation_models_pytorch/__pycache__/__version__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/__version__.py b/segmentation_models_pytorch/__version__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb0b13dc9d8c94355ae5d5fa4a9d45c5e70fc7c --- /dev/null +++ b/segmentation_models_pytorch/__version__.py @@ -0,0 +1,3 @@ +VERSION = (0, 3, 0) + +__version__ = ".".join(map(str, VERSION)) diff --git a/segmentation_models_pytorch/base/__init__.py b/segmentation_models_pytorch/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2be3ca0a342daa2005040e0e7ef48cc1edaf2ce --- /dev/null +++ b/segmentation_models_pytorch/base/__init__.py @@ -0,0 +1,11 @@ +from .model import SegmentationModel + +from .modules import ( + Conv2dReLU, + Attention, +) + +from .heads import ( + SegmentationHead, + ClassificationHead, +) diff --git a/segmentation_models_pytorch/base/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/base/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf0e484877eeabffa9d6073d173bb014875d8f35 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/base/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc5aa15da4373236b4c5f86cbf29bd64ef2d9c28 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/heads.cpython-37.pyc b/segmentation_models_pytorch/base/__pycache__/heads.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4429e6035f774378734a864aa6ea4f73d48cee38 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/heads.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/heads.cpython-39.pyc b/segmentation_models_pytorch/base/__pycache__/heads.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..509a3b79c64af1a0f179d3af872f5024a526303f Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/heads.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/initialization.cpython-37.pyc b/segmentation_models_pytorch/base/__pycache__/initialization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e32523947c2098d76f14bc762adf48e4c5c193c Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/initialization.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/initialization.cpython-39.pyc b/segmentation_models_pytorch/base/__pycache__/initialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1791f7c13eececef87d48123cb87ec51624a98b Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/initialization.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/base/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c8fe4fe13b7db5723c9f3f14f320a28e1a3a303 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/base/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f7519447a41c5323f8244466fa4eee78632fe8 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/modules.cpython-37.pyc b/segmentation_models_pytorch/base/__pycache__/modules.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b463ab440bde8d68ceef52bfc2f1642c16d84a4 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/modules.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/base/__pycache__/modules.cpython-39.pyc b/segmentation_models_pytorch/base/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61eb3fa8049f4bb628dac54bf19276690a815b27 Binary files /dev/null and b/segmentation_models_pytorch/base/__pycache__/modules.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/base/heads.py b/segmentation_models_pytorch/base/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc939cad427fb6b9edfb02c0e841a68de84a03a --- /dev/null +++ b/segmentation_models_pytorch/base/heads.py @@ -0,0 +1,34 @@ +import torch.nn as nn +from .modules import Activation + + +class SegmentationHead(nn.Sequential): + def __init__( + self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 + ): + conv2d = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + upsampling = ( + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity() + ) + activation = Activation(activation) + super().__init__(conv2d, upsampling, activation) + + +class ClassificationHead(nn.Sequential): + def __init__( + self, in_channels, classes, pooling="avg", dropout=0.2, activation=None + ): + if pooling not in ("max", "avg"): + raise ValueError( + "Pooling should be one of ('max', 'avg'), got {}.".format(pooling) + ) + pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1) + flatten = nn.Flatten() + dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() + linear = nn.Linear(in_channels, classes, bias=True) + activation = Activation(activation) + super().__init__(pool, flatten, dropout, linear, activation) diff --git a/segmentation_models_pytorch/base/initialization.py b/segmentation_models_pytorch/base/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..9622130204a0172d43a5f32f4ade065e100f746e --- /dev/null +++ b/segmentation_models_pytorch/base/initialization.py @@ -0,0 +1,27 @@ +import torch.nn as nn + + +def initialize_decoder(module): + for m in module.modules(): + + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def initialize_head(module): + for m in module.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..36850e979a19b16a9b9d284accbd5d89127c309a --- /dev/null +++ b/segmentation_models_pytorch/base/model.py @@ -0,0 +1,64 @@ +import torch +from . import initialization as init + + +class SegmentationModel(torch.nn.Module): + def initialize(self): + init.initialize_decoder(self.decoder) + init.initialize_head(self.segmentation_head) + if self.classification_head is not None: + init.initialize_head(self.classification_head) + + def check_input_shape(self, x): + + h, w = x.shape[-2:] + output_stride = self.encoder.output_stride + if h % output_stride != 0 or w % output_stride != 0: + new_h = ( + (h // output_stride + 1) * output_stride + if h % output_stride != 0 + else h + ) + new_w = ( + (w // output_stride + 1) * output_stride + if w % output_stride != 0 + else w + ) + raise RuntimeError( + f"Wrong input shape height={h}, width={w}. Expected image height and width " + f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." + ) + + def forward(self, x): + """Sequentially pass `x` trough model`s encoder, decoder and heads""" + + self.check_input_shape(x) + + features = self.encoder(x) + decoder_output = self.decoder(*features) + + masks = self.segmentation_head(decoder_output) + + if self.classification_head is not None: + labels = self.classification_head(features[-1]) + return masks, labels + + return masks + + @torch.no_grad() + def predict(self, x): + """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` + + Args: + x: 4D torch tensor with shape (batch_size, channels, height, width) + + Return: + prediction: 4D torch tensor with shape (batch_size, classes, height, width) + + """ + if self.training: + self.eval() + + x = self.forward(x) + + return x diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..096541fc248cfef434e1a9ffc6cfe1ad7f0acbe5 --- /dev/null +++ b/segmentation_models_pytorch/base/modules.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn + +try: + from inplace_abn import InPlaceABN +except ImportError: + InPlaceABN = None + + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + + if use_batchnorm == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + + "To install see: https://github.com/mapillary/inplace_abn" + ) + + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + if use_batchnorm == "inplace": + bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) + relu = nn.Identity() + + elif use_batchnorm and use_batchnorm != "inplace": + bn = nn.BatchNorm2d(out_channels) + + else: + bn = nn.Identity() + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class SCSEModule(nn.Module): + def __init__(self, in_channels, reduction=16): + super().__init__() + self.cSE = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, in_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, in_channels, 1), + nn.Sigmoid(), + ) + self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) + + def forward(self, x): + return x * self.cSE(x) + x * self.sSE(x) + + +class ArgMax(nn.Module): + def __init__(self, dim=None): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.argmax(x, dim=self.dim) + + +class Clamp(nn.Module): + def __init__(self, min=0, max=1): + super().__init__() + self.min, self.max = min, max + + def forward(self, x): + return torch.clamp(x, self.min, self.max) + + +class Activation(nn.Module): + def __init__(self, name, **params): + + super().__init__() + + if name is None or name == "identity": + self.activation = nn.Identity(**params) + elif name == "sigmoid": + self.activation = nn.Sigmoid() + elif name == "softmax2d": + self.activation = nn.Softmax(dim=1, **params) + elif name == "softmax": + self.activation = nn.Softmax(**params) + elif name == "logsoftmax": + self.activation = nn.LogSoftmax(**params) + elif name == "tanh": + self.activation = nn.Tanh() + elif name == "argmax": + self.activation = ArgMax(**params) + elif name == "argmax2d": + self.activation = ArgMax(dim=1, **params) + elif name == "clamp": + self.activation = Clamp(**params) + elif callable(name): + self.activation = name(**params) + else: + raise ValueError( + f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" + f"argmax/argmax2d/clamp/None; got {name}" + ) + + def forward(self, x): + return self.activation(x) + + +class Attention(nn.Module): + def __init__(self, name, **params): + super().__init__() + + if name is None: + self.attention = nn.Identity(**params) + elif name == "scse": + self.attention = SCSEModule(**params) + else: + raise ValueError("Attention {} is not implemented".format(name)) + + def forward(self, x): + return self.attention(x) diff --git a/segmentation_models_pytorch/datasets/__init__.py b/segmentation_models_pytorch/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8aba23b9d3faf1b2ccd0dc2655cf15639d2dc4a6 --- /dev/null +++ b/segmentation_models_pytorch/datasets/__init__.py @@ -0,0 +1 @@ +from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset diff --git a/segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5377cf2f337af094b42fe64a26c7834298d82c5 Binary files /dev/null and b/segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4519149ea09ac52f8fcbdbf09a477990d354a6fa Binary files /dev/null and b/segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-37.pyc b/segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8349f9f3e6247468e962b848d6a6a2d1754458ca Binary files /dev/null and b/segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-39.pyc b/segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..795288e0c2207ad7020bab1bf880e0bc431018bb Binary files /dev/null and b/segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/datasets/oxford_pet.py b/segmentation_models_pytorch/datasets/oxford_pet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8fbf7a6a2da96094e8b964bdeff62c83a39e38 --- /dev/null +++ b/segmentation_models_pytorch/datasets/oxford_pet.py @@ -0,0 +1,136 @@ +import os +import torch +import shutil +import numpy as np + +from PIL import Image +from tqdm import tqdm +from urllib.request import urlretrieve + + +class OxfordPetDataset(torch.utils.data.Dataset): + def __init__(self, root, mode="train", transform=None): + + assert mode in {"train", "valid", "test"} + + self.root = root + self.mode = mode + self.transform = transform + + self.images_directory = os.path.join(self.root, "images") + self.masks_directory = os.path.join(self.root, "annotations", "trimaps") + + self.filenames = self._read_split() # read train/valid/test splits + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + + filename = self.filenames[idx] + image_path = os.path.join(self.images_directory, filename + ".jpg") + mask_path = os.path.join(self.masks_directory, filename + ".png") + + image = np.array(Image.open(image_path).convert("RGB")) + + trimap = np.array(Image.open(mask_path)) + mask = self._preprocess_mask(trimap) + + sample = dict(image=image, mask=mask, trimap=trimap) + if self.transform is not None: + sample = self.transform(**sample) + + return sample + + @staticmethod + def _preprocess_mask(mask): + mask = mask.astype(np.float32) + mask[mask == 2.0] = 0.0 + mask[(mask == 1.0) | (mask == 3.0)] = 1.0 + return mask + + def _read_split(self): + split_filename = "test.txt" if self.mode == "test" else "trainval.txt" + split_filepath = os.path.join(self.root, "annotations", split_filename) + with open(split_filepath) as f: + split_data = f.read().strip("\n").split("\n") + filenames = [x.split(" ")[0] for x in split_data] + if self.mode == "train": # 90% for train + filenames = [x for i, x in enumerate(filenames) if i % 10 != 0] + elif self.mode == "valid": # 10% for validation + filenames = [x for i, x in enumerate(filenames) if i % 10 == 0] + return filenames + + @staticmethod + def download(root): + + # load images + filepath = os.path.join(root, "images.tar.gz") + download_url( + url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", + filepath=filepath, + ) + extract_archive(filepath) + + # load annotations + filepath = os.path.join(root, "annotations.tar.gz") + download_url( + url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", + filepath=filepath, + ) + extract_archive(filepath) + + +class SimpleOxfordPetDataset(OxfordPetDataset): + def __getitem__(self, *args, **kwargs): + + sample = super().__getitem__(*args, **kwargs) + + # resize images + image = np.array( + Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR) + ) + mask = np.array( + Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST) + ) + trimap = np.array( + Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST) + ) + + # convert to other format HWC -> CHW + sample["image"] = np.moveaxis(image, -1, 0) + sample["mask"] = np.expand_dims(mask, 0) + sample["trimap"] = np.expand_dims(trimap, 0) + + return sample + + +class TqdmUpTo(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + +def download_url(url, filepath): + directory = os.path.dirname(os.path.abspath(filepath)) + os.makedirs(directory, exist_ok=True) + if os.path.exists(filepath): + return + + with TqdmUpTo( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=os.path.basename(filepath), + ) as t: + urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None) + t.total = t.n + + +def extract_archive(filepath): + extract_dir = os.path.dirname(os.path.abspath(filepath)) + dst_dir = os.path.splitext(filepath)[0] + if not os.path.exists(dst_dir): + shutil.unpack_archive(filepath, extract_dir) diff --git a/segmentation_models_pytorch/decoders/__init__.py b/segmentation_models_pytorch/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d856221372a22dd2405fb3da91ec756c50fe071 Binary files /dev/null and b/segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31c52cfcfc5e4956954abee5cc00f0109d03dbaa Binary files /dev/null and b/segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__init__.py b/segmentation_models_pytorch/decoders/deeplabv3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c384551686048e56e54bc043b3201cfb6f07902f --- /dev/null +++ b/segmentation_models_pytorch/decoders/deeplabv3/__init__.py @@ -0,0 +1 @@ +from .model import DeepLabV3, DeepLabV3Plus diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c99aefb6e97352fc197879668171846331ed3fc Binary files /dev/null and b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f172dcef26bb9d6398df0f402a9349e92216ccd Binary files /dev/null and b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcea63624536e89ba6b668d06bdf9233e8236a06 Binary files /dev/null and b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..242ff8445f2f5bf469908a37834a14cadaef6f51 Binary files /dev/null and b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78573582976cb3a44c2fa3411d09af01d57bc444 Binary files /dev/null and b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f92de6e64f70c365e8c32c05d92f19c86a04aefc Binary files /dev/null and b/segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc37411a1af6cbb55933a1b0708250d0592fae7 --- /dev/null +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -0,0 +1,220 @@ +""" +BSD 3-Clause License + +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ["DeepLabV3Decoder"] + + +class DeepLabV3Decoder(nn.Sequential): + def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): + super().__init__( + ASPP(in_channels, out_channels, atrous_rates), + nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + self.out_channels = out_channels + + def forward(self, *features): + return super().forward(features[-1]) + + +class DeepLabV3PlusDecoder(nn.Module): + def __init__( + self, + encoder_channels, + out_channels=256, + atrous_rates=(12, 24, 36), + output_stride=16, + ): + super().__init__() + if output_stride not in {8, 16}: + raise ValueError( + "Output stride should be 8 or 16, got {}.".format(output_stride) + ) + + self.out_channels = out_channels + self.output_stride = output_stride + + self.aspp = nn.Sequential( + ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True), + SeparableConv2d( + out_channels, out_channels, kernel_size=3, padding=1, bias=False + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + scale_factor = 2 if output_stride == 8 else 4 + self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor) + + highres_in_channels = encoder_channels[-4] + highres_out_channels = 48 # proposed by authors of paper + self.block1 = nn.Sequential( + nn.Conv2d( + highres_in_channels, highres_out_channels, kernel_size=1, bias=False + ), + nn.BatchNorm2d(highres_out_channels), + nn.ReLU(), + ) + self.block2 = nn.Sequential( + SeparableConv2d( + highres_out_channels + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, *features): + aspp_features = self.aspp(features[-1]) + aspp_features = self.up(aspp_features) + high_res_features = self.block1(features[-4]) + concat_features = torch.cat([aspp_features, high_res_features], dim=1) + fused_features = self.block2(concat_features) + return fused_features + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + super().__init__( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + +class ASPPSeparableConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + super().__init__( + SeparableConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super().__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x): + size = x.shape[-2:] + for mod in self: + x = mod(x) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, out_channels, atrous_rates, separable=False): + super(ASPP, self).__init__() + modules = [] + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv + + modules.append(ASPPConvModule(in_channels, out_channels, rate1)) + modules.append(ASPPConvModule(in_channels, out_channels, rate2)) + modules.append(ASPPConvModule(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + +class SeparableConv2d(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + dephtwise_conv = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False, + ) + pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias,) + super().__init__(dephtwise_conv, pointwise_conv) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea5970f7986290051e3a7cd102587267609cb27 --- /dev/null +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -0,0 +1,179 @@ +from torch import nn +from typing import Optional + +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from segmentation_models_pytorch.encoders import get_encoder +from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder + + +class DeepLabV3(SegmentationModel): + """DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation" + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: A number of convolution filters in ASPP module. Default is 256 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + Returns: + ``torch.nn.Module``: **DeepLabV3** + + .. _DeeplabV3: + https://arxiv.org/abs/1706.05587 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_channels: int = 256, + in_channels: int = 3, + classes: int = 1, + activation: Optional[str] = None, + upsampling: int = 8, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + output_stride=8, + ) + + self.decoder = DeepLabV3Decoder( + in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=1, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + +class DeepLabV3Plus(SegmentationModel): + """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable + Convolution for Semantic Image Segmentation" + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) + decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) + decoder_channels: A number of convolution filters in ASPP module. Default is 256 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + Returns: + ``torch.nn.Module``: **DeepLabV3Plus** + + Reference: + https://arxiv.org/abs/1802.02611v3 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + encoder_output_stride: int = 16, + decoder_channels: int = 256, + decoder_atrous_rates: tuple = (12, 24, 36), + in_channels: int = 3, + classes: int = 1, + activation: Optional[str] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + ): + super().__init__() + + if encoder_output_stride not in [8, 16]: + raise ValueError( + "Encoder output stride should be 8 or 16, got {}".format( + encoder_output_stride + ) + ) + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + output_stride=encoder_output_stride, + ) + + self.decoder = DeepLabV3PlusDecoder( + encoder_channels=self.encoder.out_channels, + out_channels=decoder_channels, + atrous_rates=decoder_atrous_rates, + output_stride=encoder_output_stride, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=1, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None diff --git a/segmentation_models_pytorch/decoders/fpn/__init__.py b/segmentation_models_pytorch/decoders/fpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd72b0731afe38fdcd5fea3ffe9ca18b2aba86d --- /dev/null +++ b/segmentation_models_pytorch/decoders/fpn/__init__.py @@ -0,0 +1 @@ +from .model import FPN diff --git a/segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6840db7461085f402381aca6ce21712c464de1e3 Binary files /dev/null and b/segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0c4dc83084c077da09c4a763a36aacdccd203f Binary files /dev/null and b/segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d9320ce2c813497ffe66318a3b46301cb1c6c19 Binary files /dev/null and b/segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50cb8194901a533558d2656892202c649f26a284 Binary files /dev/null and b/segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/fpn/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/fpn/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7617aa606eaf59f862375c407536d2b4d591fda9 Binary files /dev/null and b/segmentation_models_pytorch/decoders/fpn/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/fpn/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/fpn/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f728d9224c70172cdc8a98c3bd45fcb4f3cf51e7 Binary files /dev/null and b/segmentation_models_pytorch/decoders/fpn/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/fpn/decoder.py b/segmentation_models_pytorch/decoders/fpn/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..766190f45c32c6f9e4eb6d03632585c339fb81fa --- /dev/null +++ b/segmentation_models_pytorch/decoders/fpn/decoder.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Conv3x3GNReLU(nn.Module): + def __init__(self, in_channels, out_channels, upsample=False): + super().__init__() + self.upsample = upsample + self.block = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False + ), + nn.GroupNorm(32, out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.block(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + return x + + +class FPNBlock(nn.Module): + def __init__(self, pyramid_channels, skip_channels): + super().__init__() + self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + skip = self.skip_conv(skip) + x = x + skip + return x + + +class SegmentationBlock(nn.Module): + def __init__(self, in_channels, out_channels, n_upsamples=0): + super().__init__() + + blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] + + if n_upsamples > 1: + for _ in range(1, n_upsamples): + blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) + + self.block = nn.Sequential(*blocks) + + def forward(self, x): + return self.block(x) + + +class MergeBlock(nn.Module): + def __init__(self, policy): + super().__init__() + if policy not in ["add", "cat"]: + raise ValueError( + "`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy) + ) + self.policy = policy + + def forward(self, x): + if self.policy == "add": + return sum(x) + elif self.policy == "cat": + return torch.cat(x, dim=1) + else: + raise ValueError( + "`merge_policy` must be one of: ['add', 'cat'], got {}".format( + self.policy + ) + ) + + +class FPNDecoder(nn.Module): + def __init__( + self, + encoder_channels, + encoder_depth=5, + pyramid_channels=256, + segmentation_channels=128, + dropout=0.2, + merge_policy="add", + ): + super().__init__() + + self.out_channels = ( + segmentation_channels + if merge_policy == "add" + else segmentation_channels * 4 + ) + if encoder_depth < 3: + raise ValueError( + "Encoder depth for FPN decoder cannot be less than 3, got {}.".format( + encoder_depth + ) + ) + + encoder_channels = encoder_channels[::-1] + encoder_channels = encoder_channels[: encoder_depth + 1] + + self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) + self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) + self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) + self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) + + self.seg_blocks = nn.ModuleList( + [ + SegmentationBlock( + pyramid_channels, segmentation_channels, n_upsamples=n_upsamples + ) + for n_upsamples in [3, 2, 1, 0] + ] + ) + + self.merge = MergeBlock(merge_policy) + self.dropout = nn.Dropout2d(p=dropout, inplace=True) + + def forward(self, *features): + c2, c3, c4, c5 = features[-4:] + + p5 = self.p5(c5) + p4 = self.p4(p5, c4) + p3 = self.p3(p4, c3) + p2 = self.p2(p3, c2) + + feature_pyramid = [ + seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2]) + ] + x = self.merge(feature_pyramid) + x = self.dropout(x) + + return x diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e6f67718851d216591d1781e14b0189210ae94 --- /dev/null +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from segmentation_models_pytorch.encoders import get_encoder +from .decoder import FPNDecoder + + +class FPN(SegmentationModel): + """FPN_ is a fully convolution neural network for image semantic segmentation. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_ + decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_ + decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** + and **cat** + decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **FPN** + + .. _FPN: + http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 128, + decoder_merge_policy: str = "add", + decoder_dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + activation: Optional[str] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + ): + super().__init__() + + # validate input params + if encoder_name.startswith("mit_b") and encoder_depth != 5: + raise ValueError( + "Encoder {} support only encoder_depth=5".format(encoder_name) + ) + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = FPNDecoder( + encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + dropout=decoder_dropout, + merge_policy=decoder_merge_policy, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=1, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "fpn-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/decoders/linknet/__init__.py b/segmentation_models_pytorch/decoders/linknet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69d2ff663a394d917a324c8e2c17d899f06e416a --- /dev/null +++ b/segmentation_models_pytorch/decoders/linknet/__init__.py @@ -0,0 +1 @@ +from .model import Linknet diff --git a/segmentation_models_pytorch/decoders/linknet/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/linknet/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0287b3f4feb07c85c52c2741165d3b1a1f33f2e9 Binary files /dev/null and b/segmentation_models_pytorch/decoders/linknet/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/linknet/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/linknet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690b7e807e0b3429e586ec77751088acb9796bd7 Binary files /dev/null and b/segmentation_models_pytorch/decoders/linknet/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/linknet/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/linknet/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28fa83245bf2ff98c695a8c586620c7696541e2c Binary files /dev/null and b/segmentation_models_pytorch/decoders/linknet/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/linknet/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/linknet/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03901c43955cdbb4aab46c11d3dcdb6d164a5cef Binary files /dev/null and b/segmentation_models_pytorch/decoders/linknet/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/linknet/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/linknet/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47424d3fde9931310a6962fa476b0d7a6489660d Binary files /dev/null and b/segmentation_models_pytorch/decoders/linknet/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/linknet/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/linknet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ca1c8d2f9c5e7e9b1eead81a3364bf907cbe38 Binary files /dev/null and b/segmentation_models_pytorch/decoders/linknet/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..70bc4371bb720af0a970696d011a35ff3afc54ce --- /dev/null +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -0,0 +1,82 @@ +import torch.nn as nn + +from segmentation_models_pytorch.base import modules + + +class TransposeX2(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + super().__init__() + layers = [ + nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=4, stride=2, padding=1 + ), + nn.ReLU(inplace=True), + ] + + if use_batchnorm: + layers.insert(1, nn.BatchNorm2d(out_channels)) + + super().__init__(*layers) + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + super().__init__() + + self.block = nn.Sequential( + modules.Conv2dReLU( + in_channels, + in_channels // 4, + kernel_size=1, + use_batchnorm=use_batchnorm, + ), + TransposeX2( + in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm + ), + modules.Conv2dReLU( + in_channels // 4, + out_channels, + kernel_size=1, + use_batchnorm=use_batchnorm, + ), + ) + + def forward(self, x, skip=None): + x = self.block(x) + if skip is not None: + x = x + skip + return x + + +class LinknetDecoder(nn.Module): + def __init__( + self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True, + ): + super().__init__() + + # remove first skip + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + channels = list(encoder_channels) + [prefinal_channels] + + self.blocks = nn.ModuleList( + [ + DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) + for i in range(n_blocks) + ] + ) + + def forward(self, *features): + features = features[1:] # remove first skip + features = features[::-1] # reverse channels to start from head of encoder + + x = features[0] + skips = features[1:] + + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + + return x diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c3139fdc4db0d5dddfbf292b76c0cc8fccb873 --- /dev/null +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -0,0 +1,98 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.base import ( + SegmentationHead, + SegmentationModel, + ClassificationHead, +) +from segmentation_models_pytorch.encoders import get_encoder +from .decoder import LinknetDecoder + + +class Linknet(SegmentationModel): + """Linknet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* + and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial + resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *sum* + for fusing decoder blocks with skip connections. + + Note: + This implementation by default has 4 skip connections (original - 3). + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **Linknet** + + .. _Linknet: + https://arxiv.org/abs/1707.03718 + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_batchnorm: bool = True, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + if encoder_name.startswith("mit_b"): + raise ValueError( + "Encoder `{}` is not supported for Linknet".format(encoder_name) + ) + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = LinknetDecoder( + encoder_channels=self.encoder.out_channels, + n_blocks=encoder_depth, + prefinal_channels=32, + use_batchnorm=decoder_use_batchnorm, + ) + + self.segmentation_head = SegmentationHead( + in_channels=32, out_channels=classes, activation=activation, kernel_size=1 + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "link-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/decoders/manet/__init__.py b/segmentation_models_pytorch/decoders/manet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3bdc788d300d6aa95b3894f2bba78214fd437e3 --- /dev/null +++ b/segmentation_models_pytorch/decoders/manet/__init__.py @@ -0,0 +1 @@ +from .model import MAnet diff --git a/segmentation_models_pytorch/decoders/manet/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/manet/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ff336a64a5ccb6652e7be6123f0e561dc9b30fa Binary files /dev/null and b/segmentation_models_pytorch/decoders/manet/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/manet/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/manet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b968e05c4201956a06772af702d8cb2198c50f66 Binary files /dev/null and b/segmentation_models_pytorch/decoders/manet/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/manet/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/manet/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e38fabc56640e5ad51311581081cf8b0ff81e48b Binary files /dev/null and b/segmentation_models_pytorch/decoders/manet/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/manet/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/manet/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3cb7de92c92c26bb9cea3b9a9470e6d38a71cd8 Binary files /dev/null and b/segmentation_models_pytorch/decoders/manet/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/manet/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/manet/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..032f000472f6b72c86278aed8e5463b7dfb77076 Binary files /dev/null and b/segmentation_models_pytorch/decoders/manet/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/manet/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/manet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a3759064c35b7d2fc253fcd83a0f95985c14e05 Binary files /dev/null and b/segmentation_models_pytorch/decoders/manet/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1227abd708b854d37c003d94234715de03d164b2 --- /dev/null +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -0,0 +1,188 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class PAB(nn.Module): + def __init__(self, in_channels, out_channels, pab_channels=64): + super(PAB, self).__init__() + # Series of 1x1 conv to generate attention feature maps + self.pab_channels = pab_channels + self.in_channels = in_channels + self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) + self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) + self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.map_softmax = nn.Softmax(dim=1) + self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x): + bsize = x.size()[0] + h = x.size()[2] + w = x.size()[3] + x_top = self.top_conv(x) + x_center = self.center_conv(x) + x_bottom = self.bottom_conv(x) + + x_top = x_top.flatten(2) + x_center = x_center.flatten(2).transpose(1, 2) + x_bottom = x_bottom.flatten(2).transpose(1, 2) + + sp_map = torch.matmul(x_center, x_top) + sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h * w, h * w) + sp_map = torch.matmul(sp_map, x_bottom) + sp_map = sp_map.reshape(bsize, self.in_channels, h, w) + x = x + sp_map + x = self.out_conv(x) + return x + + +class MFAB(nn.Module): + def __init__( + self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16 + ): + # MFAB is just a modified version of SE-blocks, one for skip, one for input + super(MFAB, self).__init__() + self.hl_conv = nn.Sequential( + md.Conv2dReLU( + in_channels, + in_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ), + md.Conv2dReLU( + in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm, + ), + ) + reduced_channels = max(1, skip_channels // reduction) + self.SE_ll = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(skip_channels, reduced_channels, 1), + nn.ReLU(inplace=True), + nn.Conv2d(reduced_channels, skip_channels, 1), + nn.Sigmoid(), + ) + self.SE_hl = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(skip_channels, reduced_channels, 1), + nn.ReLU(inplace=True), + nn.Conv2d(reduced_channels, skip_channels, 1), + nn.Sigmoid(), + ) + self.conv1 = md.Conv2dReLU( + skip_channels + + skip_channels, # we transform C-prime form high level to C from skip connection + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + def forward(self, x, skip=None): + x = self.hl_conv(x) + x = F.interpolate(x, scale_factor=2, mode="nearest") + attention_hl = self.SE_hl(x) + if skip is not None: + attention_ll = self.SE_ll(skip) + attention_hl = attention_hl + attention_ll + x = x * attention_hl + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class MAnetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + reduction=16, + use_batchnorm=True, + pab_channels=64, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here + blocks = [ + MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) + if skip_ch > 0 + else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) + ] + # for the last we dont have skip connection -> use simple decoder block + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + head = features[0] + skips = features[1:] + + x = self.center(head) + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + + return x diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..08e64a2aba8eb7d18eb89e92541ea5f14c678f55 --- /dev/null +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -0,0 +1,102 @@ +from typing import Optional, Union, List + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import MAnetDecoder + + +class MAnet(SegmentationModel): + """MAnet_ : Multi-scale Attention Net. The MA-Net can capture rich contextual dependencies based on + the attention mechanism, using two blocks: + - Position-wise Attention Block (PAB), which captures the spatial dependencies between pixels in a global view + - Multi-scale Fusion Attention Block (MFAB), which captures the channel dependencies between any feature map by + multi-scale semantic feature fusion + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + decoder_pab_channels: A number of channels for PAB module in decoder. + Default is 64. + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **MAnet** + + .. _MAnet: + https://ieeexplore.ieee.org/abstract/document/9201310 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_pab_channels: int = 64, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = MAnetDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + pab_channels=decoder_pab_channels, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "manet-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/decoders/pan/__init__.py b/segmentation_models_pytorch/decoders/pan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46327c35a041683dd24b6522ff75a4ac9559d60b --- /dev/null +++ b/segmentation_models_pytorch/decoders/pan/__init__.py @@ -0,0 +1 @@ +from .model import PAN diff --git a/segmentation_models_pytorch/decoders/pan/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/pan/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5f69a13b95d0db6d0b235017eac0e3dde1fdfd3 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pan/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/pan/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/pan/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86cb4f15a85d798f5fb73e84ed5a038825bd0031 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pan/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/pan/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/pan/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d1962802c587a19417d7c80e8b1b51cca9db048 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pan/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/pan/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/pan/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82dbd29ab5bdce658342627a0cd2c51397ea81ba Binary files /dev/null and b/segmentation_models_pytorch/decoders/pan/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/pan/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/pan/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44130b2d20d0ce191356982c32262606514e1510 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pan/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/pan/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/pan/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d11ea8ad89a5680981b22b14f9c662337850bf4e Binary files /dev/null and b/segmentation_models_pytorch/decoders/pan/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8f8675a1b5dc6ed3447aef63caeb7d96331529 --- /dev/null +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBnRelu(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + add_relu: bool = True, + interpolate: bool = False, + ): + super(ConvBnRelu, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=groups, + ) + self.add_relu = add_relu + self.interpolate = interpolate + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.add_relu: + x = self.activation(x) + if self.interpolate: + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + return x + + +class FPABlock(nn.Module): + def __init__(self, in_channels, out_channels, upscale_mode="bilinear"): + super(FPABlock, self).__init__() + + self.upscale_mode = upscale_mode + if self.upscale_mode == "bilinear": + self.align_corners = True + else: + self.align_corners = False + + # global pooling branch + self.branch1 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ) + + # midddle branch + self.mid = nn.Sequential( + ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + ) + self.down1 = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + ConvBnRelu( + in_channels=in_channels, + out_channels=1, + kernel_size=7, + stride=1, + padding=3, + ), + ) + self.down2 = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + ConvBnRelu( + in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2 + ), + ) + self.down3 = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + ConvBnRelu( + in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 + ), + ConvBnRelu( + in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 + ), + ) + self.conv2 = ConvBnRelu( + in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2 + ) + self.conv1 = ConvBnRelu( + in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3 + ) + + def forward(self, x): + h, w = x.size(2), x.size(3) + b1 = self.branch1(x) + upscale_parameters = dict( + mode=self.upscale_mode, align_corners=self.align_corners + ) + b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) + + mid = self.mid(x) + x1 = self.down1(x) + x2 = self.down2(x1) + x3 = self.down3(x2) + x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) + + x2 = self.conv2(x2) + x = x2 + x3 + x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) + + x1 = self.conv1(x1) + x = x + x1 + x = F.interpolate(x, size=(h, w), **upscale_parameters) + + x = torch.mul(x, mid) + x = x + b1 + return x + + +class GAUBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" + ): + super(GAUBlock, self).__init__() + + self.upscale_mode = upscale_mode + self.align_corners = True if upscale_mode == "bilinear" else None + + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvBnRelu( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + add_relu=False, + ), + nn.Sigmoid(), + ) + self.conv2 = ConvBnRelu( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + def forward(self, x, y): + """ + Args: + x: low level feature + y: high level feature + """ + h, w = x.size(2), x.size(3) + y_up = F.interpolate( + y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners + ) + x = self.conv2(x) + y = self.conv1(y) + z = torch.mul(x, y) + return y_up + z + + +class PANDecoder(nn.Module): + def __init__( + self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear" + ): + super().__init__() + + self.fpa = FPABlock( + in_channels=encoder_channels[-1], out_channels=decoder_channels + ) + self.gau3 = GAUBlock( + in_channels=encoder_channels[-2], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + self.gau2 = GAUBlock( + in_channels=encoder_channels[-3], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + self.gau1 = GAUBlock( + in_channels=encoder_channels[-4], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + + def forward(self, *features): + bottleneck = features[-1] + x5 = self.fpa(bottleneck) # 1/32 + x4 = self.gau3(features[-2], x5) # 1/16 + x3 = self.gau2(features[-3], x4) # 1/8 + x2 = self.gau1(features[-4], x3) # 1/4 + + return x2 diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8086d024ea9768832cbcd81addd1196cfb34c3cf --- /dev/null +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -0,0 +1,100 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import PANDecoder + + +class PAN(SegmentationModel): + """Implementation of PAN_ (Pyramid Attention Network). + + Note: + Currently works with shape of input tensor >= [B x C x 128 x 128] for pytorch <= 1.1.0 + and with shape of input tensor >= [B x C x 256 x 256] for pytorch == 1.3.1 + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer. + Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16. + decoder_channels: A number of convolution layer filters in decoder blocks + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **PAN** + + .. _PAN: + https://arxiv.org/abs/1805.10180 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_weights: Optional[str] = "imagenet", + encoder_output_stride: int = 16, + decoder_channels: int = 32, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + ): + super().__init__() + + if encoder_output_stride not in [16, 32]: + raise ValueError( + "PAN support output stride 16 or 32, got {}".format( + encoder_output_stride + ) + ) + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=5, + weights=encoder_weights, + output_stride=encoder_output_stride, + ) + + self.decoder = PANDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels, + out_channels=classes, + activation=activation, + kernel_size=3, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "pan-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/decoders/pspnet/__init__.py b/segmentation_models_pytorch/decoders/pspnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7eacc398ffb131b50a1907a448de43cc684def --- /dev/null +++ b/segmentation_models_pytorch/decoders/pspnet/__init__.py @@ -0,0 +1 @@ +from .model import PSPNet diff --git a/segmentation_models_pytorch/decoders/pspnet/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/pspnet/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00b7b745a8ffdf49d956fc40cf45c9c9d1a6f711 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pspnet/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/pspnet/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/pspnet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e94ac7a988288dc1a0cb84b9ea66ffb2f4c39a3 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pspnet/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/pspnet/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/pspnet/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbb397287c57c9aada7de98490a0ebd9857f856b Binary files /dev/null and b/segmentation_models_pytorch/decoders/pspnet/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/pspnet/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/pspnet/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0768f0c9ecb4763b255b894e725e720c77849ad4 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pspnet/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/pspnet/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/pspnet/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b53e87dbfabad646323e63313f81bade88151a9 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pspnet/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/pspnet/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/pspnet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5298f01319aa391904bf9b820d274489f89d4298 Binary files /dev/null and b/segmentation_models_pytorch/decoders/pspnet/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfd52e72404fd447cd58b365a592246306a3263 --- /dev/null +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules + + +class PSPBlock(nn.Module): + def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): + super().__init__() + if pool_size == 1: + use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape + self.pool = nn.Sequential( + nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), + modules.Conv2dReLU( + in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm + ), + ) + + def forward(self, x): + h, w = x.size(2), x.size(3) + x = self.pool(x) + x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=True) + return x + + +class PSPModule(nn.Module): + def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): + super().__init__() + + self.blocks = nn.ModuleList( + [ + PSPBlock( + in_channels, + in_channels // len(sizes), + size, + use_bathcnorm=use_bathcnorm, + ) + for size in sizes + ] + ) + + def forward(self, x): + xs = [block(x) for block in self.blocks] + [x] + x = torch.cat(xs, dim=1) + return x + + +class PSPDecoder(nn.Module): + def __init__( + self, encoder_channels, use_batchnorm=True, out_channels=512, dropout=0.2, + ): + super().__init__() + + self.psp = PSPModule( + in_channels=encoder_channels[-1], + sizes=(1, 2, 3, 6), + use_bathcnorm=use_batchnorm, + ) + + self.conv = modules.Conv2dReLU( + in_channels=encoder_channels[-1] * 2, + out_channels=out_channels, + kernel_size=1, + use_batchnorm=use_batchnorm, + ) + + self.dropout = nn.Dropout2d(p=dropout) + + def forward(self, *features): + x = features[-1] + x = self.psp(x) + x = self.conv(x) + x = self.dropout(x) + + return x diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9997f82bd77e4e8ac44e7550daa53739f1f828 --- /dev/null +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -0,0 +1,101 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import PSPDecoder + + +class PSPNet(SegmentationModel): + """PSPNet_ is a fully convolution neural network for image semantic segmentation. Consist of + *encoder* and *Spatial Pyramid* (decoder). Spatial Pyramid build on top of encoder and does not + use "fine-features" (features of high spatial resolution). PSPNet can be used for multiclass segmentation + of high resolution images, however it is not good for detecting small objects and producing accurate, + pixel-level mask. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + psp_out_channels: A number of filters in Spatial Pyramid + psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **PSPNet** + + .. _PSPNet: + https://arxiv.org/abs/1612.01105 + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_weights: Optional[str] = "imagenet", + encoder_depth: int = 3, + psp_out_channels: int = 512, + psp_use_batchnorm: bool = True, + psp_dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + upsampling: int = 8, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = PSPDecoder( + encoder_channels=self.encoder.out_channels, + use_batchnorm=psp_use_batchnorm, + out_channels=psp_out_channels, + dropout=psp_dropout, + ) + + self.segmentation_head = SegmentationHead( + in_channels=psp_out_channels, + out_channels=classes, + kernel_size=3, + activation=activation, + upsampling=upsampling, + ) + + if aux_params: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "psp-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/decoders/unet/__init__.py b/segmentation_models_pytorch/decoders/unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9a367cd7338999a742961fbc1a93289a6380da --- /dev/null +++ b/segmentation_models_pytorch/decoders/unet/__init__.py @@ -0,0 +1 @@ +from .model import Unet diff --git a/segmentation_models_pytorch/decoders/unet/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/unet/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b4f1fb221f12c75505a1c3736a605690ce5b08 Binary files /dev/null and b/segmentation_models_pytorch/decoders/unet/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/unet/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/unet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03e2442976919fa863243ad176db68978ecef118 Binary files /dev/null and b/segmentation_models_pytorch/decoders/unet/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/unet/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/unet/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a3330bc41cfcda2082b784a0decfe5c682d82cf Binary files /dev/null and b/segmentation_models_pytorch/decoders/unet/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/unet/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/unet/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e98b654571ebc155050068df2b577920e0aad721 Binary files /dev/null and b/segmentation_models_pytorch/decoders/unet/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/unet/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/unet/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f78fddfae0582e0c305f7365b92503cf330baee Binary files /dev/null and b/segmentation_models_pytorch/decoders/unet/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/unet/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/unet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e676b2b05b4223d5534e9fbf861f59df89bb12b Binary files /dev/null and b/segmentation_models_pytorch/decoders/unet/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6cbcfcc2df7f6dbd4e2f5c3525dc803af43c7ed9 --- /dev/null +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=None, + ): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention1 = md.Attention( + attention_type, in_channels=in_channels + skip_channels + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention2 = md.Attention(attention_type, in_channels=out_channels) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class UnetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + if center: + self.center = CenterBlock( + head_channels, head_channels, use_batchnorm=use_batchnorm + ) + else: + self.center = nn.Identity() + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + blocks = [ + DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + head = features[0] + skips = features[1:] + + x = self.center(head) + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + + return x diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..46528c5a9b376bd739c45d1bce3ba823b1a869fa --- /dev/null +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -0,0 +1,102 @@ +from typing import Optional, Union, List + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import UnetDecoder + + +class Unet(SegmentationModel): + """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* + and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial + resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* + for fusing decoder blocks with skip connections. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + decoder_attention_type: Attention module used in decoder of the model. Available options are + **None** and **scse** (https://arxiv.org/abs/1808.08127). + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: Unet + + .. _Unet: + https://arxiv.org/abs/1505.04597 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = UnetDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + center=True if encoder_name.startswith("vgg") else False, + attention_type=decoder_attention_type, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "u-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__init__.py b/segmentation_models_pytorch/decoders/unetplusplus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bda62b70a30d92622616f7279d153bc4b68f6b54 --- /dev/null +++ b/segmentation_models_pytorch/decoders/unetplusplus/__init__.py @@ -0,0 +1 @@ +from .model import UnetPlusPlus diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f626fc228a4633401b36f59868f5870e935995c1 Binary files /dev/null and b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..020af1a88452d51dab1cb488552c6067df879f3f Binary files /dev/null and b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/decoder.cpython-37.pyc b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..137cf108a70d8db9ebefb2bd9e3303f38f79bd4a Binary files /dev/null and b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/decoder.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/decoder.cpython-39.pyc b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62a82ade285835dc2af7e0d2f0c809b914998759 Binary files /dev/null and b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/decoder.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/model.cpython-37.pyc b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d3210afcb8bf3183c5cf7ee44ff801808d5c1a Binary files /dev/null and b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/model.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/model.cpython-39.pyc b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edcab30b0a61f9aab2653747e6624a9955121fcf Binary files /dev/null and b/segmentation_models_pytorch/decoders/unetplusplus/__pycache__/model.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3d41cb1306db13ae52cd290ae769907f550592e6 --- /dev/null +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=None, + ): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention1 = md.Attention( + attention_type, in_channels=in_channels + skip_channels + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention2 = md.Attention(attention_type, in_channels=out_channels) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class UnetPlusPlusDecoder(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + self.in_channels = [head_channels] + list(decoder_channels[:-1]) + self.skip_channels = list(encoder_channels[1:]) + [0] + self.out_channels = decoder_channels + if center: + self.center = CenterBlock( + head_channels, head_channels, use_batchnorm=use_batchnorm + ) + else: + self.center = nn.Identity() + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + + blocks = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(layer_idx + 1): + if depth_idx == 0: + in_ch = self.in_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) + out_ch = self.out_channels[layer_idx] + else: + out_ch = self.skip_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * ( + layer_idx + 1 - depth_idx + ) + in_ch = self.skip_channels[layer_idx - 1] + blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( + in_ch, skip_ch, out_ch, **kwargs + ) + blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock( + self.in_channels[-1], 0, self.out_channels[-1], **kwargs + ) + self.blocks = nn.ModuleDict(blocks) + self.depth = len(self.in_channels) - 1 + + def forward(self, *features): + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + # start building dense connections + dense_x = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(self.depth - layer_idx): + if layer_idx == 0: + output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( + features[depth_idx], features[depth_idx + 1] + ) + dense_x[f"x_{depth_idx}_{depth_idx}"] = output + else: + dense_l_i = depth_idx + layer_idx + cat_features = [ + dense_x[f"x_{idx}_{dense_l_i}"] + for idx in range(depth_idx + 1, dense_l_i + 1) + ] + cat_features = torch.cat( + cat_features + [features[dense_l_i + 1]], dim=1 + ) + dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ + f"x_{depth_idx}_{dense_l_i}" + ](dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features) + dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( + dense_x[f"x_{0}_{self.depth-1}"] + ) + return dense_x[f"x_{0}_{self.depth}"] diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py new file mode 100644 index 0000000000000000000000000000000000000000..60d591f0ddb41b84324d796fec9dbe95364fdbfa --- /dev/null +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -0,0 +1,107 @@ +from typing import Optional, Union, List + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import UnetPlusPlusDecoder + + +class UnetPlusPlus(SegmentationModel): + """Unet++ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* + and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial + resolution (skip connections) which are used by decoder to define accurate segmentation mask. Decoder of + Unet++ is more complex than in usual Unet. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + decoder_attention_type: Attention module used in decoder of the model. + Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **Unet++** + + Reference: + https://arxiv.org/abs/1807.10165 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + if encoder_name.startswith("mit_b"): + raise ValueError( + "UnetPlusPlus is not support encoder_name={}".format(encoder_name) + ) + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = UnetPlusPlusDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + center=True if encoder_name.startswith("vgg") else False, + attention_type=decoder_attention_type, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "unetplusplus-{}".format(encoder_name) + self.initialize() diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37988514a9d617ce4fc26b976a051d6d75d9073c --- /dev/null +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -0,0 +1,126 @@ +# import timm +import functools +import torch.utils.model_zoo as model_zoo + +# from .resnet import resnet_encoders +# from .dpn import dpn_encoders +# from .vgg import vgg_encoders +# from .senet import senet_encoders +# from .densenet import densenet_encoders +# from .inceptionresnetv2 import inceptionresnetv2_encoders +# from .inceptionv4 import inceptionv4_encoders +# from .efficientnet import efficient_net_encoders +# from .mobilenet import mobilenet_encoders +# from .xception import xception_encoders +# from .timm_efficientnet import timm_efficientnet_encoders +# from .timm_resnest import timm_resnest_encoders +# from .timm_res2net import timm_res2net_encoders +# from .timm_regnet import timm_regnet_encoders +# from .timm_sknet import timm_sknet_encoders +# from .timm_mobilenetv3 import timm_mobilenetv3_encoders +# from .timm_gernet import timm_gernet_encoders +from .mix_transformer import mix_transformer_encoders + +# from .timm_universal import TimmUniversalEncoder + +# from ._preprocessing import preprocess_input + +encoders = {} +# encoders.update(resnet_encoders) +# encoders.update(dpn_encoders) +# encoders.update(vgg_encoders) +# encoders.update(senet_encoders) +# encoders.update(densenet_encoders) +# encoders.update(inceptionresnetv2_encoders) +# encoders.update(inceptionv4_encoders) +# encoders.update(efficient_net_encoders) +# encoders.update(mobilenet_encoders) +# encoders.update(xception_encoders) +# encoders.update(timm_efficientnet_encoders) +# encoders.update(timm_resnest_encoders) +# encoders.update(timm_res2net_encoders) +# encoders.update(timm_regnet_encoders) +# encoders.update(timm_sknet_encoders) +# encoders.update(timm_mobilenetv3_encoders) +# encoders.update(timm_gernet_encoders) +encoders.update(mix_transformer_encoders) + + +def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): + + if name.startswith("tu-"): + name = name[3:] + encoder = TimmUniversalEncoder( + name=name, + in_channels=in_channels, + depth=depth, + output_stride=output_stride, + pretrained=weights is not None, + **kwargs, + ) + return encoder + + try: + Encoder = encoders[name]["encoder"] + except KeyError: + raise KeyError( + "Wrong encoder name `{}`, supported encoders: {}".format( + name, list(encoders.keys()) + ) + ) + + params = encoders[name]["params"] + params.update(depth=depth) + encoder = Encoder(**params) + + if weights is not None: + try: + settings = encoders[name]["pretrained_settings"][weights] + except KeyError: + raise KeyError( + "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( + weights, name, list(encoders[name]["pretrained_settings"].keys()), + ) + ) + encoder.load_state_dict(model_zoo.load_url(settings["url"])) + + encoder.set_in_channels(in_channels, pretrained=weights is not None) + if output_stride != 32: + encoder.make_dilated(output_stride) + + return encoder + + +def get_encoder_names(): + return list(encoders.keys()) + + +def get_preprocessing_params(encoder_name, pretrained="imagenet"): + + if encoder_name.startswith("tu-"): + encoder_name = encoder_name[3:] + if encoder_name not in timm.models.registry._model_has_pretrained: + raise ValueError( + f"{encoder_name} does not have pretrained weights and preprocessing parameters" + ) + settings = timm.models.registry._model_default_cfgs[encoder_name] + else: + all_settings = encoders[encoder_name]["pretrained_settings"] + if pretrained not in all_settings.keys(): + raise ValueError( + "Available pretrained options {}".format(all_settings.keys()) + ) + settings = all_settings[pretrained] + + formatted_settings = {} + formatted_settings["input_space"] = settings.get("input_space", "RGB") + formatted_settings["input_range"] = list(settings.get("input_range", [0, 1])) + formatted_settings["mean"] = list(settings.get("mean")) + formatted_settings["std"] = list(settings.get("std")) + + return formatted_settings + + +def get_preprocessing_fn(encoder_name, pretrained="imagenet"): + params = get_preprocessing_params(encoder_name, pretrained=pretrained) + return functools.partial(preprocess_input, **params) diff --git a/segmentation_models_pytorch/encoders/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/encoders/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4a3bc84c653d2ea2a7a8d90f478f14435ba950 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..680abec9e6034d33848065f47cdfbe2f779dc373 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/_base.cpython-37.pyc b/segmentation_models_pytorch/encoders/__pycache__/_base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6617404b08db8db6ef208e00a1ea7784425c2918 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/_base.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/_base.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5804825c4ecb5a87eef58e4dd59ed7a3f69b9792 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/_base.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/_preprocessing.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/_preprocessing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52ea8a7e7507b41206ba06b4c813b5effe4ef087 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/_preprocessing.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/_utils.cpython-37.pyc b/segmentation_models_pytorch/encoders/__pycache__/_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc137d20c3cdc266613870c5f4e5df6ae1274615 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/_utils.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/_utils.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c236148546283e5066382515c11fedf276488a53 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/_utils.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/densenet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/densenet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a9808b19167c3e873f4c337fc035fc4f726a636 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/densenet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/dpn.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/dpn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a45a1d7ac24b70a3e107aaa3c56683ff580ccaaa Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/dpn.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/efficientnet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/efficientnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35b78ba39ca3e593fe9b1e13682fb2526f7b6c57 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/efficientnet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/inceptionresnetv2.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/inceptionresnetv2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ac2484b5e26d248b16620a2b5f1ad3061e20e4c Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/inceptionresnetv2.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/inceptionv4.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/inceptionv4.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616aa8046a825abacd08310b82cdbabc75c52817 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/inceptionv4.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/mix_transformer.cpython-37.pyc b/segmentation_models_pytorch/encoders/__pycache__/mix_transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42f2165f72be2e86f4d3e69c25e693b3a34d6331 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/mix_transformer.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/mix_transformer.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/mix_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18ef9dc386fe891eeef16b5c7cdde6083663e188 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/mix_transformer.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/mobilenet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/mobilenet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a912e47f4cf9985cafbd365a204c80d672e4afd3 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/mobilenet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/resnet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/resnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de1001b93fdf693fd3686c6b061d65ebfea8fd11 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/resnet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/senet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/senet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1717b3ea7667e85b7c80808a6c01d1ae3404691 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/senet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_efficientnet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_efficientnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce4051f5bf7f811c1ca544b0d280b53e216a178 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_efficientnet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_gernet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_gernet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c82038131656cb746862bff355eff0f5ab3f89c Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_gernet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_mobilenetv3.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_mobilenetv3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b8b7c3b9d5d5cce6f99d959b5e24ed416ba86ef Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_mobilenetv3.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_regnet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_regnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aab560a03a6425200467bd8d1da87d0e00c7d442 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_regnet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_res2net.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_res2net.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe0d3a5977d09de8af6ab60db98673f712f51488 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_res2net.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_resnest.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_resnest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f66d541748ee798f1a9a76a0fed759ad87c9215 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_resnest.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_sknet.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_sknet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa57845ae2bf36a474d1d744fece820c7a16f5a Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_sknet.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/timm_universal.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/timm_universal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1ec227185cb6e365125bd5a4002b5f9187530d Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/timm_universal.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/vgg.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/vgg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aca89f79ea184a253f3a0b31116deb18ca90f955 Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/vgg.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/__pycache__/xception.cpython-39.pyc b/segmentation_models_pytorch/encoders/__pycache__/xception.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..463a315e0301be3eea901e43d5c123c417715a6e Binary files /dev/null and b/segmentation_models_pytorch/encoders/__pycache__/xception.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c58386310843ac7952c05530381edeb7b5a8424b --- /dev/null +++ b/segmentation_models_pytorch/encoders/_base.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from typing import List +from collections import OrderedDict + +from . import _utils as utils + + +class EncoderMixin: + """Add encoder functionality such as: + - output channels specification of feature tensors (produced by encoder) + - patching first convolution for arbitrary input channels + """ + + _output_stride = 32 + + @property + def out_channels(self): + """Return channels dimensions for each tensor of forward output of encoder""" + return self._out_channels[: self._depth + 1] + + @property + def output_stride(self): + return min(self._output_stride, 2 ** self._depth) + + def set_in_channels(self, in_channels, pretrained=True): + """Change first convolution channels""" + if in_channels == 3: + return + + self._in_channels = in_channels + if self._out_channels[0] == 3: + self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) + + utils.patch_first_conv( + model=self, new_in_channels=in_channels, pretrained=pretrained + ) + + def get_stages(self): + """Override it in your implementation""" + raise NotImplementedError + + def make_dilated(self, output_stride): + + if output_stride == 16: + stage_list = [ + 5, + ] + dilation_list = [ + 2, + ] + + elif output_stride == 8: + stage_list = [4, 5] + dilation_list = [2, 4] + + else: + raise ValueError( + "Output stride should be 16 or 8, got {}.".format(output_stride) + ) + + self._output_stride = output_stride + + stages = self.get_stages() + for stage_indx, dilation_rate in zip(stage_list, dilation_list): + utils.replace_strides_with_dilation( + module=stages[stage_indx], dilation_rate=dilation_rate, + ) diff --git a/segmentation_models_pytorch/encoders/_preprocessing.py b/segmentation_models_pytorch/encoders/_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..ec19d542f1fd8033525ef056adf252041db26e15 --- /dev/null +++ b/segmentation_models_pytorch/encoders/_preprocessing.py @@ -0,0 +1,23 @@ +import numpy as np + + +def preprocess_input( + x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs +): + + if input_space == "BGR": + x = x[..., ::-1].copy() + + if input_range is not None: + if x.max() > 1 and input_range[1] == 1: + x = x / 255.0 + + if mean is not None: + mean = np.array(mean) + x = x - mean + + if std is not None: + std = np.array(std) + x = x / std + + return x diff --git a/segmentation_models_pytorch/encoders/_utils.py b/segmentation_models_pytorch/encoders/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2ab014d434ad0815a0f7fab96b03c85ff0163e --- /dev/null +++ b/segmentation_models_pytorch/encoders/_utils.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn + + +def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): + """Change first convolution layer input channels. + In case: + in_channels == 1 or in_channels == 2 -> reuse original weights + in_channels > 3 -> make random kaiming normal initialization + """ + + # get first conv + for module in model.modules(): + if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: + break + + weight = module.weight.detach() + module.in_channels = new_in_channels + + if not pretrained: + module.weight = nn.parameter.Parameter( + torch.Tensor( + module.out_channels, + new_in_channels // module.groups, + *module.kernel_size + ) + ) + module.reset_parameters() + + elif new_in_channels == 1: + new_weight = weight.sum(1, keepdim=True) + module.weight = nn.parameter.Parameter(new_weight) + + else: + new_weight = torch.Tensor( + module.out_channels, new_in_channels // module.groups, *module.kernel_size + ) + + for i in range(new_in_channels): + new_weight[:, i] = weight[:, i % default_in_channels] + + new_weight = new_weight * (default_in_channels / new_in_channels) + module.weight = nn.parameter.Parameter(new_weight) + + +def replace_strides_with_dilation(module, dilation_rate): + """Patch Conv2d modules replacing strides with dilation""" + for mod in module.modules(): + if isinstance(mod, nn.Conv2d): + mod.stride = (1, 1) + mod.dilation = (dilation_rate, dilation_rate) + kh, kw = mod.kernel_size + mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) + + # Kostyl for EfficientNet + if hasattr(mod, "static_padding"): + mod.static_padding = nn.Identity() diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..7e2c580ba6ec9544a1b7b9f116fa69a195abd0d2 --- /dev/null +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -0,0 +1,156 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import re +import torch.nn as nn + +from pretrainedmodels.models.torchvision_models import pretrained_settings +from torchvision.models.densenet import DenseNet + +from ._base import EncoderMixin + + +class TransitionWithSkip(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + for module in self.module: + x = module(x) + if isinstance(module, nn.ReLU): + skip = x + return x, skip + + +class DenseNetEncoder(DenseNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + del self.classifier + + def make_dilated(self, *args, **kwargs): + raise ValueError( + "DenseNet encoders do not support dilated mode " + "due to pooling operation for downsampling!" + ) + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential( + self.features.conv0, self.features.norm0, self.features.relu0 + ), + nn.Sequential( + self.features.pool0, + self.features.denseblock1, + TransitionWithSkip(self.features.transition1), + ), + nn.Sequential( + self.features.denseblock2, TransitionWithSkip(self.features.transition2) + ), + nn.Sequential( + self.features.denseblock3, TransitionWithSkip(self.features.transition3) + ), + nn.Sequential(self.features.denseblock4, self.features.norm5), + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + if isinstance(x, (list, tuple)): + x, skip = x + features.append(skip) + else: + features.append(x) + + return features + + def load_state_dict(self, state_dict): + pattern = re.compile( + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + # remove linear + state_dict.pop("classifier.bias", None) + state_dict.pop("classifier.weight", None) + + super().load_state_dict(state_dict) + + +densenet_encoders = { + "densenet121": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet121"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 1024), + "num_init_features": 64, + "growth_rate": 32, + "block_config": (6, 12, 24, 16), + }, + }, + "densenet169": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet169"], + "params": { + "out_channels": (3, 64, 256, 512, 1280, 1664), + "num_init_features": 64, + "growth_rate": 32, + "block_config": (6, 12, 32, 32), + }, + }, + "densenet201": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet201"], + "params": { + "out_channels": (3, 64, 256, 512, 1792, 1920), + "num_init_features": 64, + "growth_rate": 32, + "block_config": (6, 12, 48, 32), + }, + }, + "densenet161": { + "encoder": DenseNetEncoder, + "pretrained_settings": pretrained_settings["densenet161"], + "params": { + "out_channels": (3, 96, 384, 768, 2112, 2208), + "num_init_features": 96, + "growth_rate": 48, + "block_config": (6, 12, 36, 24), + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py new file mode 100644 index 0000000000000000000000000000000000000000..08de735f90391b191ac5ab74d35bf4375a9c965d --- /dev/null +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -0,0 +1,174 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pretrainedmodels.models.dpn import DPN +from pretrainedmodels.models.dpn import pretrained_settings + +from ._base import EncoderMixin + + +class DPNEncoder(DPN, EncoderMixin): + def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._stage_idxs = stage_idxs + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.last_linear + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential( + self.features[0].conv, self.features[0].bn, self.features[0].act + ), + nn.Sequential( + self.features[0].pool, self.features[1 : self._stage_idxs[0]] + ), + self.features[self._stage_idxs[0] : self._stage_idxs[1]], + self.features[self._stage_idxs[1] : self._stage_idxs[2]], + self.features[self._stage_idxs[2] : self._stage_idxs[3]], + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + if isinstance(x, (list, tuple)): + features.append(F.relu(torch.cat(x, dim=1), inplace=True)) + else: + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +dpn_encoders = { + "dpn68": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn68"], + "params": { + "stage_idxs": (4, 8, 20, 24), + "out_channels": (3, 10, 144, 320, 704, 832), + "groups": 32, + "inc_sec": (16, 32, 32, 64), + "k_r": 128, + "k_sec": (3, 4, 12, 3), + "num_classes": 1000, + "num_init_features": 10, + "small": True, + "test_time_pool": True, + }, + }, + "dpn68b": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn68b"], + "params": { + "stage_idxs": (4, 8, 20, 24), + "out_channels": (3, 10, 144, 320, 704, 832), + "b": True, + "groups": 32, + "inc_sec": (16, 32, 32, 64), + "k_r": 128, + "k_sec": (3, 4, 12, 3), + "num_classes": 1000, + "num_init_features": 10, + "small": True, + "test_time_pool": True, + }, + }, + "dpn92": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn92"], + "params": { + "stage_idxs": (4, 8, 28, 32), + "out_channels": (3, 64, 336, 704, 1552, 2688), + "groups": 32, + "inc_sec": (16, 32, 24, 128), + "k_r": 96, + "k_sec": (3, 4, 20, 3), + "num_classes": 1000, + "num_init_features": 64, + "test_time_pool": True, + }, + }, + "dpn98": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn98"], + "params": { + "stage_idxs": (4, 10, 30, 34), + "out_channels": (3, 96, 336, 768, 1728, 2688), + "groups": 40, + "inc_sec": (16, 32, 32, 128), + "k_r": 160, + "k_sec": (3, 6, 20, 3), + "num_classes": 1000, + "num_init_features": 96, + "test_time_pool": True, + }, + }, + "dpn107": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn107"], + "params": { + "stage_idxs": (5, 13, 33, 37), + "out_channels": (3, 128, 376, 1152, 2432, 2688), + "groups": 50, + "inc_sec": (20, 64, 64, 128), + "k_r": 200, + "k_sec": (4, 8, 20, 3), + "num_classes": 1000, + "num_init_features": 128, + "test_time_pool": True, + }, + }, + "dpn131": { + "encoder": DPNEncoder, + "pretrained_settings": pretrained_settings["dpn131"], + "params": { + "stage_idxs": (5, 13, 41, 45), + "out_channels": (3, 128, 352, 832, 1984, 2688), + "groups": 40, + "inc_sec": (16, 32, 32, 128), + "k_r": 160, + "k_sec": (4, 8, 28, 3), + "num_classes": 1000, + "num_init_features": 128, + "test_time_pool": True, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0216e9010b9fcc1d8c94d6d8ef150bc95d2746e8 --- /dev/null +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -0,0 +1,178 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" +import torch.nn as nn +from efficientnet_pytorch import EfficientNet +from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params + +from ._base import EncoderMixin + + +class EfficientNetEncoder(EfficientNet, EncoderMixin): + def __init__(self, stage_idxs, out_channels, model_name, depth=5): + + blocks_args, global_params = get_model_params(model_name, override_params=None) + super().__init__(blocks_args, global_params) + + self._stage_idxs = stage_idxs + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + del self._fc + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self._conv_stem, self._bn0, self._swish), + self._blocks[: self._stage_idxs[0]], + self._blocks[self._stage_idxs[0] : self._stage_idxs[1]], + self._blocks[self._stage_idxs[1] : self._stage_idxs[2]], + self._blocks[self._stage_idxs[2] :], + ] + + def forward(self, x): + stages = self.get_stages() + + block_number = 0.0 + drop_connect_rate = self._global_params.drop_connect_rate + + features = [] + for i in range(self._depth + 1): + + # Identity and Sequential stages + if i < 2: + x = stages[i](x) + + # Block stages need drop_connect rate + else: + for module in stages[i]: + drop_connect = drop_connect_rate * block_number / len(self._blocks) + block_number += 1.0 + x = module(x, drop_connect) + + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("_fc.bias", None) + state_dict.pop("_fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +def _get_pretrained_settings(encoder): + pretrained_settings = { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": url_map[encoder], + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": url_map_advprop[encoder], + "input_space": "RGB", + "input_range": [0, 1], + }, + } + return pretrained_settings + + +efficient_net_encoders = { + "efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (3, 5, 9, 16), + "model_name": "efficientnet-b0", + }, + }, + "efficientnet-b1": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (5, 8, 16, 23), + "model_name": "efficientnet-b1", + }, + }, + "efficientnet-b2": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), + "params": { + "out_channels": (3, 32, 24, 48, 120, 352), + "stage_idxs": (5, 8, 16, 23), + "model_name": "efficientnet-b2", + }, + }, + "efficientnet-b3": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), + "params": { + "out_channels": (3, 40, 32, 48, 136, 384), + "stage_idxs": (5, 8, 18, 26), + "model_name": "efficientnet-b3", + }, + }, + "efficientnet-b4": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), + "params": { + "out_channels": (3, 48, 32, 56, 160, 448), + "stage_idxs": (6, 10, 22, 32), + "model_name": "efficientnet-b4", + }, + }, + "efficientnet-b5": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), + "params": { + "out_channels": (3, 48, 40, 64, 176, 512), + "stage_idxs": (8, 13, 27, 39), + "model_name": "efficientnet-b5", + }, + }, + "efficientnet-b6": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), + "params": { + "out_channels": (3, 56, 40, 72, 200, 576), + "stage_idxs": (9, 15, 31, 45), + "model_name": "efficientnet-b6", + }, + }, + "efficientnet-b7": { + "encoder": EfficientNetEncoder, + "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), + "params": { + "out_channels": (3, 64, 48, 80, 224, 640), + "stage_idxs": (11, 18, 38, 55), + "model_name": "efficientnet-b7", + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..310853ffcd0cce7ef101bde6e00f6836aeed9fdd --- /dev/null +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -0,0 +1,92 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn +from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 +from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings + +from ._base import EncoderMixin + + +class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + # correct paddings + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.kernel_size == (3, 3): + m.padding = (1, 1) + if isinstance(m, nn.MaxPool2d): + m.padding = (1, 1) + + # remove linear layers + del self.avgpool_1a + del self.last_linear + + def make_dilated(self, *args, **kwargs): + raise ValueError( + "InceptionResnetV2 encoder does not support dilated mode " + "due to pooling operation for downsampling!" + ) + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), + nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), + nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), + nn.Sequential(self.mixed_6a, self.repeat_1), + nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +inceptionresnetv2_encoders = { + "inceptionresnetv2": { + "encoder": InceptionResNetV2Encoder, + "pretrained_settings": pretrained_settings["inceptionresnetv2"], + "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, + } +} diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py new file mode 100644 index 0000000000000000000000000000000000000000..96b0c8b17de1ea2345600fc6760c656927a7403e --- /dev/null +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -0,0 +1,95 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn +from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d +from pretrainedmodels.models.inceptionv4 import pretrained_settings + +from ._base import EncoderMixin + + +class InceptionV4Encoder(InceptionV4, EncoderMixin): + def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._stage_idxs = stage_idxs + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + # correct paddings + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.kernel_size == (3, 3): + m.padding = (1, 1) + if isinstance(m, nn.MaxPool2d): + m.padding = (1, 1) + + # remove linear layers + del self.last_linear + + def make_dilated(self, stage_list, dilation_list): + raise ValueError( + "InceptionV4 encoder does not support dilated mode " + "due to pooling operation for downsampling!" + ) + + def get_stages(self): + return [ + nn.Identity(), + self.features[: self._stage_idxs[0]], + self.features[self._stage_idxs[0] : self._stage_idxs[1]], + self.features[self._stage_idxs[1] : self._stage_idxs[2]], + self.features[self._stage_idxs[2] : self._stage_idxs[3]], + self.features[self._stage_idxs[3] :], + ] + + def forward(self, x): + + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +inceptionv4_encoders = { + "inceptionv4": { + "encoder": InceptionV4Encoder, + "pretrained_settings": pretrained_settings["inceptionv4"], + "params": { + "stage_idxs": (3, 5, 9, 15), + "out_channels": (3, 64, 192, 384, 1024, 1536), + "num_classes": 1001, + }, + } +} diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..62269b39a4717c8db851d9ad8a4e45b38c79d57d --- /dev/null +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -0,0 +1,664 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = ( + self.kv(x_) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + else: + kv = ( + self.kv(x) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed( + img_size=img_size, + patch_size=7, + stride=4, + in_chans=in_chans, + embed_dim=embed_dims[0], + ) + self.patch_embed2 = OverlapPatchEmbed( + img_size=img_size // 4, + patch_size=3, + stride=2, + in_chans=embed_dims[0], + embed_dim=embed_dims[1], + ) + self.patch_embed3 = OverlapPatchEmbed( + img_size=img_size // 8, + patch_size=3, + stride=2, + in_chans=embed_dims[1], + embed_dim=embed_dims[2], + ) + self.patch_embed4 = OverlapPatchEmbed( + img_size=img_size // 16, + patch_size=3, + stride=2, + in_chans=embed_dims[2], + embed_dim=embed_dims[3], + ) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList( + [ + Block( + dim=embed_dims[0], + num_heads=num_heads[0], + mlp_ratio=mlp_ratios[0], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[0], + ) + for i in range(depths[0]) + ] + ) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList( + [ + Block( + dim=embed_dims[1], + num_heads=num_heads[1], + mlp_ratio=mlp_ratios[1], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[1], + ) + for i in range(depths[1]) + ] + ) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList( + [ + Block( + dim=embed_dims[2], + num_heads=num_heads[2], + mlp_ratio=mlp_ratios[2], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[2], + ) + for i in range(depths[2]) + ] + ) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList( + [ + Block( + dim=embed_dims[3], + num_heads=num_heads[3], + mlp_ratio=mlp_ratios[3], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[3], + ) + for i in range(depths[3]) + ] + ) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def init_weights(self, pretrained=None): + pass + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return { + "pos_embed1", + "pos_embed2", + "pos_embed3", + "pos_embed4", + "cls_token", + } # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def forward_features(self, x): + B = x.shape[0] + outs = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + + +# --------------------------------------------------------------- +# End of NVIDIA code +# --------------------------------------------------------------- + +from ._base import EncoderMixin # noqa E402 + + +class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + def make_dilated(self, *args, **kwargs): + raise ValueError("MixVisionTransformer encoder does not support dilated mode") + + def set_in_channels(self, in_channels, *args, **kwargs): + if in_channels != 3: + raise ValueError( + "MixVisionTransformer encoder does not support in_channels setting other than 3" + ) + + def forward(self, x): + + # create dummy output for the first block + B, C, H, W = x.shape + dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) + + return [x, dummy] + self.forward_features(x)[: self._depth - 1] + + def load_state_dict(self, state_dict): + state_dict.pop("head.weight", None) + state_dict.pop("head.bias", None) + return super().load_state_dict(state_dict) + + +def get_pretrained_cfg(name): + return { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/{}.pth".format( + name + ), + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + + +mix_transformer_encoders = { + "mit_b0": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b0"),}, + "params": dict( + out_channels=(3, 0, 32, 64, 160, 256), + patch_size=4, + embed_dims=[32, 64, 160, 256], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b1": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b1"),}, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b2": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b2"),}, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b3": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b3"),}, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b4": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b4"),}, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b5": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b5"),}, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 6, 40, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, +} diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..878732ef0ec36b6bd8a0d7c651a5e81e54731d5e --- /dev/null +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -0,0 +1,80 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torchvision +import torch.nn as nn + +from ._base import EncoderMixin + + +class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + del self.classifier + + def get_stages(self): + return [ + nn.Identity(), + self.features[:2], + self.features[2:4], + self.features[4:7], + self.features[7:14], + self.features[14:], + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("classifier.1.bias", None) + state_dict.pop("classifier.1.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +mobilenet_encoders = { + "mobilenet_v2": { + "encoder": MobileNetV2Encoder, + "pretrained_settings": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "params": {"out_channels": (3, 16, 24, 32, 96, 1280),}, + }, +} diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfc532b12f8e8a6f9c9656ee6fa411e8d5c1294 --- /dev/null +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -0,0 +1,238 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" +from copy import deepcopy + +import torch.nn as nn + +from torchvision.models.resnet import ResNet +from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import Bottleneck +from pretrainedmodels.models.torchvision_models import pretrained_settings + +from ._base import EncoderMixin + + +class ResNetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.avgpool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.relu), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +new_settings = { + "resnet18": { + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", # noqa + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth", # noqa + }, + "resnet50": { + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", # noqa + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth", # noqa + }, + "resnext50_32x4d": { + "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", # noqa + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", # noqa + }, + "resnext101_32x4d": { + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", # noqa + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth", # noqa + }, + "resnext101_32x8d": { + "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", # noqa + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", # noqa + }, + "resnext101_32x16d": { + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", + "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", # noqa + "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", # noqa + }, + "resnext101_32x32d": { + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth", + }, + "resnext101_32x48d": { + "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth", + }, +} + +pretrained_settings = deepcopy(pretrained_settings) +for model_name, sources in new_settings.items(): + if model_name not in pretrained_settings: + pretrained_settings[model_name] = {} + + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + + +resnet_encoders = { + "resnet18": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet18"], + "params": { + "out_channels": (3, 64, 64, 128, 256, 512), + "block": BasicBlock, + "layers": [2, 2, 2, 2], + }, + }, + "resnet34": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet34"], + "params": { + "out_channels": (3, 64, 64, 128, 256, 512), + "block": BasicBlock, + "layers": [3, 4, 6, 3], + }, + }, + "resnet50": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet50"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 6, 3], + }, + }, + "resnet101": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet101"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + }, + }, + "resnet152": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnet152"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 8, 36, 3], + }, + }, + "resnext50_32x4d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext50_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 6, 3], + "groups": 32, + "width_per_group": 4, + }, + }, + "resnext101_32x4d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 4, + }, + }, + "resnext101_32x8d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x8d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 8, + }, + }, + "resnext101_32x16d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x16d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 16, + }, + }, + "resnext101_32x32d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x32d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 32, + }, + }, + "resnext101_32x48d": { + "encoder": ResNetEncoder, + "pretrained_settings": pretrained_settings["resnext101_32x48d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "groups": 32, + "width_per_group": 48, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0f6fd851f07a0eed93146d63c39dad250de151 --- /dev/null +++ b/segmentation_models_pytorch/encoders/senet.py @@ -0,0 +1,174 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn + +from pretrainedmodels.models.senet import ( + SENet, + SEBottleneck, + SEResNetBottleneck, + SEResNeXtBottleneck, + pretrained_settings, +) +from ._base import EncoderMixin + + +class SENetEncoder(SENet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + del self.last_linear + del self.avg_pool + + def get_stages(self): + return [ + nn.Identity(), + self.layer0[:-1], + nn.Sequential(self.layer0[-1], self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +senet_encoders = { + "senet154": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["senet154"], + "params": { + "out_channels": (3, 128, 256, 512, 1024, 2048), + "block": SEBottleneck, + "dropout_p": 0.2, + "groups": 64, + "layers": [3, 8, 36, 3], + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnet50": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnet50"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNetBottleneck, + "layers": [3, 4, 6, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnet101": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnet101"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNetBottleneck, + "layers": [3, 4, 23, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnet152": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnet152"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNetBottleneck, + "layers": [3, 8, 36, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnext50_32x4d": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNeXtBottleneck, + "layers": [3, 4, 6, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 32, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, + "se_resnext101_32x4d": { + "encoder": SENetEncoder, + "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SEResNeXtBottleneck, + "layers": [3, 4, 23, 3], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 32, + "inplanes": 64, + "input_3x3": False, + "num_classes": 1000, + "reduction": 16, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b0cfd27088c84d851eea1ddc26b54e881fce8a --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -0,0 +1,390 @@ +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.efficientnet import EfficientNet +from timm.models.efficientnet import decode_arch_def, round_channels, default_cfgs +from timm.models.layers.activations import Swish + +from ._base import EncoderMixin + + +def get_efficientnet_kwargs( + channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2 +): + """Create EfficientNet model. + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ["ds_r1_k3_s1_e1_c16_se0.25"], + ["ir_r2_k3_s2_e6_c24_se0.25"], + ["ir_r2_k5_s2_e6_c40_se0.25"], + ["ir_r3_k3_s2_e6_c80_se0.25"], + ["ir_r3_k5_s1_e6_c112_se0.25"], + ["ir_r4_k5_s2_e6_c192_se0.25"], + ["ir_r1_k3_s1_e6_c320_se0.25"], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=Swish, + drop_rate=drop_rate, + drop_path_rate=0.2, + ) + return model_kwargs + + +def gen_efficientnet_lite_kwargs( + channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2 +): + """EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ["ds_r1_k3_s1_e1_c16"], + ["ir_r2_k3_s2_e6_c24"], + ["ir_r2_k5_s2_e6_c40"], + ["ir_r3_k3_s2_e6_c80"], + ["ir_r3_k5_s1_e6_c112"], + ["ir_r4_k5_s2_e6_c192"], + ["ir_r1_k3_s1_e6_c320"], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=nn.ReLU6, + drop_rate=drop_rate, + drop_path_rate=0.2, + ) + return model_kwargs + + +class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): + def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + + self._stage_idxs = stage_idxs + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + del self.classifier + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv_stem, self.bn1, self.act1), + self.blocks[: self._stage_idxs[0]], + self.blocks[self._stage_idxs[0] : self._stage_idxs[1]], + self.blocks[self._stage_idxs[1] : self._stage_idxs[2]], + self.blocks[self._stage_idxs[2] :], + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("classifier.bias", None) + state_dict.pop("classifier.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +class EfficientNetEncoder(EfficientNetBaseEncoder): + def __init__( + self, + stage_idxs, + out_channels, + depth=5, + channel_multiplier=1.0, + depth_multiplier=1.0, + drop_rate=0.2, + ): + kwargs = get_efficientnet_kwargs( + channel_multiplier, depth_multiplier, drop_rate + ) + super().__init__(stage_idxs, out_channels, depth, **kwargs) + + +class EfficientNetLiteEncoder(EfficientNetBaseEncoder): + def __init__( + self, + stage_idxs, + out_channels, + depth=5, + channel_multiplier=1.0, + depth_multiplier=1.0, + drop_rate=0.2, + ): + kwargs = gen_efficientnet_lite_kwargs( + channel_multiplier, depth_multiplier, drop_rate + ) + super().__init__(stage_idxs, out_channels, depth, **kwargs) + + +def prepare_settings(settings): + return { + "mean": settings["mean"], + "std": settings["std"], + "url": settings["url"], + "input_range": (0, 1), + "input_space": "RGB", + } + + +timm_efficientnet_encoders = { + "timm-efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b0"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b0_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b0_ns"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, + "timm-efficientnet-b1": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b1"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b1_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b1_ns"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.1, + "drop_rate": 0.2, + }, + }, + "timm-efficientnet-b2": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b2"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b2_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b2_ns"]), + }, + "params": { + "out_channels": (3, 32, 24, 48, 120, 352), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.1, + "depth_multiplier": 1.2, + "drop_rate": 0.3, + }, + }, + "timm-efficientnet-b3": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b3"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b3_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b3_ns"]), + }, + "params": { + "out_channels": (3, 40, 32, 48, 136, 384), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.2, + "depth_multiplier": 1.4, + "drop_rate": 0.3, + }, + }, + "timm-efficientnet-b4": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b4"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b4_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b4_ns"]), + }, + "params": { + "out_channels": (3, 48, 32, 56, 160, 448), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.4, + "depth_multiplier": 1.8, + "drop_rate": 0.4, + }, + }, + "timm-efficientnet-b5": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b5"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b5_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b5_ns"]), + }, + "params": { + "out_channels": (3, 48, 40, 64, 176, 512), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.6, + "depth_multiplier": 2.2, + "drop_rate": 0.4, + }, + }, + "timm-efficientnet-b6": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b6"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b6_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b6_ns"]), + }, + "params": { + "out_channels": (3, 56, 40, 72, 200, 576), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.8, + "depth_multiplier": 2.6, + "drop_rate": 0.5, + }, + }, + "timm-efficientnet-b7": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b7"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b7_ap"]), + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_b7_ns"]), + }, + "params": { + "out_channels": (3, 64, 48, 80, 224, 640), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 2.0, + "depth_multiplier": 3.1, + "drop_rate": 0.5, + }, + }, + "timm-efficientnet-b8": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_b8"]), + "advprop": prepare_settings(default_cfgs["tf_efficientnet_b8_ap"]), + }, + "params": { + "out_channels": (3, 72, 56, 88, 248, 704), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 2.2, + "depth_multiplier": 3.6, + "drop_rate": 0.5, + }, + }, + "timm-efficientnet-l2": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "noisy-student": prepare_settings(default_cfgs["tf_efficientnet_l2_ns"]), + }, + "params": { + "out_channels": (3, 136, 104, 176, 480, 1376), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 4.3, + "depth_multiplier": 5.3, + "drop_rate": 0.5, + }, + }, + "timm-tf_efficientnet_lite0": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite0"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, + "timm-tf_efficientnet_lite1": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite1"]), + }, + "params": { + "out_channels": (3, 32, 24, 40, 112, 320), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.0, + "depth_multiplier": 1.1, + "drop_rate": 0.2, + }, + }, + "timm-tf_efficientnet_lite2": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite2"]), + }, + "params": { + "out_channels": (3, 32, 24, 48, 120, 352), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.1, + "depth_multiplier": 1.2, + "drop_rate": 0.3, + }, + }, + "timm-tf_efficientnet_lite3": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite3"]), + }, + "params": { + "out_channels": (3, 32, 32, 48, 136, 384), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.2, + "depth_multiplier": 1.4, + "drop_rate": 0.3, + }, + }, + "timm-tf_efficientnet_lite4": { + "encoder": EfficientNetLiteEncoder, + "pretrained_settings": { + "imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite4"]), + }, + "params": { + "out_channels": (3, 32, 32, 56, 160, 448), + "stage_idxs": (2, 3, 5), + "channel_multiplier": 1.4, + "depth_multiplier": 1.8, + "drop_rate": 0.4, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b94c9593498657d1088942ea7614037888130df --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -0,0 +1,124 @@ +from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet + +from ._base import EncoderMixin +import torch.nn as nn + + +class GERNetEncoder(ByobNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.stages[0], + self.stages[1], + self.stages[2], + nn.Sequential(self.stages[3], self.stages[4], self.final_conv), + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight", None) + state_dict.pop("head.fc.bias", None) + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + "timm-gernet_s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth", # noqa + }, + "timm-gernet_m": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth", # noqa + }, + "timm-gernet_l": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth", # noqa + }, +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + +timm_gernet_encoders = { + "timm-gernet_s": { + "encoder": GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_s"], + "params": { + "out_channels": (3, 13, 48, 48, 384, 1920), + "cfg": ByoModelCfg( + blocks=( + ByoBlockCfg(type="basic", d=1, c=48, s=2, gs=0, br=1.0), + ByoBlockCfg(type="basic", d=3, c=48, s=2, gs=0, br=1.0), + ByoBlockCfg(type="bottle", d=7, c=384, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type="bottle", d=2, c=560, s=2, gs=1, br=3.0), + ByoBlockCfg(type="bottle", d=1, c=256, s=1, gs=1, br=3.0), + ), + stem_chs=13, + stem_pool=None, + num_features=1920, + ), + }, + }, + "timm-gernet_m": { + "encoder": GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_m"], + "params": { + "out_channels": (3, 32, 128, 192, 640, 2560), + "cfg": ByoModelCfg( + blocks=( + ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), + ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type="bottle", d=4, c=640, s=2, gs=1, br=3.0), + ByoBlockCfg(type="bottle", d=1, c=640, s=1, gs=1, br=3.0), + ), + stem_chs=32, + stem_pool=None, + num_features=2560, + ), + }, + }, + "timm-gernet_l": { + "encoder": GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_l"], + "params": { + "out_channels": (3, 32, 128, 192, 640, 2560), + "cfg": ByoModelCfg( + blocks=( + ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), + ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type="bottle", d=5, c=640, s=2, gs=1, br=3.0), + ByoBlockCfg(type="bottle", d=4, c=640, s=1, gs=1, br=3.0), + ), + stem_chs=32, + stem_pool=None, + num_features=2560, + ), + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d96948d679a64efbbb095cafd0e352c807af06 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py @@ -0,0 +1,151 @@ +import timm +import numpy as np +import torch.nn as nn + +from ._base import EncoderMixin + + +def _make_divisible(x, divisible_by=8): + return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) + + +class MobileNetV3Encoder(nn.Module, EncoderMixin): + def __init__(self, model_name, width_mult, depth=5, **kwargs): + super().__init__() + if "large" not in model_name and "small" not in model_name: + raise ValueError("MobileNetV3 wrong model name {}".format(model_name)) + + self._mode = "small" if "small" in model_name else "large" + self._depth = depth + self._out_channels = self._get_channels(self._mode, width_mult) + self._in_channels = 3 + + # minimal models replace hardswish with relu + self.model = timm.create_model( + model_name=model_name, + scriptable=True, # torch.jit scriptable + exportable=True, # onnx export + features_only=True, + ) + + def _get_channels(self, mode, width_mult): + if mode == "small": + channels = [16, 16, 24, 48, 576] + else: + channels = [16, 24, 40, 112, 960] + channels = [3,] + [_make_divisible(x * width_mult) for x in channels] + return tuple(channels) + + def get_stages(self): + if self._mode == "small": + return [ + nn.Identity(), + nn.Sequential(self.model.conv_stem, self.model.bn1, self.model.act1,), + self.model.blocks[0], + self.model.blocks[1], + self.model.blocks[2:4], + self.model.blocks[4:], + ] + elif self._mode == "large": + return [ + nn.Identity(), + nn.Sequential( + self.model.conv_stem, + self.model.bn1, + self.model.act1, + self.model.blocks[0], + ), + self.model.blocks[1], + self.model.blocks[2], + self.model.blocks[3:5], + self.model.blocks[5:], + ] + else: + ValueError( + "MobileNetV3 mode should be small or large, got {}".format(self._mode) + ) + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("conv_head.weight", None) + state_dict.pop("conv_head.bias", None) + state_dict.pop("classifier.weight", None) + state_dict.pop("classifier.bias", None) + self.model.load_state_dict(state_dict, **kwargs) + + +mobilenetv3_weights = { + "tf_mobilenetv3_large_075": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth" # noqa + }, + "tf_mobilenetv3_large_100": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth" # noqa + }, + "tf_mobilenetv3_large_minimal_100": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth" # noqa + }, + "tf_mobilenetv3_small_075": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth" # noqa + }, + "tf_mobilenetv3_small_100": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth" # noqa + }, + "tf_mobilenetv3_small_minimal_100": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth" # noqa + }, +} + +pretrained_settings = {} +for model_name, sources in mobilenetv3_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "input_space": "RGB", + } + + +timm_mobilenetv3_encoders = { + "timm-mobilenetv3_large_075": { + "encoder": MobileNetV3Encoder, + "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_075"], + "params": {"model_name": "tf_mobilenetv3_large_075", "width_mult": 0.75}, + }, + "timm-mobilenetv3_large_100": { + "encoder": MobileNetV3Encoder, + "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_100"], + "params": {"model_name": "tf_mobilenetv3_large_100", "width_mult": 1.0}, + }, + "timm-mobilenetv3_large_minimal_100": { + "encoder": MobileNetV3Encoder, + "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_minimal_100"], + "params": {"model_name": "tf_mobilenetv3_large_minimal_100", "width_mult": 1.0}, + }, + "timm-mobilenetv3_small_075": { + "encoder": MobileNetV3Encoder, + "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_075"], + "params": {"model_name": "tf_mobilenetv3_small_075", "width_mult": 0.75}, + }, + "timm-mobilenetv3_small_100": { + "encoder": MobileNetV3Encoder, + "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_100"], + "params": {"model_name": "tf_mobilenetv3_small_100", "width_mult": 1.0}, + }, + "timm-mobilenetv3_small_minimal_100": { + "encoder": MobileNetV3Encoder, + "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_minimal_100"], + "params": {"model_name": "tf_mobilenetv3_small_minimal_100", "width_mult": 1.0}, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_regnet.py b/segmentation_models_pytorch/encoders/timm_regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c87a3e85304e2c061515e9d7835665b66e2504 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_regnet.py @@ -0,0 +1,342 @@ +from ._base import EncoderMixin +from timm.models.regnet import RegNet +import torch.nn as nn + + +class RegNetEncoder(RegNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.s1, + self.s2, + self.s3, + self.s4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight", None) + state_dict.pop("head.fc.bias", None) + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + "timm-regnetx_002": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth", # noqa + }, + "timm-regnetx_004": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth", # noqa + }, + "timm-regnetx_006": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth", # noqa + }, + "timm-regnetx_008": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth", # noqa + }, + "timm-regnetx_016": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth", # noqa + }, + "timm-regnetx_032": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth", # noqa + }, + "timm-regnetx_040": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth", # noqa + }, + "timm-regnetx_064": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth", # noqa + }, + "timm-regnetx_080": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth", # noqa + }, + "timm-regnetx_120": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth", # noqa + }, + "timm-regnetx_160": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth", # noqa + }, + "timm-regnetx_320": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth", # noqa + }, + "timm-regnety_002": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth", # noqa + }, + "timm-regnety_004": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth", # noqa + }, + "timm-regnety_006": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth", # noqa + }, + "timm-regnety_008": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth", # noqa + }, + "timm-regnety_016": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth", # noqa + }, + "timm-regnety_032": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth", # noqa + }, + "timm-regnety_040": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth", # noqa + }, + "timm-regnety_064": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth", # noqa + }, + "timm-regnety_080": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth", # noqa + }, + "timm-regnety_120": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth", # noqa + }, + "timm-regnety_160": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth", # noqa + }, + "timm-regnety_320": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth", # noqa + }, +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + +# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0.0, bottle_ratio=1.0, stem_width=32) + cfg.update(**kwargs) + return cfg + + +timm_regnet_encoders = { + "timm-regnetx_002": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_002"], + "params": { + "out_channels": (3, 32, 24, 56, 152, 368), + "cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), + }, + }, + "timm-regnetx_004": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_004"], + "params": { + "out_channels": (3, 32, 32, 64, 160, 384), + "cfg": _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), + }, + }, + "timm-regnetx_006": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_006"], + "params": { + "out_channels": (3, 32, 48, 96, 240, 528), + "cfg": _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), + }, + }, + "timm-regnetx_008": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_008"], + "params": { + "out_channels": (3, 32, 64, 128, 288, 672), + "cfg": _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), + }, + }, + "timm-regnetx_016": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_016"], + "params": { + "out_channels": (3, 32, 72, 168, 408, 912), + "cfg": _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), + }, + }, + "timm-regnetx_032": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_032"], + "params": { + "out_channels": (3, 32, 96, 192, 432, 1008), + "cfg": _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), + }, + }, + "timm-regnetx_040": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_040"], + "params": { + "out_channels": (3, 32, 80, 240, 560, 1360), + "cfg": _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), + }, + }, + "timm-regnetx_064": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_064"], + "params": { + "out_channels": (3, 32, 168, 392, 784, 1624), + "cfg": _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), + }, + }, + "timm-regnetx_080": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_080"], + "params": { + "out_channels": (3, 32, 80, 240, 720, 1920), + "cfg": _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), + }, + }, + "timm-regnetx_120": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_120"], + "params": { + "out_channels": (3, 32, 224, 448, 896, 2240), + "cfg": _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), + }, + }, + "timm-regnetx_160": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_160"], + "params": { + "out_channels": (3, 32, 256, 512, 896, 2048), + "cfg": _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), + }, + }, + "timm-regnetx_320": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_320"], + "params": { + "out_channels": (3, 32, 336, 672, 1344, 2520), + "cfg": _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), + }, + }, + # regnety + "timm-regnety_002": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_002"], + "params": { + "out_channels": (3, 32, 24, 56, 152, 368), + "cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), + }, + }, + "timm-regnety_004": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_004"], + "params": { + "out_channels": (3, 32, 48, 104, 208, 440), + "cfg": _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25), + }, + }, + "timm-regnety_006": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_006"], + "params": { + "out_channels": (3, 32, 48, 112, 256, 608), + "cfg": _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), + }, + }, + "timm-regnety_008": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_008"], + "params": { + "out_channels": (3, 32, 64, 128, 320, 768), + "cfg": _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), + }, + }, + "timm-regnety_016": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_016"], + "params": { + "out_channels": (3, 32, 48, 120, 336, 888), + "cfg": _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), + }, + }, + "timm-regnety_032": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_032"], + "params": { + "out_channels": (3, 32, 72, 216, 576, 1512), + "cfg": _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), + }, + }, + "timm-regnety_040": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_040"], + "params": { + "out_channels": (3, 32, 128, 192, 512, 1088), + "cfg": _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), + }, + }, + "timm-regnety_064": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_064"], + "params": { + "out_channels": (3, 32, 144, 288, 576, 1296), + "cfg": _mcfg( + w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25 + ), + }, + }, + "timm-regnety_080": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_080"], + "params": { + "out_channels": (3, 32, 168, 448, 896, 2016), + "cfg": _mcfg( + w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25 + ), + }, + }, + "timm-regnety_120": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_120"], + "params": { + "out_channels": (3, 32, 224, 448, 896, 2240), + "cfg": _mcfg( + w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25 + ), + }, + }, + "timm-regnety_160": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_160"], + "params": { + "out_channels": (3, 32, 224, 448, 1232, 3024), + "cfg": _mcfg( + w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25 + ), + }, + }, + "timm-regnety_320": { + "encoder": RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_320"], + "params": { + "out_channels": (3, 32, 232, 696, 1392, 3712), + "cfg": _mcfg( + w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25 + ), + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_res2net.py b/segmentation_models_pytorch/encoders/timm_res2net.py new file mode 100644 index 0000000000000000000000000000000000000000..0fc1948849e9d5c4ca92262ca682b181e1e9fcab --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_res2net.py @@ -0,0 +1,163 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.res2net import Bottle2neck +import torch.nn as nn + + +class Res2NetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def make_dilated(self, *args, **kwargs): + raise ValueError("Res2Net encoders do not support dilated mode") + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +res2net_weights = { + "timm-res2net50_26w_4s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth", # noqa + }, + "timm-res2net50_48w_2s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth", # noqa + }, + "timm-res2net50_14w_8s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth", # noqa + }, + "timm-res2net50_26w_6s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth", # noqa + }, + "timm-res2net50_26w_8s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth", # noqa + }, + "timm-res2net101_26w_4s": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth", # noqa + }, + "timm-res2next50": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth", # noqa + }, +} + +pretrained_settings = {} +for model_name, sources in res2net_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + + +timm_res2net_encoders = { + "timm-res2net50_26w_4s": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 6, 3], + "base_width": 26, + "block_args": {"scale": 4}, + }, + }, + "timm-res2net101_26w_4s": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 23, 3], + "base_width": 26, + "block_args": {"scale": 4}, + }, + }, + "timm-res2net50_26w_6s": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 6, 3], + "base_width": 26, + "block_args": {"scale": 6}, + }, + }, + "timm-res2net50_26w_8s": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 6, 3], + "base_width": 26, + "block_args": {"scale": 8}, + }, + }, + "timm-res2net50_48w_2s": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 6, 3], + "base_width": 48, + "block_args": {"scale": 2}, + }, + }, + "timm-res2net50_14w_8s": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 6, 3], + "base_width": 14, + "block_args": {"scale": 8}, + }, + }, + "timm-res2next50": { + "encoder": Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2next50"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": Bottle2neck, + "layers": [3, 4, 6, 3], + "base_width": 4, + "cardinality": 8, + "block_args": {"scale": 4}, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_resnest.py b/segmentation_models_pytorch/encoders/timm_resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..a04fcf195d4736dc323a6774a19efdbbc96a5a77 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_resnest.py @@ -0,0 +1,208 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.resnest import ResNestBottleneck +import torch.nn as nn + + +class ResNestEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def make_dilated(self, *args, **kwargs): + raise ValueError("ResNest encoders do not support dilated mode") + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +resnest_weights = { + "timm-resnest14d": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth", # noqa + }, + "timm-resnest26d": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth", # noqa + }, + "timm-resnest50d": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth", # noqa + }, + "timm-resnest101e": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth", # noqa + }, + "timm-resnest200e": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth", # noqa + }, + "timm-resnest269e": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth", # noqa + }, + "timm-resnest50d_4s2x40d": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth", # noqa + }, + "timm-resnest50d_1s4x24d": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth", # noqa + }, +} + +pretrained_settings = {} +for model_name, sources in resnest_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + + +timm_resnest_encoders = { + "timm-resnest14d": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest14d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [1, 1, 1, 1], + "stem_type": "deep", + "stem_width": 32, + "avg_down": True, + "base_width": 64, + "cardinality": 1, + "block_args": {"radix": 2, "avd": True, "avd_first": False}, + }, + }, + "timm-resnest26d": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest26d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [2, 2, 2, 2], + "stem_type": "deep", + "stem_width": 32, + "avg_down": True, + "base_width": 64, + "cardinality": 1, + "block_args": {"radix": 2, "avd": True, "avd_first": False}, + }, + }, + "timm-resnest50d": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [3, 4, 6, 3], + "stem_type": "deep", + "stem_width": 32, + "avg_down": True, + "base_width": 64, + "cardinality": 1, + "block_args": {"radix": 2, "avd": True, "avd_first": False}, + }, + }, + "timm-resnest101e": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest101e"], + "params": { + "out_channels": (3, 128, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [3, 4, 23, 3], + "stem_type": "deep", + "stem_width": 64, + "avg_down": True, + "base_width": 64, + "cardinality": 1, + "block_args": {"radix": 2, "avd": True, "avd_first": False}, + }, + }, + "timm-resnest200e": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest200e"], + "params": { + "out_channels": (3, 128, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [3, 24, 36, 3], + "stem_type": "deep", + "stem_width": 64, + "avg_down": True, + "base_width": 64, + "cardinality": 1, + "block_args": {"radix": 2, "avd": True, "avd_first": False}, + }, + }, + "timm-resnest269e": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest269e"], + "params": { + "out_channels": (3, 128, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [3, 30, 48, 8], + "stem_type": "deep", + "stem_width": 64, + "avg_down": True, + "base_width": 64, + "cardinality": 1, + "block_args": {"radix": 2, "avd": True, "avd_first": False}, + }, + }, + "timm-resnest50d_4s2x40d": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [3, 4, 6, 3], + "stem_type": "deep", + "stem_width": 32, + "avg_down": True, + "base_width": 40, + "cardinality": 2, + "block_args": {"radix": 4, "avd": True, "avd_first": True}, + }, + }, + "timm-resnest50d_1s4x24d": { + "encoder": ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": ResNestBottleneck, + "layers": [3, 4, 6, 3], + "stem_type": "deep", + "stem_width": 32, + "avg_down": True, + "base_width": 24, + "cardinality": 4, + "block_args": {"radix": 1, "avd": True, "avd_first": True}, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py new file mode 100644 index 0000000000000000000000000000000000000000..9969c90a8eb2fb50861077ce449840d23d1d2760 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -0,0 +1,103 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic +import torch.nn as nn + + +class SkNetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + super().load_state_dict(state_dict, **kwargs) + + +sknet_weights = { + "timm-skresnet18": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth", # noqa + }, + "timm-skresnet34": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth", # noqa + }, + "timm-skresnext50_32x4d": { + "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth", # noqa + }, +} + +pretrained_settings = {} +for model_name, sources in sknet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + +timm_sknet_encoders = { + "timm-skresnet18": { + "encoder": SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnet18"], + "params": { + "out_channels": (3, 64, 64, 128, 256, 512), + "block": SelectiveKernelBasic, + "layers": [2, 2, 2, 2], + "zero_init_last_bn": False, + "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, + }, + }, + "timm-skresnet34": { + "encoder": SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnet34"], + "params": { + "out_channels": (3, 64, 64, 128, 256, 512), + "block": SelectiveKernelBasic, + "layers": [3, 4, 6, 3], + "zero_init_last_bn": False, + "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, + }, + }, + "timm-skresnext50_32x4d": { + "encoder": SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], + "params": { + "out_channels": (3, 64, 256, 512, 1024, 2048), + "block": SelectiveKernelBottleneck, + "layers": [3, 4, 6, 3], + "zero_init_last_bn": False, + "cardinality": 32, + "base_width": 4, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4f53b7f84b8365c7b7ba9ddac192a9a8ed4d8f --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -0,0 +1,38 @@ +import timm +import torch.nn as nn + + +class TimmUniversalEncoder(nn.Module): + def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): + super().__init__() + kwargs = dict( + in_chans=in_channels, + features_only=True, + output_stride=output_stride, + pretrained=pretrained, + out_indices=tuple(range(depth)), + ) + + # not all models support output stride argument, drop it by default + if output_stride == 32: + kwargs.pop("output_stride") + + self.model = timm.create_model(name, **kwargs) + + self._in_channels = in_channels + self._out_channels = [in_channels,] + self.model.feature_info.channels() + self._depth = depth + self._output_stride = output_stride + + def forward(self, x): + features = self.model(x) + features = [x,] + features + return features + + @property + def out_channels(self): + return self._out_channels + + @property + def output_stride(self): + return min(self._output_stride, 2 ** self._depth) diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc602c8e4ebbbed362893042e54843a692aabb3 --- /dev/null +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -0,0 +1,159 @@ +"""Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` + +Attributes: + + _out_channels (list of int): specify number of channels for each encoder feature tensor + _depth (int): specify number of stages in decoder (in other words number of downsampling operations) + _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) + +Methods: + + forward(self, x: torch.Tensor) + produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of + shape NCHW (features should be sorted in descending order according to spatial resolution, starting + with resolution same as input `x` tensor). + + Input: `x` with shape (1, 3, 64, 64) + Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes + [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), + (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) + + also should support number of features according to specified depth, e.g. if depth = 5, + number of feature tensors = 6 (one with same resolution as input and 5 downsampled), + depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). +""" + +import torch.nn as nn +from torchvision.models.vgg import VGG +from torchvision.models.vgg import make_layers +from pretrainedmodels.models.torchvision_models import pretrained_settings + +from ._base import EncoderMixin + +# fmt: off +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} +# fmt: on + + +class VGGEncoder(VGG, EncoderMixin): + def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): + super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + del self.classifier + + def make_dilated(self, *args, **kwargs): + raise ValueError( + "'VGG' models do not support dilated mode due to Max Pooling" + " operations for downsampling!" + ) + + def get_stages(self): + stages = [] + stage_modules = [] + for module in self.features: + if isinstance(module, nn.MaxPool2d): + stages.append(nn.Sequential(*stage_modules)) + stage_modules = [] + stage_modules.append(module) + stages.append(nn.Sequential(*stage_modules)) + return stages + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + keys = list(state_dict.keys()) + for k in keys: + if k.startswith("classifier"): + state_dict.pop(k, None) + super().load_state_dict(state_dict, **kwargs) + + +vgg_encoders = { + "vgg11": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg11"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["A"], + "batch_norm": False, + }, + }, + "vgg11_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg11_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["A"], + "batch_norm": True, + }, + }, + "vgg13": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg13"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["B"], + "batch_norm": False, + }, + }, + "vgg13_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg13_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["B"], + "batch_norm": True, + }, + }, + "vgg16": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg16"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["D"], + "batch_norm": False, + }, + }, + "vgg16_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg16_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["D"], + "batch_norm": True, + }, + }, + "vgg19": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg19"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["E"], + "batch_norm": False, + }, + }, + "vgg19_bn": { + "encoder": VGGEncoder, + "pretrained_settings": pretrained_settings["vgg19_bn"], + "params": { + "out_channels": (64, 128, 256, 512, 512, 512), + "config": cfg["E"], + "batch_norm": True, + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py new file mode 100644 index 0000000000000000000000000000000000000000..9453bd08351f78ff78450b6ff17e2d216385ba6b --- /dev/null +++ b/segmentation_models_pytorch/encoders/xception.py @@ -0,0 +1,78 @@ +import re +import torch.nn as nn + +from pretrainedmodels.models.xception import pretrained_settings +from pretrainedmodels.models.xception import Xception + +from ._base import EncoderMixin + + +class XceptionEncoder(Xception, EncoderMixin): + def __init__(self, out_channels, *args, depth=5, **kwargs): + super().__init__(*args, **kwargs) + + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + # modify padding to maintain output shape + self.conv1.padding = (1, 1) + self.conv2.padding = (1, 1) + + del self.fc + + def make_dilated(self, *args, **kwargs): + raise ValueError( + "Xception encoder does not support dilated mode " + "due to pooling operation for downsampling!" + ) + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential( + self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu + ), + self.block1, + self.block2, + nn.Sequential( + self.block3, + self.block4, + self.block5, + self.block6, + self.block7, + self.block8, + self.block9, + self.block10, + self.block11, + ), + nn.Sequential( + self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4 + ), + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict): + # remove linear + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) + + super().load_state_dict(state_dict) + + +xception_encoders = { + "xception": { + "encoder": XceptionEncoder, + "pretrained_settings": pretrained_settings["xception"], + "params": {"out_channels": (3, 64, 128, 256, 728, 2048),}, + }, +} diff --git a/segmentation_models_pytorch/losses/__init__.py b/segmentation_models_pytorch/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59d99cde5a32d9fe5561f88bdb16c334d946abfc --- /dev/null +++ b/segmentation_models_pytorch/losses/__init__.py @@ -0,0 +1,10 @@ +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE + +from .jaccard import JaccardLoss +from .dice import DiceLoss +from .focal import FocalLoss +from .lovasz import LovaszLoss +from .soft_bce import SoftBCEWithLogitsLoss +from .soft_ce import SoftCrossEntropyLoss +from .tversky import TverskyLoss +from .mcc import MCCLoss diff --git a/segmentation_models_pytorch/losses/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af26087ce8a66dbaaf49b1f0a7b4e0ad56650f99 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e6b6158a22009924e23f1910d5c359210d397c1 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/_functional.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/_functional.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be7c51f4044f5f7fda5042453ee875f027a1aa5b Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/_functional.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/_functional.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/_functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca1f6bffef5f658ed338d7441c8b85b5f6a2a667 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/_functional.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/constants.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/constants.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4857fe452bf38588f150236c0be119c35b62ab85 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/constants.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/constants.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb77a3fff8556be30bedb9f68d4c8786ba0ab27b Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/constants.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/dice.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/dice.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..612b7b1c1fd6bfcc7013c77797cec1a38b4d4684 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/dice.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/dice.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/dice.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0181864d2be62b5a3aabc456a28fe13541358663 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/dice.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/focal.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/focal.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3fae8eaa788fd918e27d442d6fde802c3ba57eb Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/focal.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/focal.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/focal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb47919f5c42bb14395d08475c87ca454dfd785 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/focal.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/jaccard.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/jaccard.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f563a25d756abb1d1b8dbc36770565122676e0c Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/jaccard.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/jaccard.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/jaccard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b52d339139236d10d39b669fbf31805cc2fc055 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/jaccard.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/lovasz.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/lovasz.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f73b0ba428ab14ede676475b6124c9e0c0483d80 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/lovasz.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/lovasz.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/lovasz.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4eda8146461fe631063f54a683cea305cc9473b Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/lovasz.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/mcc.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/mcc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e1de50818c52dd78527c891555579b94ab3539 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/mcc.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/mcc.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/mcc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa97fb31b613798deac67a52364d876391df3115 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/mcc.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/soft_bce.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/soft_bce.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..734a2ba66d718db7091e2c0b2cec8bf4611c3ef5 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/soft_bce.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/soft_bce.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/soft_bce.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5977fbd1d94f87c65449420e8ad6699d1f57a11 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/soft_bce.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/soft_ce.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/soft_ce.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06f15d380d2271c567cb78a28ae16b0651315939 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/soft_ce.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/soft_ce.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/soft_ce.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e6463d9d382a9c7da779f71afe662c82b4b3c9 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/soft_ce.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/tversky.cpython-37.pyc b/segmentation_models_pytorch/losses/__pycache__/tversky.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..215a8fe2eda4d3a0bee20fcdbb7c148eccb76bb3 Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/tversky.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/losses/__pycache__/tversky.cpython-39.pyc b/segmentation_models_pytorch/losses/__pycache__/tversky.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c06eb49c7cdd57f10a69c31ab2847d99b0f6403b Binary files /dev/null and b/segmentation_models_pytorch/losses/__pycache__/tversky.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..74301e6d701a884ff0af8300816afc73f6814486 --- /dev/null +++ b/segmentation_models_pytorch/losses/_functional.py @@ -0,0 +1,290 @@ +import math +import numpy as np + +from typing import Optional + +import torch +import torch.nn.functional as F + +__all__ = [ + "focal_loss_with_logits", + "softmax_focal_loss_with_logits", + "soft_jaccard_score", + "soft_dice_score", + "wing_loss", +] + + +def to_tensor(x, dtype=None) -> torch.Tensor: + if isinstance(x, torch.Tensor): + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, (list, tuple)): + x = np.array(x) + x = torch.from_numpy(x) + if dtype is not None: + x = x.type(dtype) + return x + + +def focal_loss_with_logits( + output: torch.Tensor, + target: torch.Tensor, + gamma: float = 2.0, + alpha: Optional[float] = 0.25, + reduction: str = "mean", + normalized: bool = False, + reduced_threshold: Optional[float] = None, + eps: float = 1e-6, +) -> torch.Tensor: + """Compute binary focal loss between target and output logits. + See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. + + Args: + output: Tensor of arbitrary shape (predictions of the model) + target: Tensor of the same shape as input + gamma: Focal loss power factor + alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, + high values will give more weight to positive class. + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, + 'mean': the sum of the output will be divided by the number of + elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. + 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' + normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). + reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). + + References: + https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py + """ + target = target.type(output.type()) + + logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none") + pt = torch.exp(-logpt) + + # compute the loss + if reduced_threshold is None: + focal_term = (1.0 - pt).pow(gamma) + else: + focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) + focal_term[pt < reduced_threshold] = 1 + + loss = focal_term * logpt + + if alpha is not None: + loss *= alpha * target + (1 - alpha) * (1 - target) + + if normalized: + norm_factor = focal_term.sum().clamp_min(eps) + loss /= norm_factor + + if reduction == "mean": + loss = loss.mean() + if reduction == "sum": + loss = loss.sum() + if reduction == "batchwise_mean": + loss = loss.sum(0) + + return loss + + +def softmax_focal_loss_with_logits( + output: torch.Tensor, + target: torch.Tensor, + gamma: float = 2.0, + reduction="mean", + normalized=False, + reduced_threshold: Optional[float] = None, + eps: float = 1e-6, +) -> torch.Tensor: + """Softmax version of focal loss between target and output logits. + See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. + + Args: + output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss) + target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss) + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, + 'mean': the sum of the output will be divided by the number of + elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. + 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' + normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). + reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). + """ + log_softmax = F.log_softmax(output, dim=1) + + loss = F.nll_loss(log_softmax, target, reduction="none") + pt = torch.exp(-loss) + + # compute the loss + if reduced_threshold is None: + focal_term = (1.0 - pt).pow(gamma) + else: + focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) + focal_term[pt < reduced_threshold] = 1 + + loss = focal_term * loss + + if normalized: + norm_factor = focal_term.sum().clamp_min(eps) + loss = loss / norm_factor + + if reduction == "mean": + loss = loss.mean() + if reduction == "sum": + loss = loss.sum() + if reduction == "batchwise_mean": + loss = loss.sum(0) + + return loss + + +def soft_jaccard_score( + output: torch.Tensor, + target: torch.Tensor, + smooth: float = 0.0, + eps: float = 1e-7, + dims=None, +) -> torch.Tensor: + assert output.size() == target.size() + if dims is not None: + intersection = torch.sum(output * target, dim=dims) + cardinality = torch.sum(output + target, dim=dims) + else: + intersection = torch.sum(output * target) + cardinality = torch.sum(output + target) + + union = cardinality - intersection + jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps) + return jaccard_score + + +def soft_dice_score( + output: torch.Tensor, + target: torch.Tensor, + smooth: float = 0.0, + eps: float = 1e-7, + dims=None, +) -> torch.Tensor: + assert output.size() == target.size() + if dims is not None: + intersection = torch.sum(output * target, dim=dims) + cardinality = torch.sum(output + target, dim=dims) + else: + intersection = torch.sum(output * target) + cardinality = torch.sum(output + target) + dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) + return dice_score + + +def soft_tversky_score( + output: torch.Tensor, + target: torch.Tensor, + alpha: float, + beta: float, + smooth: float = 0.0, + eps: float = 1e-7, + dims=None, +) -> torch.Tensor: + assert output.size() == target.size() + if dims is not None: + intersection = torch.sum(output * target, dim=dims) # TP + fp = torch.sum(output * (1.0 - target), dim=dims) + fn = torch.sum((1 - output) * target, dim=dims) + else: + intersection = torch.sum(output * target) # TP + fp = torch.sum(output * (1.0 - target)) + fn = torch.sum((1 - output) * target) + + tversky_score = (intersection + smooth) / ( + intersection + alpha * fp + beta * fn + smooth + ).clamp_min(eps) + return tversky_score + + +def wing_loss( + output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean" +): + """Wing loss + + References: + https://arxiv.org/pdf/1711.06753.pdf + + """ + diff_abs = (target - output).abs() + loss = diff_abs.clone() + + idx_smaller = diff_abs < width + idx_bigger = diff_abs >= width + + loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature) + + C = width - width * math.log(1 + width / curvature) + loss[idx_bigger] = loss[idx_bigger] - C + + if reduction == "sum": + loss = loss.sum() + + if reduction == "mean": + loss = loss.mean() + + return loss + + +def label_smoothed_nll_loss( + lprobs: torch.Tensor, + target: torch.Tensor, + epsilon: float, + ignore_index=None, + reduction="mean", + dim=-1, +) -> torch.Tensor: + """NLL loss with label smoothing + + References: + https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py + + Args: + lprobs (torch.Tensor): Log-probabilities of predictions (e.g after log_softmax) + + """ + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(dim) + + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + target = target.masked_fill(pad_mask, 0) + nll_loss = -lprobs.gather(dim=dim, index=target) + smooth_loss = -lprobs.sum(dim=dim, keepdim=True) + + # nll_loss.masked_fill_(pad_mask, 0.0) + # smooth_loss.masked_fill_(pad_mask, 0.0) + nll_loss = nll_loss.masked_fill(pad_mask, 0.0) + smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0) + else: + nll_loss = -lprobs.gather(dim=dim, index=target) + smooth_loss = -lprobs.sum(dim=dim, keepdim=True) + + nll_loss = nll_loss.squeeze(dim) + smooth_loss = smooth_loss.squeeze(dim) + + if reduction == "sum": + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + if reduction == "mean": + nll_loss = nll_loss.mean() + smooth_loss = smooth_loss.mean() + + eps_i = epsilon / lprobs.size(dim) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss diff --git a/segmentation_models_pytorch/losses/constants.py b/segmentation_models_pytorch/losses/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1399ff3545918f9ee2dd1960690766cb392666 --- /dev/null +++ b/segmentation_models_pytorch/losses/constants.py @@ -0,0 +1,18 @@ +#: Loss binary mode suppose you are solving binary segmentation task. +#: That mean yor have only one class which pixels are labled as **1**, +#: the rest pixels are background and labeled as **0**. +#: Target mask shape - (N, H, W), model output mask shape (N, 1, H, W). +BINARY_MODE: str = "binary" + +#: Loss multiclass mode suppose you are solving multi-**class** segmentation task. +#: That mean you have *C = 1..N* classes which have unique label values, +#: classes are mutually exclusive and all pixels are labeled with theese values. +#: Target mask shape - (N, H, W), model output mask shape (N, C, H, W). +MULTICLASS_MODE: str = "multiclass" + +#: Loss multilabel mode suppose you are solving multi-**label** segmentation task. +#: That mean you have *C = 1..N* classes which pixels are labeled as **1**, +#: classes are not mutually exclusive and each class have its own *channel*, +#: pixels in each channel which are not belong to class labeled as **0**. +#: Target mask shape - (N, C, H, W), model output mask shape (N, C, H, W). +MULTILABEL_MODE: str = "multilabel" diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d3dc816a71c146dc34602f330471d52bf094b7 --- /dev/null +++ b/segmentation_models_pytorch/losses/dice.py @@ -0,0 +1,138 @@ +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +from ._functional import soft_dice_score, to_tensor +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE + +__all__ = ["DiceLoss"] + + +class DiceLoss(_Loss): + def __init__( + self, + mode: str, + classes: Optional[List[int]] = None, + log_loss: bool = False, + from_logits: bool = True, + smooth: float = 0.0, + ignore_index: Optional[int] = None, + eps: float = 1e-7, + ): + """Dice loss for image segmentation task. + It supports binary, multiclass and multilabel cases + + Args: + mode: Loss mode 'binary', 'multiclass' or 'multilabel' + classes: List of classes that contribute in loss computation. By default, all channels are included. + log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff` + from_logits: If True, assumes input is raw logits + smooth: Smoothness constant for dice coefficient (a) + ignore_index: Label that indicates ignored pixels (does not contribute to loss) + eps: A small epsilon for numerical stability to avoid zero division error + (denominator will be always greater or equal to eps) + + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) + + Reference + https://github.com/BloodAxe/pytorch-toolbelt + """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super(DiceLoss, self).__init__() + self.mode = mode + if classes is not None: + assert ( + mode != BINARY_MODE + ), "Masking classes is not supported with mode=binary" + classes = to_tensor(classes, dtype=torch.long) + + self.classes = classes + self.from_logits = from_logits + self.smooth = smooth + self.eps = eps + self.log_loss = log_loss + self.ignore_index = ignore_index + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + + assert y_true.size(0) == y_pred.size(0) + + if self.from_logits: + # Apply activations to get [0..1] class probabilities + # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on + # extreme values 0 and 1 + if self.mode == MULTICLASS_MODE: + y_pred = y_pred.log_softmax(dim=1).exp() + else: + y_pred = F.logsigmoid(y_pred).exp() + + bs = y_true.size(0) + num_classes = y_pred.size(1) + dims = (0, 2) + + if self.mode == BINARY_MODE: + y_true = y_true.view(bs, 1, -1) + y_pred = y_pred.view(bs, 1, -1) + + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask + y_true = y_true * mask + + if self.mode == MULTICLASS_MODE: + y_true = y_true.view(bs, -1) + y_pred = y_pred.view(bs, num_classes, -1) + + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask.unsqueeze(1) + + y_true = F.one_hot( + (y_true * mask).to(torch.long), num_classes + ) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W + else: + y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) # N, C, H*W + + if self.mode == MULTILABEL_MODE: + y_true = y_true.view(bs, num_classes, -1) + y_pred = y_pred.view(bs, num_classes, -1) + + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask + y_true = y_true * mask + + scores = self.compute_score( + y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims + ) + + if self.log_loss: + loss = -torch.log(scores.clamp_min(self.eps)) + else: + loss = 1.0 - scores + + # Dice loss is undefined for non-empty classes + # So we zero contribution of channel that does not have true pixels + # NOTE: A better workaround would be to use loss term `mean(y_pred)` + # for this case, however it will be a modified jaccard loss + + mask = y_true.sum(dims) > 0 + loss *= mask.to(loss.dtype) + + if self.classes is not None: + loss = loss[self.classes] + + return self.aggregate_loss(loss) + + def aggregate_loss(self, loss): + return loss.mean() + + def compute_score( + self, output, target, smooth=0.0, eps=1e-7, dims=None + ) -> torch.Tensor: + return soft_dice_score(output, target, smooth, eps, dims) diff --git a/segmentation_models_pytorch/losses/focal.py b/segmentation_models_pytorch/losses/focal.py new file mode 100644 index 0000000000000000000000000000000000000000..38f2d907ee9019e37a792e58db19157dcf683387 --- /dev/null +++ b/segmentation_models_pytorch/losses/focal.py @@ -0,0 +1,90 @@ +from typing import Optional +from functools import partial + +import torch +from torch.nn.modules.loss import _Loss +from ._functional import focal_loss_with_logits +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE + +__all__ = ["FocalLoss"] + + +class FocalLoss(_Loss): + def __init__( + self, + mode: str, + alpha: Optional[float] = None, + gamma: Optional[float] = 2.0, + ignore_index: Optional[int] = None, + reduction: Optional[str] = "mean", + normalized: bool = False, + reduced_threshold: Optional[float] = None, + ): + """Compute Focal loss + + Args: + mode: Loss mode 'binary', 'multiclass' or 'multilabel' + alpha: Prior probability of having positive value in target. + gamma: Power factor for dampening weight (focal strength). + ignore_index: If not None, targets may contain values to be ignored. + Target values equal to ignore_index will be ignored from loss computation. + normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). + reduced_threshold: Switch to reduced focal loss. Note, when using this mode you + should use `reduction="sum"`. + + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) + + Reference + https://github.com/BloodAxe/pytorch-toolbelt + + """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super().__init__() + + self.mode = mode + self.ignore_index = ignore_index + self.focal_loss_fn = partial( + focal_loss_with_logits, + alpha=alpha, + gamma=gamma, + reduced_threshold=reduced_threshold, + reduction=reduction, + normalized=normalized, + ) + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + + if self.mode in {BINARY_MODE, MULTILABEL_MODE}: + y_true = y_true.view(-1) + y_pred = y_pred.view(-1) + + if self.ignore_index is not None: + # Filter predictions with ignore label from loss computation + not_ignored = y_true != self.ignore_index + y_pred = y_pred[not_ignored] + y_true = y_true[not_ignored] + + loss = self.focal_loss_fn(y_pred, y_true) + + elif self.mode == MULTICLASS_MODE: + + num_classes = y_pred.size(1) + loss = 0 + + # Filter anchors with -1 label from loss computation + if self.ignore_index is not None: + not_ignored = y_true != self.ignore_index + + for cls in range(num_classes): + cls_y_true = (y_true == cls).long() + cls_y_pred = y_pred[:, cls, ...] + + if self.ignore_index is not None: + cls_y_true = cls_y_true[not_ignored] + cls_y_pred = cls_y_pred[not_ignored] + + loss += self.focal_loss_fn(cls_y_pred, cls_y_true) + + return loss diff --git a/segmentation_models_pytorch/losses/jaccard.py b/segmentation_models_pytorch/losses/jaccard.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f5d6ffab6bd7814f1bbd2a08b3ffd4b2284188 --- /dev/null +++ b/segmentation_models_pytorch/losses/jaccard.py @@ -0,0 +1,113 @@ +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +from ._functional import soft_jaccard_score, to_tensor +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE + +__all__ = ["JaccardLoss"] + + +class JaccardLoss(_Loss): + def __init__( + self, + mode: str, + classes: Optional[List[int]] = None, + log_loss: bool = False, + from_logits: bool = True, + smooth: float = 0.0, + eps: float = 1e-7, + ): + """Jaccard loss for image segmentation task. + It supports binary, multiclass and multilabel cases + + Args: + mode: Loss mode 'binary', 'multiclass' or 'multilabel' + classes: List of classes that contribute in loss computation. By default, all channels are included. + log_loss: If True, loss computed as `- log(jaccard_coeff)`, otherwise `1 - jaccard_coeff` + from_logits: If True, assumes input is raw logits + smooth: Smoothness constant for dice coefficient + eps: A small epsilon for numerical stability to avoid zero division error + (denominator will be always greater or equal to eps) + + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) + + Reference + https://github.com/BloodAxe/pytorch-toolbelt + """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super(JaccardLoss, self).__init__() + + self.mode = mode + if classes is not None: + assert ( + mode != BINARY_MODE + ), "Masking classes is not supported with mode=binary" + classes = to_tensor(classes, dtype=torch.long) + + self.classes = classes + self.from_logits = from_logits + self.smooth = smooth + self.eps = eps + self.log_loss = log_loss + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + + assert y_true.size(0) == y_pred.size(0) + + if self.from_logits: + # Apply activations to get [0..1] class probabilities + # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on + # extreme values 0 and 1 + if self.mode == MULTICLASS_MODE: + y_pred = y_pred.log_softmax(dim=1).exp() + else: + y_pred = F.logsigmoid(y_pred).exp() + + bs = y_true.size(0) + num_classes = y_pred.size(1) + dims = (0, 2) + + if self.mode == BINARY_MODE: + y_true = y_true.view(bs, 1, -1) + y_pred = y_pred.view(bs, 1, -1) + + if self.mode == MULTICLASS_MODE: + y_true = y_true.view(bs, -1) + y_pred = y_pred.view(bs, num_classes, -1) + + y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) # H, C, H*W + + if self.mode == MULTILABEL_MODE: + y_true = y_true.view(bs, num_classes, -1) + y_pred = y_pred.view(bs, num_classes, -1) + + scores = soft_jaccard_score( + y_pred, + y_true.type(y_pred.dtype), + smooth=self.smooth, + eps=self.eps, + dims=dims, + ) + + if self.log_loss: + loss = -torch.log(scores.clamp_min(self.eps)) + else: + loss = 1.0 - scores + + # IoU loss is defined for non-empty classes + # So we zero contribution of channel that does not have true pixels + # NOTE: A better workaround would be to use loss term `mean(y_pred)` + # for this case, however it will be a modified jaccard loss + + mask = y_true.sum(dims) > 0 + loss *= mask.float() + + if self.classes is not None: + loss = loss[self.classes] + + return loss.mean() diff --git a/segmentation_models_pytorch/losses/lovasz.py b/segmentation_models_pytorch/losses/lovasz.py new file mode 100644 index 0000000000000000000000000000000000000000..aca6f8664f1b09ad6668977bec89105624faa23f --- /dev/null +++ b/segmentation_models_pytorch/losses/lovasz.py @@ -0,0 +1,236 @@ +""" +Lovasz-Softmax and Jaccard hinge loss in PyTorch +Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) +""" + +from __future__ import print_function, division +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE + +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse as ifilterfalse + +__all__ = ["LovaszLoss"] + + +def _lovasz_grad(gt_sorted): + """Compute gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1.0 - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def _lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Logits at each pixel (between -infinity and +infinity) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean( + _lovasz_hinge_flat( + *_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore) + ) + for log, lab in zip(logits, labels) + ) + else: + loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore)) + return loss + + +def _lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss + Args: + logits: [P] Logits at each prediction (between -infinity and +infinity) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0.0 + signs = 2.0 * labels.float() - 1.0 + errors = 1.0 - logits * signs + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = _lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def _flatten_binary_scores(scores, labels, ignore=None): + """Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = labels != ignore + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +# --------------------------- MULTICLASS LOSSES --------------------------- + + +def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore=None): + """Multi-class Lovasz-Softmax loss + Args: + @param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + @param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + @param per_image: compute the loss per image instead of per batch + @param ignore: void class labels + """ + if per_image: + loss = mean( + _lovasz_softmax_flat( + *_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), + classes=classes + ) + for prob, lab in zip(probas, labels) + ) + else: + loss = _lovasz_softmax_flat( + *_flatten_probas(probas, labels, ignore), classes=classes + ) + return loss + + +def _lovasz_softmax_flat(probas, labels, classes="present"): + """Multi-class Lovasz-Softmax loss + Args: + @param probas: [P, C] Class probabilities at each prediction (between 0 and 1) + @param labels: [P] Tensor, ground truth labels (between 0 and C - 1) + @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0.0 + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ["all", "present"] else classes + for c in class_to_sum: + fg = (labels == c).type_as(probas) # foreground for class c + if classes == "present" and fg.sum() == 0: + continue + if C == 1: + if len(classes) > 1: + raise ValueError("Sigmoid output possible only with 1 class") + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) + return mean(losses) + + +def _flatten_probas(probas, labels, ignore=None): + """Flattens predictions in the batch""" + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + + C = probas.size(1) + probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] + probas = probas.contiguous().view(-1, C) # [P, C] + + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = labels != ignore + vprobas = probas[valid] + vlabels = labels[valid] + return vprobas, vlabels + + +# --------------------------- HELPER FUNCTIONS --------------------------- +def isnan(x): + return x != x + + +def mean(values, ignore_nan=False, empty=0): + """Nanmean compatible with generators.""" + values = iter(values) + if ignore_nan: + values = ifilterfalse(isnan, values) + try: + n = 1 + acc = next(values) + except StopIteration: + if empty == "raise": + raise ValueError("Empty mean") + return empty + for n, v in enumerate(values, 2): + acc += v + if n == 1: + return acc + return acc / n + + +class LovaszLoss(_Loss): + def __init__( + self, + mode: str, + per_image: bool = False, + ignore_index: Optional[int] = None, + from_logits: bool = True, + ): + """Lovasz loss for image segmentation task. + It supports binary, multiclass and multilabel cases + + Args: + mode: Loss mode 'binary', 'multiclass' or 'multilabel' + ignore_index: Label that indicates ignored pixels (does not contribute to loss) + per_image: If True loss computed per each image and then averaged, else computed per whole batch + + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) + + Reference + https://github.com/BloodAxe/pytorch-toolbelt + """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super().__init__() + + self.mode = mode + self.ignore_index = ignore_index + self.per_image = per_image + + def forward(self, y_pred, y_true): + + if self.mode in {BINARY_MODE, MULTILABEL_MODE}: + loss = _lovasz_hinge( + y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index + ) + elif self.mode == MULTICLASS_MODE: + y_pred = y_pred.softmax(dim=1) + loss = _lovasz_softmax( + y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index + ) + else: + raise ValueError("Wrong mode {}.".format(self.mode)) + return loss diff --git a/segmentation_models_pytorch/losses/mcc.py b/segmentation_models_pytorch/losses/mcc.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd7d6694d0d1a6dcc166a6d70eb6ff00714e970 --- /dev/null +++ b/segmentation_models_pytorch/losses/mcc.py @@ -0,0 +1,51 @@ +import torch +from torch.nn.modules.loss import _Loss + + +class MCCLoss(_Loss): + def __init__(self, eps: float = 1e-5): + """Compute Matthews Correlation Coefficient Loss for image segmentation task. + It only supports binary mode. + + Args: + eps (float): Small epsilon to handle situations where all the samples in the dataset belong to one class + + Reference: + https://github.com/kakumarabhishek/MCC-Loss + """ + super().__init__() + self.eps = eps + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """Compute MCC loss + + Args: + y_pred (torch.Tensor): model prediction of shape (N, H, W) or (N, 1, H, W) + y_true (torch.Tensor): ground truth labels of shape (N, H, W) or (N, 1, H, W) + + Returns: + torch.Tensor: loss value (1 - mcc) + """ + + bs = y_true.shape[0] + + y_true = y_true.view(bs, 1, -1) + y_pred = y_pred.view(bs, 1, -1) + + tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps + tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps + fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps + fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps + + numerator = torch.mul(tp, tn) - torch.mul(fp, fn) + denominator = torch.sqrt( + torch.add(tp, fp) + * torch.add(tp, fn) + * torch.add(tn, fp) + * torch.add(tn, fn) + ) + + mcc = torch.div(numerator.sum(), denominator.sum()) + loss = 1.0 - mcc + + return loss diff --git a/segmentation_models_pytorch/losses/soft_bce.py b/segmentation_models_pytorch/losses/soft_bce.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c5b6fcf8aa9997fa7cb86a29a14068c3f411d2 --- /dev/null +++ b/segmentation_models_pytorch/losses/soft_bce.py @@ -0,0 +1,84 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +__all__ = ["SoftBCEWithLogitsLoss"] + + +class SoftBCEWithLogitsLoss(nn.Module): + + __constants__ = [ + "weight", + "pos_weight", + "reduction", + "ignore_index", + "smooth_factor", + ] + + def __init__( + self, + weight: Optional[torch.Tensor] = None, + ignore_index: Optional[int] = -100, + reduction: str = "mean", + smooth_factor: Optional[float] = None, + pos_weight: Optional[torch.Tensor] = None, + ): + """Drop-in replacement for torch.nn.BCEWithLogitsLoss with few additions: ignore_index and label_smoothing + + Args: + ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. + smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 1] -> [0.9, 0.1, 0.9]) + + Shape + - **y_pred** - torch.Tensor of shape NxCxHxW + - **y_true** - torch.Tensor of shape NxHxW or Nx1xHxW + + Reference + https://github.com/BloodAxe/pytorch-toolbelt + + """ + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction + self.smooth_factor = smooth_factor + self.register_buffer("weight", weight) + self.register_buffer("pos_weight", pos_weight) + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """ + Args: + y_pred: torch.Tensor of shape (N, C, H, W) + y_true: torch.Tensor of shape (N, H, W) or (N, 1, H, W) + + Returns: + loss: torch.Tensor + """ + + if self.smooth_factor is not None: + soft_targets = (1 - y_true) * self.smooth_factor + y_true * ( + 1 - self.smooth_factor + ) + else: + soft_targets = y_true + + loss = F.binary_cross_entropy_with_logits( + y_pred, + soft_targets, + self.weight, + pos_weight=self.pos_weight, + reduction="none", + ) + + if self.ignore_index is not None: + not_ignored_mask = y_true != self.ignore_index + loss *= not_ignored_mask.type_as(loss) + + if self.reduction == "mean": + loss = loss.mean() + + if self.reduction == "sum": + loss = loss.sum() + + return loss diff --git a/segmentation_models_pytorch/losses/soft_ce.py b/segmentation_models_pytorch/losses/soft_ce.py new file mode 100644 index 0000000000000000000000000000000000000000..960ef18f765b898d77eab0d9b0c91ab43e4072e9 --- /dev/null +++ b/segmentation_models_pytorch/losses/soft_ce.py @@ -0,0 +1,48 @@ +from typing import Optional +from torch import nn, Tensor +import torch +import torch.nn.functional as F +from ._functional import label_smoothed_nll_loss + +__all__ = ["SoftCrossEntropyLoss"] + + +class SoftCrossEntropyLoss(nn.Module): + + __constants__ = ["reduction", "ignore_index", "smooth_factor"] + + def __init__( + self, + reduction: str = "mean", + smooth_factor: Optional[float] = None, + ignore_index: Optional[int] = -100, + dim: int = 1, + ): + """Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing + + Args: + smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05]) + + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) + + Reference + https://github.com/BloodAxe/pytorch-toolbelt + """ + super().__init__() + self.smooth_factor = smooth_factor + self.ignore_index = ignore_index + self.reduction = reduction + self.dim = dim + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + log_prob = F.log_softmax(y_pred, dim=self.dim) + return label_smoothed_nll_loss( + log_prob, + y_true, + epsilon=self.smooth_factor, + ignore_index=self.ignore_index, + reduction=self.reduction, + dim=self.dim, + ) diff --git a/segmentation_models_pytorch/losses/tversky.py b/segmentation_models_pytorch/losses/tversky.py new file mode 100644 index 0000000000000000000000000000000000000000..28ab8516346b40522c74fdf02213bbd17771814e --- /dev/null +++ b/segmentation_models_pytorch/losses/tversky.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +import torch +from ._functional import soft_tversky_score +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE +from .dice import DiceLoss + +__all__ = ["TverskyLoss"] + + +class TverskyLoss(DiceLoss): + """Tversky loss for image segmentation task. + Where FP and FN is weighted by alpha and beta params. + With alpha == beta == 0.5, this loss becomes equal DiceLoss. + It supports binary, multiclass and multilabel cases + + Args: + mode: Metric mode {'binary', 'multiclass', 'multilabel'} + classes: Optional list of classes that contribute in loss computation; + By default, all channels are included. + log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` + from_logits: If True assumes input is raw logits + smooth: + ignore_index: Label that indicates ignored pixels (does not contribute to loss) + eps: Small epsilon for numerical stability + alpha: Weight constant that penalize model for FPs (False Positives) + beta: Weight constant that penalize model for FNs (False Negatives) + gamma: Constant that squares the error function. Defaults to ``1.0`` + + Return: + loss: torch.Tensor + + """ + + def __init__( + self, + mode: str, + classes: List[int] = None, + log_loss: bool = False, + from_logits: bool = True, + smooth: float = 0.0, + ignore_index: Optional[int] = None, + eps: float = 1e-7, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, + ): + + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super().__init__( + mode, classes, log_loss, from_logits, smooth, ignore_index, eps + ) + self.alpha = alpha + self.beta = beta + self.gamma = gamma + + def aggregate_loss(self, loss): + return loss.mean() ** self.gamma + + def compute_score( + self, output, target, smooth=0.0, eps=1e-7, dims=None + ) -> torch.Tensor: + return soft_tversky_score( + output, target, self.alpha, self.beta, smooth, eps, dims + ) diff --git a/segmentation_models_pytorch/metrics/__init__.py b/segmentation_models_pytorch/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f2544ed1e8c59279df4d2751850b781ae38ee6 --- /dev/null +++ b/segmentation_models_pytorch/metrics/__init__.py @@ -0,0 +1,20 @@ +from .functional import ( + get_stats, + fbeta_score, + f1_score, + iou_score, + accuracy, + precision, + recall, + sensitivity, + specificity, + balanced_accuracy, + positive_predictive_value, + negative_predictive_value, + false_negative_rate, + false_positive_rate, + false_discovery_rate, + false_omission_rate, + positive_likelihood_ratio, + negative_likelihood_ratio, +) diff --git a/segmentation_models_pytorch/metrics/__pycache__/__init__.cpython-37.pyc b/segmentation_models_pytorch/metrics/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bccd59e2d36890eb43d3163866115724cdb8cfb0 Binary files /dev/null and b/segmentation_models_pytorch/metrics/__pycache__/__init__.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/metrics/__pycache__/__init__.cpython-39.pyc b/segmentation_models_pytorch/metrics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43a979eed5e9fb3ab52321fad274115441719fd6 Binary files /dev/null and b/segmentation_models_pytorch/metrics/__pycache__/__init__.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/metrics/__pycache__/functional.cpython-37.pyc b/segmentation_models_pytorch/metrics/__pycache__/functional.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23536051299e093a8273803f274c7547e049d1e1 Binary files /dev/null and b/segmentation_models_pytorch/metrics/__pycache__/functional.cpython-37.pyc differ diff --git a/segmentation_models_pytorch/metrics/__pycache__/functional.cpython-39.pyc b/segmentation_models_pytorch/metrics/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0390cb0296a2314ec72a1dce24b7b241506be392 Binary files /dev/null and b/segmentation_models_pytorch/metrics/__pycache__/functional.cpython-39.pyc differ diff --git a/segmentation_models_pytorch/metrics/functional.py b/segmentation_models_pytorch/metrics/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..b829c7dbdc179b502df2b0b0325e6b93c739d7ad --- /dev/null +++ b/segmentation_models_pytorch/metrics/functional.py @@ -0,0 +1,792 @@ +"""Various metrics based on Type I and Type II errors. + +References: + https://en.wikipedia.org/wiki/Confusion_matrix + + +Example: + + .. code-block:: python + + import segmentation_models_pytorch as smp + + # lets assume we have multilabel prediction for 3 classes + output = torch.rand([10, 3, 256, 256]) + target = torch.rand([10, 3, 256, 256]).round().long() + + # first compute statistics for true positives, false positives, false negative and + # true negative "pixels" + tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5) + + # then compute metrics with required reduction (see metric docs) + iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro") + f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro") + f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro") + accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro") + recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise") + +""" +import torch +import warnings +from typing import Optional, List, Tuple, Union + + +__all__ = [ + "get_stats", + "fbeta_score", + "f1_score", + "iou_score", + "accuracy", + "precision", + "recall", + "sensitivity", + "specificity", + "balanced_accuracy", + "positive_predictive_value", + "negative_predictive_value", + "false_negative_rate", + "false_positive_rate", + "false_discovery_rate", + "false_omission_rate", + "positive_likelihood_ratio", + "negative_likelihood_ratio", +] + + +################################################################################################### +# Statistics computation (true positives, false positives, false negatives, false positives) +################################################################################################### + + +def get_stats( + output: Union[torch.LongTensor, torch.FloatTensor], + target: torch.LongTensor, + mode: str, + ignore_index: Optional[int] = None, + threshold: Optional[Union[float, List[float]]] = None, + num_classes: Optional[int] = None, +) -> Tuple[torch.LongTensor]: + """Compute true positive, false positive, false negative, true negative 'pixels' + for each image and each class. + + Args: + output (Union[torch.LongTensor, torch.FloatTensor]): Model output with following + shapes and types depending on the specified ``mode``: + + 'binary' + shape (N, 1, ...) and ``torch.LongTensor`` or ``torch.FloatTensor`` + + 'multilabel' + shape (N, C, ...) and ``torch.LongTensor`` or ``torch.FloatTensor`` + + 'multiclass' + shape (N, ...) and ``torch.LongTensor`` + + target (torch.LongTensor): Targets with following shapes depending on the specified ``mode``: + + 'binary' + shape (N, 1, ...) + + 'multilabel' + shape (N, C, ...) + + 'multiclass' + shape (N, ...) + + mode (str): One of ``'binary'`` | ``'multilabel'`` | ``'multiclass'`` + ignore_index (Optional[int]): Label to ignore on for metric computation. + **Not** supproted for ``'binary'`` and ``'multilabel'`` modes. Defaults to None. + threshold (Optional[float, List[float]]): Binarization threshold for + ``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None. + num_classes (Optional[int]): Number of classes, necessary attribute + only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1). + If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or + ``255``. + + Raises: + ValueError: in case of misconfiguration. + + Returns: + Tuple[torch.LongTensor]: true_positive, false_positive, false_negative, + true_negative tensors (N, C) shape each. + + """ + + if torch.is_floating_point(target): + raise ValueError( + f"Target should be one of the integer types, got {target.dtype}." + ) + + if torch.is_floating_point(output) and threshold is None: + raise ValueError( + f"Output should be one of the integer types if ``threshold`` is not None, got {output.dtype}." + ) + + if torch.is_floating_point(output) and mode == "multiclass": + raise ValueError( + f"For ``multiclass`` mode ``target`` should be one of the integer types, got {output.dtype}." + ) + + if mode not in {"binary", "multiclass", "multilabel"}: + raise ValueError( + f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}." + ) + + if mode == "multiclass" and threshold is not None: + raise ValueError( + "``threshold`` parameter does not supported for this 'multiclass' mode" + ) + + if output.shape != target.shape: + raise ValueError( + "Dimensions should match, but ``output`` shape is not equal to ``target`` " + + f"shape, {output.shape} != {target.shape}" + ) + + if mode != "multiclass" and ignore_index is not None: + raise ValueError( + f"``ignore_index`` parameter is not supproted for '{mode}' mode" + ) + + if mode == "multiclass" and num_classes is None: + raise ValueError( + "``num_classes`` attribute should be not ``None`` for 'multiclass' mode." + ) + + if ignore_index is not None and 0 <= ignore_index <= num_classes - 1: + raise ValueError( + f"``ignore_index`` should be outside the class values range, but got class values in range " + f"0..{num_classes - 1} and ``ignore_index={ignore_index}``. Hint: if you have ``ignore_index = 0``" + f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``" + f"and relevant class values started from ``0``." + ) + + if mode == "multiclass": + tp, fp, fn, tn = _get_stats_multiclass( + output, target, num_classes, ignore_index + ) + else: + if threshold is not None: + output = torch.where(output >= threshold, 1, 0) + target = torch.where(target >= threshold, 1, 0) + tp, fp, fn, tn = _get_stats_multilabel(output, target) + + return tp, fp, fn, tn + + +@torch.no_grad() +def _get_stats_multiclass( + output: torch.LongTensor, + target: torch.LongTensor, + num_classes: int, + ignore_index: Optional[int], +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: + + batch_size, *dims = output.shape + num_elements = torch.prod(torch.tensor(dims)).long() + + if ignore_index is not None: + ignore = target == ignore_index + output = torch.where(ignore, -1, output) + target = torch.where(ignore, -1, target) + ignore_per_sample = ignore.view(batch_size, -1).sum(1) + + tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long) + fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long) + fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long) + tn_count = torch.zeros(batch_size, num_classes, dtype=torch.long) + + for i in range(batch_size): + target_i = target[i] + output_i = output[i] + mask = output_i == target_i + matched = torch.where(mask, target_i, -1) + tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1) + fp = ( + torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) + - tp + ) + fn = ( + torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) + - tp + ) + tn = num_elements - tp - fp - fn + if ignore_index is not None: + tn = tn - ignore_per_sample[i] + tp_count[i] = tp.long() + fp_count[i] = fp.long() + fn_count[i] = fn.long() + tn_count[i] = tn.long() + + return tp_count, fp_count, fn_count, tn_count + + +@torch.no_grad() +def _get_stats_multilabel( + output: torch.LongTensor, target: torch.LongTensor, +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: + + batch_size, num_classes, *dims = target.shape + output = output.view(batch_size, num_classes, -1) + target = target.view(batch_size, num_classes, -1) + + tp = (output * target).sum(2) + fp = output.sum(2) - tp + fn = target.sum(2) - tp + tn = torch.prod(torch.tensor(dims)) - (tp + fp + fn) + + return tp, fp, fn, tn + + +################################################################################################### +# Metrics computation +################################################################################################### + + +def _handle_zero_division(x, zero_division): + nans = torch.isnan(x) + if torch.any(nans) and zero_division == "warn": + warnings.warn("Zero division in metric calculation!") + value = zero_division if zero_division != "warn" else 0 + value = torch.tensor(value, dtype=x.dtype).to(x.device) + x = torch.where(nans, value, x) + return x + + +def _compute_metric( + metric_fn, + tp, + fp, + fn, + tn, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division="warn", + **metric_kwargs, +) -> float: + + if class_weights is None and reduction is not None and "weighted" in reduction: + raise ValueError( + f"Class weights should be provided for `{reduction}` reduction" + ) + + class_weights = class_weights if class_weights is not None else 1.0 + class_weights = torch.tensor(class_weights).to(tp.device) + class_weights = class_weights / class_weights.sum() + + if reduction == "micro": + tp = tp.sum() + fp = fp.sum() + fn = fn.sum() + tn = tn.sum() + score = metric_fn(tp, fp, fn, tn, **metric_kwargs) + + elif reduction == "macro" or reduction == "weighted": + tp = tp.sum(0) + fp = fp.sum(0) + fn = fn.sum(0) + tn = tn.sum(0) + score = metric_fn(tp, fp, fn, tn, **metric_kwargs) + score = _handle_zero_division(score, zero_division) + score = (score * class_weights).mean() + + elif reduction == "micro-imagewise": + tp = tp.sum(1) + fp = fp.sum(1) + fn = fn.sum(1) + tn = tn.sum(1) + score = metric_fn(tp, fp, fn, tn, **metric_kwargs) + score = _handle_zero_division(score, zero_division) + score = score.mean() + + elif reduction == "macro-imagewise" or reduction == "weighted-imagewise": + score = metric_fn(tp, fp, fn, tn, **metric_kwargs) + score = _handle_zero_division(score, zero_division) + score = (score.mean(0) * class_weights).mean() + + elif reduction == "none" or reduction is None: + score = metric_fn(tp, fp, fn, tn, **metric_kwargs) + score = _handle_zero_division(score, zero_division) + + else: + raise ValueError( + "`reduction` should be in [micro, macro, weighted, micro-imagewise," + + "macro-imagesize, weighted-imagewise, none, None]" + ) + + return score + + +# Logic for metric computation, all metrics are with the same interface + + +def _fbeta_score(tp, fp, fn, tn, beta=1): + beta_tp = (1 + beta ** 2) * tp + beta_fn = (beta ** 2) * fn + score = beta_tp / (beta_tp + beta_fn + fp) + return score + + +def _iou_score(tp, fp, fn, tn): + return tp / (tp + fp + fn) + + +def _accuracy(tp, fp, fn, tn): + return (tp + tn) / (tp + fp + fn + tn) + + +def _sensitivity(tp, fp, fn, tn): + return tp / (tp + fn) + + +def _specificity(tp, fp, fn, tn): + return tn / (tn + fp) + + +def _balanced_accuracy(tp, fp, fn, tn): + return (_sensitivity(tp, fp, fn, tn) + _specificity(tp, fp, fn, tn)) / 2 + + +def _positive_predictive_value(tp, fp, fn, tn): + return tp / (tp + fp) + + +def _negative_predictive_value(tp, fp, fn, tn): + return tn / (tn + fn) + + +def _false_negative_rate(tp, fp, fn, tn): + return fn / (fn + tp) + + +def _false_positive_rate(tp, fp, fn, tn): + return fp / (fp + tn) + + +def _false_discovery_rate(tp, fp, fn, tn): + return 1 - _positive_predictive_value(tp, fp, fn, tn) + + +def _false_omission_rate(tp, fp, fn, tn): + return 1 - _negative_predictive_value(tp, fp, fn, tn) + + +def _positive_likelihood_ratio(tp, fp, fn, tn): + return _sensitivity(tp, fp, fn, tn) / _false_positive_rate(tp, fp, fn, tn) + + +def _negative_likelihood_ratio(tp, fp, fn, tn): + return _false_negative_rate(tp, fp, fn, tn) / _specificity(tp, fp, fn, tn) + + +def fbeta_score( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + beta: float = 1.0, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """F beta score""" + return _compute_metric( + _fbeta_score, + tp, + fp, + fn, + tn, + beta=beta, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def f1_score( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """F1 score""" + return _compute_metric( + _fbeta_score, + tp, + fp, + fn, + tn, + beta=1.0, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def iou_score( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """IoU score or Jaccard index""" # noqa + return _compute_metric( + _iou_score, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def accuracy( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Accuracy""" + return _compute_metric( + _accuracy, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def sensitivity( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Sensitivity, recall, hit rate, or true positive rate (TPR)""" + return _compute_metric( + _sensitivity, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def specificity( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Specificity, selectivity or true negative rate (TNR)""" + return _compute_metric( + _specificity, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def balanced_accuracy( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Balanced accuracy""" + return _compute_metric( + _balanced_accuracy, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def positive_predictive_value( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Precision or positive predictive value (PPV)""" + return _compute_metric( + _positive_predictive_value, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def negative_predictive_value( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Negative predictive value (NPV)""" + return _compute_metric( + _negative_predictive_value, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def false_negative_rate( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Miss rate or false negative rate (FNR)""" + return _compute_metric( + _false_negative_rate, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def false_positive_rate( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Fall-out or false positive rate (FPR)""" + return _compute_metric( + _false_positive_rate, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def false_discovery_rate( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """False discovery rate (FDR)""" # noqa + return _compute_metric( + _false_discovery_rate, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def false_omission_rate( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """False omission rate (FOR)""" # noqa + return _compute_metric( + _false_omission_rate, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def positive_likelihood_ratio( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Positive likelihood ratio (LR+)""" + return _compute_metric( + _positive_likelihood_ratio, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +def negative_likelihood_ratio( + tp: torch.LongTensor, + fp: torch.LongTensor, + fn: torch.LongTensor, + tn: torch.LongTensor, + reduction: Optional[str] = None, + class_weights: Optional[List[float]] = None, + zero_division: Union[str, float] = 1.0, +) -> torch.Tensor: + """Negative likelihood ratio (LR-)""" + return _compute_metric( + _negative_likelihood_ratio, + tp, + fp, + fn, + tn, + reduction=reduction, + class_weights=class_weights, + zero_division=zero_division, + ) + + +_doc = """ + + Args: + tp (torch.LongTensor): tensor of shape (N, C), true positive cases + fp (torch.LongTensor): tensor of shape (N, C), false positive cases + fn (torch.LongTensor): tensor of shape (N, C), false negative cases + tn (torch.LongTensor): tensor of shape (N, C), true negative cases + reduction (Optional[str]): Define how to aggregate metric between classes and images: + + - 'micro' + Sum true positive, false positive, false negative and true negative pixels over + all images and all classes and then compute score. + + - 'macro' + Sum true positive, false positive, false negative and true negative pixels over + all images for each label, then compute score for each label separately and average labels scores. + This does not take label imbalance into account. + + - 'weighted' + Sum true positive, false positive, false negative and true negative pixels over + all images for each label, then compute score for each label separately and average + weighted labels scores. + + - 'micro-imagewise' + Sum true positive, false positive, false negative and true negative pixels for **each image**, + then compute score for **each image** and average scores over dataset. All images contribute equally + to final score, however takes into accout class imbalance for each image. + + - 'macro-imagewise' + Compute score for each image and for each class on that image separately, then compute average score + on each image over labels and average image scores over dataset. Does not take into account label + imbalance on each image. + + - 'weighted-imagewise' + Compute score for each image and for each class on that image separately, then compute weighted average + score on each image over labels and average image scores over dataset. + + - 'none' or ``None`` + Same as ``'macro-imagewise'``, but without any reduction. + + For ``'binary'`` case ``'micro' = 'macro' = 'weighted'`` and + ``'micro-imagewise' = 'macro-imagewise' = 'weighted-imagewise'``. + + Prefixes ``'micro'``, ``'macro'`` and ``'weighted'`` define how the scores for classes will be aggregated, + while postfix ``'imagewise'`` defines how scores between the images will be aggregated. + + class_weights (Optional[List[float]]): list of class weights for metric + aggregation, in case of `weighted*` reduction is chosen. Defaults to None. + zero_division (Union[str, float]): Sets the value to return when there is a zero division, + i.e. when all predictions and labels are negative. If set to “warn”, this acts as 0, + but warnings are also raised. Defaults to 1. + + Returns: + torch.Tensor: if ``'reduction'`` is not ``None`` or ``'none'`` returns scalar metric, + else returns tensor of shape (N, C) + + References: + https://en.wikipedia.org/wiki/Confusion_matrix +""" + +fbeta_score.__doc__ += _doc +f1_score.__doc__ += _doc +iou_score.__doc__ += _doc +accuracy.__doc__ += _doc +sensitivity.__doc__ += _doc +specificity.__doc__ += _doc +balanced_accuracy.__doc__ += _doc +positive_predictive_value.__doc__ += _doc +negative_predictive_value.__doc__ += _doc +false_negative_rate.__doc__ += _doc +false_positive_rate.__doc__ += _doc +false_discovery_rate.__doc__ += _doc +false_omission_rate.__doc__ += _doc +positive_likelihood_ratio.__doc__ += _doc +negative_likelihood_ratio.__doc__ += _doc + +precision = positive_predictive_value +recall = sensitivity diff --git a/segmentation_models_pytorch/utils/__init__.py b/segmentation_models_pytorch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d32952997e45feaa06fa407908000e6e1a9b7b9c --- /dev/null +++ b/segmentation_models_pytorch/utils/__init__.py @@ -0,0 +1,10 @@ +import warnings + +from . import train +from . import losses +from . import metrics + +warnings.warn( + "`smp.utils` module is deprecated and will be removed in future releases.", + DeprecationWarning, +) diff --git a/segmentation_models_pytorch/utils/base.py b/segmentation_models_pytorch/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d5933654e173328aa5e7abec317b4cc04aaf864d --- /dev/null +++ b/segmentation_models_pytorch/utils/base.py @@ -0,0 +1,68 @@ +import re +import torch.nn as nn + + +class BaseObject(nn.Module): + def __init__(self, name=None): + super().__init__() + self._name = name + + @property + def __name__(self): + if self._name is None: + name = self.__class__.__name__ + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + else: + return self._name + + +class Metric(BaseObject): + pass + + +class Loss(BaseObject): + def __add__(self, other): + if isinstance(other, Loss): + return SumOfLosses(self, other) + else: + raise ValueError("Loss should be inherited from `Loss` class") + + def __radd__(self, other): + return self.__add__(other) + + def __mul__(self, value): + if isinstance(value, (int, float)): + return MultipliedLoss(self, value) + else: + raise ValueError("Loss should be inherited from `BaseLoss` class") + + def __rmul__(self, other): + return self.__mul__(other) + + +class SumOfLosses(Loss): + def __init__(self, l1, l2): + name = "{} + {}".format(l1.__name__, l2.__name__) + super().__init__(name=name) + self.l1 = l1 + self.l2 = l2 + + def __call__(self, *inputs): + return self.l1.forward(*inputs) + self.l2.forward(*inputs) + + +class MultipliedLoss(Loss): + def __init__(self, loss, multiplier): + + # resolve name + if len(loss.__name__.split("+")) > 1: + name = "{} * ({})".format(multiplier, loss.__name__) + else: + name = "{} * {}".format(multiplier, loss.__name__) + super().__init__(name=name) + self.loss = loss + self.multiplier = multiplier + + def __call__(self, *inputs): + return self.multiplier * self.loss.forward(*inputs) diff --git a/segmentation_models_pytorch/utils/functional.py b/segmentation_models_pytorch/utils/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec7b3ce75f2ca83a9d6bd771e9dbd06b722639e --- /dev/null +++ b/segmentation_models_pytorch/utils/functional.py @@ -0,0 +1,134 @@ +import torch + + +def _take_channels(*xs, ignore_channels=None): + if ignore_channels is None: + return xs + else: + channels = [ + channel + for channel in range(xs[0].shape[1]) + if channel not in ignore_channels + ] + xs = [ + torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) + for x in xs + ] + return xs + + +def _threshold(x, threshold=None): + if threshold is not None: + return (x > threshold).type(x.dtype) + else: + return x + + +def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate Intersection over Union between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: IoU (Jaccard) score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + intersection = torch.sum(gt * pr) + union = torch.sum(gt) + torch.sum(pr) - intersection + eps + return (intersection + eps) / union + + +jaccard = iou + + +def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate F-score between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + beta (float): positive constant + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: F score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fp = torch.sum(pr) - tp + fn = torch.sum(gt) - tp + + score = ((1 + beta ** 2) * tp + eps) / ( + (1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps + ) + + return score + + +def accuracy(pr, gt, threshold=0.5, ignore_channels=None): + """Calculate accuracy score between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: precision score + """ + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt == pr, dtype=pr.dtype) + score = tp / gt.view(-1).shape[0] + return score + + +def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate precision score between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: precision score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fp = torch.sum(pr) - tp + + score = (tp + eps) / (tp + fp + eps) + + return score + + +def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate Recall between ground truth and prediction + Args: + pr (torch.Tensor): A list of predicted elements + gt (torch.Tensor): A list of elements that are to be predicted + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: recall score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fn = torch.sum(gt) - tp + + score = (tp + eps) / (tp + fn + eps) + + return score diff --git a/segmentation_models_pytorch/utils/losses.py b/segmentation_models_pytorch/utils/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a87c8eb972696db63f08b903a1b84168e1aea6 --- /dev/null +++ b/segmentation_models_pytorch/utils/losses.py @@ -0,0 +1,69 @@ +import torch.nn as nn + +from . import base +from . import functional as F +from ..base.modules import Activation + + +class JaccardLoss(base.Loss): + def __init__(self, eps=1.0, activation=None, ignore_channels=None, **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return 1 - F.jaccard( + y_pr, + y_gt, + eps=self.eps, + threshold=None, + ignore_channels=self.ignore_channels, + ) + + +class DiceLoss(base.Loss): + def __init__( + self, eps=1.0, beta=1.0, activation=None, ignore_channels=None, **kwargs + ): + super().__init__(**kwargs) + self.eps = eps + self.beta = beta + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return 1 - F.f_score( + y_pr, + y_gt, + beta=self.beta, + eps=self.eps, + threshold=None, + ignore_channels=self.ignore_channels, + ) + + +class L1Loss(nn.L1Loss, base.Loss): + pass + + +class MSELoss(nn.MSELoss, base.Loss): + pass + + +class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss): + pass + + +class NLLLoss(nn.NLLLoss, base.Loss): + pass + + +class BCELoss(nn.BCELoss, base.Loss): + pass + + +class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss): + pass diff --git a/segmentation_models_pytorch/utils/meter.py b/segmentation_models_pytorch/utils/meter.py new file mode 100644 index 0000000000000000000000000000000000000000..9faa720cef0368474b8238d6a530cb2bfd5164f3 --- /dev/null +++ b/segmentation_models_pytorch/utils/meter.py @@ -0,0 +1,61 @@ +import numpy as np + + +class Meter(object): + """Meters provide a way to keep track of important statistics in an online manner. + This class is abstract, but provides a standard interface for all meters to follow. + """ + + def reset(self): + """Reset the meter to default settings.""" + pass + + def add(self, value): + """Log a new value to the meter + Args: + value: Next result to include. + """ + pass + + def value(self): + """Get the value of the meter in the current state.""" + pass + + +class AverageValueMeter(Meter): + def __init__(self): + super(AverageValueMeter, self).__init__() + self.reset() + self.val = 0 + + def add(self, value, n=1): + self.val = value + self.sum += value + self.var += value * value + self.n += n + + if self.n == 0: + self.mean, self.std = np.nan, np.nan + elif self.n == 1: + self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy + self.std = np.inf + self.mean_old = self.mean + self.m_s = 0.0 + else: + self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) + self.m_s += (value - self.mean_old) * (value - self.mean) + self.mean_old = self.mean + self.std = np.sqrt(self.m_s / (self.n - 1.0)) + + def value(self): + return self.mean, self.std + + def reset(self): + self.n = 0 + self.sum = 0.0 + self.var = 0.0 + self.val = 0.0 + self.mean = np.nan + self.mean_old = 0.0 + self.m_s = 0.0 + self.std = np.nan diff --git a/segmentation_models_pytorch/utils/metrics.py b/segmentation_models_pytorch/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a7dadc00b713631a7e0937a0cd1f5f01c19fc7 --- /dev/null +++ b/segmentation_models_pytorch/utils/metrics.py @@ -0,0 +1,111 @@ +from . import base +from . import functional as F +from ..base.modules import Activation + + +class IoU(base.Metric): + __name__ = "iou_score" + + def __init__( + self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs + ): + super().__init__(**kwargs) + self.eps = eps + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.iou( + y_pr, + y_gt, + eps=self.eps, + threshold=self.threshold, + ignore_channels=self.ignore_channels, + ) + + +class Fscore(base.Metric): + def __init__( + self, + beta=1, + eps=1e-7, + threshold=0.5, + activation=None, + ignore_channels=None, + **kwargs + ): + super().__init__(**kwargs) + self.eps = eps + self.beta = beta + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.f_score( + y_pr, + y_gt, + eps=self.eps, + beta=self.beta, + threshold=self.threshold, + ignore_channels=self.ignore_channels, + ) + + +class Accuracy(base.Metric): + def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs): + super().__init__(**kwargs) + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.accuracy( + y_pr, y_gt, threshold=self.threshold, ignore_channels=self.ignore_channels, + ) + + +class Recall(base.Metric): + def __init__( + self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs + ): + super().__init__(**kwargs) + self.eps = eps + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.recall( + y_pr, + y_gt, + eps=self.eps, + threshold=self.threshold, + ignore_channels=self.ignore_channels, + ) + + +class Precision(base.Metric): + def __init__( + self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs + ): + super().__init__(**kwargs) + self.eps = eps + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.precision( + y_pr, + y_gt, + eps=self.eps, + threshold=self.threshold, + ignore_channels=self.ignore_channels, + ) diff --git a/segmentation_models_pytorch/utils/train.py b/segmentation_models_pytorch/utils/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6c49e2612efd590f202d6df2e2cae0fae1f051 --- /dev/null +++ b/segmentation_models_pytorch/utils/train.py @@ -0,0 +1,117 @@ +import sys +import torch +from tqdm import tqdm as tqdm +from .meter import AverageValueMeter + + +class Epoch: + def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): + self.model = model + self.loss = loss + self.metrics = metrics + self.stage_name = stage_name + self.verbose = verbose + self.device = device + + self._to_device() + + def _to_device(self): + self.model.to(self.device) + self.loss.to(self.device) + for metric in self.metrics: + metric.to(self.device) + + def _format_logs(self, logs): + str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] + s = ", ".join(str_logs) + return s + + def batch_update(self, x, y): + raise NotImplementedError + + def on_epoch_start(self): + pass + + def run(self, dataloader): + + self.on_epoch_start() + + logs = {} + loss_meter = AverageValueMeter() + metrics_meters = { + metric.__name__: AverageValueMeter() for metric in self.metrics + } + + with tqdm( + dataloader, + desc=self.stage_name, + file=sys.stdout, + disable=not (self.verbose), + ) as iterator: + for x, y in iterator: + x, y = x.to(self.device), y.to(self.device) + loss, y_pred = self.batch_update(x, y) + + # update loss logs + loss_value = loss.cpu().detach().numpy() + loss_meter.add(loss_value) + loss_logs = {self.loss.__name__: loss_meter.mean} + logs.update(loss_logs) + + # update metrics logs + for metric_fn in self.metrics: + metric_value = metric_fn(y_pred, y).cpu().detach().numpy() + metrics_meters[metric_fn.__name__].add(metric_value) + metrics_logs = {k: v.mean for k, v in metrics_meters.items()} + logs.update(metrics_logs) + + if self.verbose: + s = self._format_logs(logs) + iterator.set_postfix_str(s) + + return logs + + +class TrainEpoch(Epoch): + def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): + super().__init__( + model=model, + loss=loss, + metrics=metrics, + stage_name="train", + device=device, + verbose=verbose, + ) + self.optimizer = optimizer + + def on_epoch_start(self): + self.model.train() + + def batch_update(self, x, y): + self.optimizer.zero_grad() + prediction = self.model.forward(x) + loss = self.loss(prediction, y) + loss.backward() + self.optimizer.step() + return loss, prediction + + +class ValidEpoch(Epoch): + def __init__(self, model, loss, metrics, device="cpu", verbose=True): + super().__init__( + model=model, + loss=loss, + metrics=metrics, + stage_name="valid", + device=device, + verbose=verbose, + ) + + def on_epoch_start(self): + self.model.eval() + + def batch_update(self, x, y): + with torch.no_grad(): + prediction = self.model.forward(x) + loss = self.loss(prediction, y) + return loss, prediction diff --git a/sub_model.pth b/sub_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..ba660072a50774e2c67af98f3c1cec1b484f444e --- /dev/null +++ b/sub_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a8cc413267ffa5eed6a9a889ef6d4b96f6eb15aded22041a14173a3283eb96e +size 485716201