EditAnything / vlpart /vlpart_fast_rcnn.py
shgao's picture
update new demo
0c7479d
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from typing import Callable, Dict, List, Optional, Tuple, Union
import math
import copy
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.nn import functional as F
from fvcore.nn import sigmoid_focal_loss_jit, giou_loss, smooth_l1_loss
import fvcore.nn.weight_init as weight_init
from detectron2.config import configurable
from detectron2.data.detection_utils import get_fed_loss_cls_weights
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss
from detectron2.structures import Boxes, Instances, BitMasks, pairwise_iou, pairwise_ioa
from detectron2.utils.events import get_event_storage
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
class TexeEmbedClassifier(nn.Module):
def __init__(
self,
input_shape: ShapeSpec,
zs_weight_dim: int = 1024,
norm_weight: bool = True,
norm_temperature: float = 50.0,
):
super().__init__()
if isinstance(input_shape, int): # some backward compatibility
input_shape = ShapeSpec(channels=input_shape)
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
self.norm_weight = norm_weight
self.norm_temperature = norm_temperature
self.linear = nn.Linear(input_size, zs_weight_dim)
def forward(self, x, text_embed):
x = self.linear(x)
if self.norm_weight:
x = self.norm_temperature * F.normalize(x, p=2, dim=1)
x = torch.mm(x, text_embed)
return x
class VLMFastRCNNOutputLayers(nn.Module):
def __init__(
self,
input_shape: ShapeSpec,
box2box_transform,
use_sigmoid_ce: bool = True,
test_score_thresh: float = 0.0,
test_nms_thresh: float = 0.5,
test_topk_per_image: int = 100,
):
super().__init__()
if isinstance(input_shape, int): # some backward compatibility
input_shape = ShapeSpec(channels=input_shape)
self.box2box_transform = box2box_transform
self.use_sigmoid_ce = use_sigmoid_ce
self.test_score_thresh = test_score_thresh
self.test_nms_thresh = test_nms_thresh
self.test_topk_per_image = test_topk_per_image
input_size = input_shape.channels * \
(input_shape.width or 1) * (input_shape.height or 1)
# bbox_pred
self.bbox_pred = nn.Sequential(
nn.Linear(input_size, input_size),
nn.ReLU(inplace=True),
nn.Linear(input_size, 4)
)
# cls_score
self.cls_score = TexeEmbedClassifier(input_shape)
def forward(self, x, text_embed):
if x.dim() > 2:
x = torch.flatten(x, start_dim=1)
cls_scores = self.cls_score(x, text_embed)
proposal_deltas = self.bbox_pred(x)
return cls_scores, proposal_deltas
def predict_boxes(self, predictions, proposals):
if not len(proposals):
return []
_, proposal_deltas = predictions
num_prop_per_image = [len(p) for p in proposals]
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0)
predict_boxes = self.box2box_transform.apply_deltas(
proposal_deltas,
proposal_boxes,
) # Nx(KxB)
return predict_boxes.split(num_prop_per_image)
def predict_probs(self, predictions, proposals):
cls_scores, _ = predictions
num_inst_per_image = [len(p) for p in proposals]
cls_scores = cls_scores.split(num_inst_per_image, dim=0)
final_scores = []
for cls_score in cls_scores:
final_score = cls_score.sigmoid() if self.use_sigmoid_ce else F.softmax(cls_score, dim=-1)
final_scores.append(final_score)
return final_scores