OMG_Seg / seg /models /data_preprocessor /ovsam_preprocessor.py
HarborYuan's picture
add omg code
b34d1d6
raw
history blame
No virus
20.5 kB
import copy
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from mmdet.models import DetDataPreprocessor
from mmdet.registry import MODELS
from kornia.contrib import distance_transform
from mmengine.structures import InstanceData
from seg.models.data_preprocessor import VideoSegDataPreprocessor
def get_center_coords(gt_instances, rescale_shape=None, device='cpu'):
if rescale_shape is not None:
masks = gt_instances.masks
masks = masks.rescale(rescale_shape)
else:
masks = gt_instances.masks
masks = masks.to_tensor(dtype=torch.bool, device=device)[:, None]
point_coords = []
for mask in masks:
mask = mask[None]
n, _, h, w = mask.shape
mask_dt = (
distance_transform(
(~F.pad(mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float()
)[:, :, 1:-1, 1:-1]
)
selected_point = torch.tensor([mask_dt.argmax() / w, mask_dt.argmax() % w]).long().flip(0).to(
device)
point_coords.append(selected_point)
if len(point_coords) > 0:
point_coords = torch.stack(point_coords)[:, None]
else:
point_coords = torch.empty((0, 1, 2), dtype=torch.int32).to(device=device)
return point_coords
def get_random_points(gt_instances, device='cpu'):
point_coords = []
for instance_idx in range(len(gt_instances)):
mask = gt_instances.masks.masks[instance_idx]
candidate_indices = torch.tensor(mask, device=device).nonzero()
assert len(candidate_indices) > 0
selected_point = candidate_indices[torch.randperm(
len(candidate_indices), dtype=torch.int32, device=device)[0]].flip(0)
point_coords.append(selected_point)
if len(point_coords) > 0:
point_coords = torch.stack(point_coords)[:, None]
else:
point_coords = torch.empty((0, 1, 2), dtype=torch.int32).to(device=device)
return point_coords
@MODELS.register_module()
class OVSAMDataPreprocessor(DetDataPreprocessor):
def __init__(self, *args,
use_det: bool = False,
use_point: bool = False,
use_center_point: bool = False,
use_point_det: bool = False,
use_center_point_det: bool = False,
use_point_pseudo_box: bool = False,
use_img_center: bool = False,
use_custom_bbox: Optional[Tuple] = None,
use_custom_point: Optional[Tuple] = None,
num_proposals: int = 60,
default_mode: str = 'sam',
**kwargs):
super().__init__(*args, **kwargs)
self.num_proposals = num_proposals
self.use_det = use_det
self.use_point = use_point
self.use_center_point = use_center_point
self.use_point_det = use_point_det
self.use_center_point_det = use_center_point_det
self.use_point_pseudo_box = use_point_pseudo_box
self.use_img_center = use_img_center
self.use_custom_bbox = use_custom_bbox
self.use_custom_point = use_custom_point
self.default_mode = default_mode
def forward(self, data: dict, training: bool = False) -> dict:
data = super().forward(data, training=training)
inputs, data_samples = data['inputs'], data['data_samples']
if 'data_tag' in data_samples[0]:
data_tag = data_samples[0].data_tag
for i in range(1, len(data_samples)):
assert data_samples[i].data_tag == data_tag
else:
data_tag = self.default_mode
for i in range(0, len(data_samples)):
data_samples[i].data_tag = data_tag
device = inputs.device
if data_tag == 'sam_mul':
for data_sample in data_samples:
gt_instances_collected = data_sample.gt_instances_collected
gt_instances = data_sample.gt_instances
masks_list = []
for idx in range(len(gt_instances_collected)):
gt_ids = gt_instances_collected.sub_instances[idx]
masks_list.append(gt_instances.masks[gt_ids])
gt_instances = InstanceData(
labels=torch.zeros_like(gt_instances_collected.idx),
masks=masks_list,
point_coords=gt_instances_collected.point_coords,
bp=torch.zeros_like(gt_instances_collected.idx), # all box
)
# all points
data_sample.gt_instances = gt_instances
del data_sample.gt_instances_collected
elif data_tag == 'sam':
num_proposals = self.num_proposals if training else 10000000
if self.use_custom_bbox:
for data_sample in data_samples:
img_shape = data_sample.img_shape
data_sample.gt_instances = InstanceData(
bboxes=inputs.new_tensor([[img_shape[1] * self.use_custom_bbox[0],
img_shape[0] * self.use_custom_bbox[1],
img_shape[1] * self.use_custom_bbox[2],
img_shape[0] * self.use_custom_bbox[3]]])
)
elif self.use_img_center:
for data_sample in data_samples:
data_sample.gt_instances = InstanceData(
point_coords=inputs.new_tensor([[[data_sample.img_shape[1] / 2, data_sample.img_shape[0] / 2]]])
)
elif self.use_custom_point:
for data_sample in data_samples:
data_sample.gt_instances = InstanceData(
point_coords=inputs.new_tensor([[[self.use_custom_point[0], self.use_custom_point[1]]]])
)
elif self.use_det:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if not training:
bboxes = gt_instances.bboxes
scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2)
bboxes = bboxes * scale_factor
gt_instances.bboxes = bboxes
num_ins = len(gt_instances)
bp_indicator = torch.zeros((num_ins,))
gt_instances.bp = bp_indicator.to(device=device)
data_sample.gt_instances = gt_instances
elif self.use_point_det:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
if len(gt_instances) < num_proposals:
num_copy = num_proposals // len(gt_instances) + 1
gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)])
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_random_points(gt_instances, device=device)
else:
raise NotImplementedError
num_ins = len(gt_instances)
bp_indicator = torch.arange(2).repeat_interleave((num_ins // 2) + 1)[:num_ins]
gt_instances = gt_instances[torch.randperm(num_ins, device=device)]
gt_instances.bp = bp_indicator.to(device=device)
data_sample.gt_instances = gt_instances
elif self.use_center_point_det:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_center_coords(gt_instances, device=device)
else:
gt_instances.point_coords = get_center_coords(
gt_instances, rescale_shape=data_sample.img_shape, device=device
)
bboxes = gt_instances.bboxes
scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2)
bboxes = bboxes * scale_factor
gt_instances.bboxes = bboxes
data_sample.gt_instances = gt_instances
elif self.use_point:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_random_points(gt_instances, device=device)
else:
raise NotImplementedError
data_sample.gt_instances = gt_instances
elif self.use_center_point:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_center_coords(gt_instances, device=device)
else:
gt_instances.point_coords = get_center_coords(
gt_instances, rescale_shape=data_sample.img_shape, device=device
)
data_sample.gt_instances = gt_instances
elif self.use_point_pseudo_box:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
if training:
if len(gt_instances) < num_proposals:
num_copy = num_proposals // len(gt_instances) + 1
gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)])
gt_instances = gt_instances[:num_proposals]
points = get_random_points(gt_instances, device=device)
else:
points = get_center_coords(
gt_instances, rescale_shape=data_sample.img_shape, device=device
)
points = points.squeeze(1)
gt_instances.point_coords = torch.cat([points - 3, points + 3], 1)
gt_instances.bp = torch.zeros_like(gt_instances.labels) # bug to match sam_mul
data_sample.gt_instances = gt_instances
else:
raise NotImplementedError
elif data_tag == 'coco':
pass
elif data_tag == 'img':
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
h, w = data_sample.img_shape
gt_instances.bboxes = torch.tensor(
[[0., 0., h, w]], dtype=torch.float32, device=gt_instances.labels.device
)
gt_instances.bp = torch.zeros((1,), dtype=torch.int32, device=gt_instances.labels.device)
elif data_tag == 'mosaic_img':
b, three, h, w = inputs.shape
num_img_per_batch = 4 * 4
assert b % num_img_per_batch == 0
target_h, target_w = h * 4, w * 4
new_b = b // num_img_per_batch
result_input = inputs.new_empty(b // num_img_per_batch, three, target_h, target_w)
cnt = 0
result_data_samples = []
for id_b in range(new_b):
cur_data_sample = data_samples[cnt]
cur_gt_instances = []
for id_x in range(4):
for id_y in range(4):
result_input[id_b, :, id_x * h: (id_x + 1) * h, id_y * w: (id_y + 1) * w] = inputs[cnt]
img_gt_instances = data_samples[cnt].gt_instances
img_gt_instances.bboxes += img_gt_instances.bboxes.new_tensor([
id_x * h, id_y * w, id_x * h, id_y * w
])
cur_gt_instances.append(img_gt_instances)
cnt += 1
cur_gt_instances = InstanceData.cat(cur_gt_instances)
cur_data_sample.gt_instances = cur_gt_instances
result_data_samples.append(cur_data_sample)
inputs = result_input
data_samples = result_data_samples
else:
raise NotImplementedError
return dict(inputs=inputs, data_samples=data_samples)
@MODELS.register_module()
class OVSAMVideoSegDataPreprocessor(VideoSegDataPreprocessor):
def __init__(self, *args,
use_det: bool = False,
use_point: bool = False,
use_center_point: bool = False,
use_point_det: bool = False,
use_center_point_det: bool = False,
use_point_pseudo_box: bool = False,
num_proposals: int = 60,
**kwargs):
super().__init__(*args, **kwargs)
self.num_proposals = num_proposals
self.use_det = use_det
self.use_point = use_point
self.use_center_point = use_center_point
self.use_point_det = use_point_det
self.use_center_point_det = use_center_point_det
self.use_point_pseudo_box = use_point_pseudo_box
def forward(self, data: dict, training: bool = False) -> dict:
data = super().forward(data, training=training)
inputs, data_samples = data['inputs'], data['data_samples']
if 'data_tag' in data_samples[0]:
data_tag = data_samples[0].data_tag
for i in range(1, len(data_samples)):
assert data_samples[i].data_tag == data_tag
else:
data_tag = 'sam'
for i in range(0, len(data_samples)):
data_samples[i].data_tag = data_tag
device = inputs.device
if data_tag == 'sam_mul':
for data_sample in data_samples:
gt_instances_collected = data_sample.gt_instances_collected
gt_instances = data_sample.gt_instances
masks_list = []
for idx in range(len(gt_instances_collected)):
gt_ids = gt_instances_collected.sub_instances[idx]
masks_list.append(gt_instances.masks[gt_ids])
gt_instances = InstanceData(
labels=torch.zeros_like(gt_instances_collected.idx),
masks=masks_list,
point_coords=gt_instances_collected.point_coords,
bp=torch.zeros_like(gt_instances_collected.idx), # all box
)
# all points
data_sample.gt_instances = gt_instances
del data_sample.gt_instances_collected
elif data_tag == 'sam':
num_proposals = self.num_proposals if training else 10000000
if self.use_det:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if not training:
bboxes = gt_instances.bboxes
scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2)
bboxes = bboxes * scale_factor
gt_instances.bboxes = bboxes
data_sample.gt_instances = gt_instances
elif self.use_point_det:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
if len(gt_instances) < num_proposals:
num_copy = num_proposals // len(gt_instances) + 1
gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)])
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_random_points(gt_instances, device=device)
else:
raise NotImplementedError
num_ins = len(gt_instances)
bp_indicator = torch.arange(2).repeat_interleave((num_ins // 2) + 1)[:num_ins]
gt_instances = gt_instances[torch.randperm(num_ins, device=device)]
gt_instances.bp = bp_indicator.to(device=device)
data_sample.gt_instances = gt_instances
elif self.use_center_point_det:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_center_coords(gt_instances, device=device)
else:
gt_instances.point_coords = get_center_coords(
gt_instances, rescale_shape=data_sample.img_shape, device=device
)
bboxes = gt_instances.bboxes
scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2)
bboxes = bboxes * scale_factor
gt_instances.bboxes = bboxes
data_sample.gt_instances = gt_instances
elif self.use_point:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_random_points(gt_instances, device=device)
else:
raise NotImplementedError
data_sample.gt_instances = gt_instances
elif self.use_center_point:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
gt_instances = gt_instances[:num_proposals]
if training:
gt_instances.point_coords = get_center_coords(gt_instances, device=device)
else:
gt_instances.point_coords = get_center_coords(
gt_instances, rescale_shape=data_sample.img_shape, device=device
)
data_sample.gt_instances = gt_instances
elif self.use_point_pseudo_box:
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
if training:
if len(gt_instances) < num_proposals:
num_copy = num_proposals // len(gt_instances) + 1
gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)])
gt_instances = gt_instances[:num_proposals]
points = get_random_points(gt_instances, device=device)
else:
points = get_center_coords(
gt_instances, rescale_shape=data_sample.img_shape, device=device
)
points = points.squeeze(1)
gt_instances.point_coords = torch.cat([points - 3, points + 3], 1)
gt_instances.bp = torch.zeros_like(gt_instances.labels) # bug to match sam_mul
data_sample.gt_instances = gt_instances
else:
raise NotImplementedError
elif data_tag == 'coco':
pass
elif data_tag == 'img':
for data_sample in data_samples:
gt_instances = data_sample.gt_instances
h, w = data_sample.img_shape
gt_instances.bboxes = torch.tensor(
[[0., 0., h, w]], dtype=torch.float32, device=gt_instances.labels.device
)
gt_instances.bp = torch.zeros((1,), dtype=torch.int32, device=gt_instances.labels.device)
else:
raise NotImplementedError
return dict(inputs=inputs, data_samples=data_samples)