GECO2-demo / models /box_corr.py
jerpelhan's picture
Initial commit
6146368
import numpy as np
import skimage
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torchvision.ops import roi_align
from torchvision.transforms import Resize
from .query_generator import C_base
from .sam_mask import MaskProcessor
class Box_correction(nn.Module):
def __init__(
self,
reduction,
image_size,
emb_dim,
):
super(Box_correction, self).__init__()
self.sam_mask = MaskProcessor(emb_dim, image_size, reduction)
self.sam_corr = True
def forward(self, feats, outputs, x):
# mask processing
masks, ious, corrected_bboxes = self.sam_mask(feats, outputs)
for i in range(len(outputs)):
outputs[i]["scores"] = ious[i]
outputs[i]["pred_boxes"] = corrected_bboxes[i].to(outputs[i]["pred_boxes"].device).unsqueeze(0) /x.shape[-1]
return outputs, masks