zyliu's picture
release iChatApp
0f90f73
raw
history blame
No virus
2.27 kB
from typing import Dict, List, Optional, Tuple
import torch
from detectron2.config import configurable
from detectron2.structures import ImageList, Instances, Boxes
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
@META_ARCH_REGISTRY.register()
class GRiT(GeneralizedRCNN):
@configurable
def __init__(
self,
**kwargs):
super().__init__(**kwargs)
assert self.proposal_generator is not None
@classmethod
def from_config(cls, cfg):
ret = super().from_config(cfg)
return ret
def inference(
self,
batched_inputs: Tuple[Dict[str, torch.Tensor]],
detected_instances: Optional[List[Instances]] = None,
do_postprocess: bool = True,
):
assert not self.training
assert detected_instances is None
images = self.preprocess_image(batched_inputs)
features = self.backbone(images.tensor)
proposals, _ = self.proposal_generator(images, features, None)
results, _ = self.roi_heads(features, proposals)
if do_postprocess:
assert not torch.jit.is_scripting(), \
"Scripting is not supported for postprocess."
return GRiT._postprocess(
results, batched_inputs, images.image_sizes)
else:
return results
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets_task = batched_inputs[0]['task']
for anno_per_image in batched_inputs:
assert targets_task == anno_per_image['task']
features = self.backbone(images.tensor)
proposals, proposal_losses = self.proposal_generator(
images, features, gt_instances)
proposals, roihead_textdecoder_losses = self.roi_heads(
features, proposals, gt_instances, targets_task=targets_task)
losses = {}
losses.update(roihead_textdecoder_losses)
losses.update(proposal_losses)
return losses