Spaces:
Runtime error
Runtime error
from typing import Dict, List, Optional, Tuple | |
import torch | |
from detectron2.config import configurable | |
from detectron2.structures import ImageList, Instances, Boxes | |
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN | |
class GRiT(GeneralizedRCNN): | |
def __init__( | |
self, | |
**kwargs): | |
super().__init__(**kwargs) | |
assert self.proposal_generator is not None | |
def from_config(cls, cfg): | |
ret = super().from_config(cfg) | |
return ret | |
def inference( | |
self, | |
batched_inputs: Tuple[Dict[str, torch.Tensor]], | |
detected_instances: Optional[List[Instances]] = None, | |
do_postprocess: bool = True, | |
): | |
assert not self.training | |
assert detected_instances is None | |
images = self.preprocess_image(batched_inputs) | |
features = self.backbone(images.tensor) | |
proposals, _ = self.proposal_generator(images, features, None) | |
results, _ = self.roi_heads(features, proposals) | |
if do_postprocess: | |
assert not torch.jit.is_scripting(), \ | |
"Scripting is not supported for postprocess." | |
return GRiT._postprocess( | |
results, batched_inputs, images.image_sizes) | |
else: | |
return results | |
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
if not self.training: | |
return self.inference(batched_inputs) | |
images = self.preprocess_image(batched_inputs) | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
targets_task = batched_inputs[0]['task'] | |
for anno_per_image in batched_inputs: | |
assert targets_task == anno_per_image['task'] | |
features = self.backbone(images.tensor) | |
proposals, proposal_losses = self.proposal_generator( | |
images, features, gt_instances) | |
proposals, roihead_textdecoder_losses = self.roi_heads( | |
features, proposals, gt_instances, targets_task=targets_task) | |
losses = {} | |
losses.update(roihead_textdecoder_losses) | |
losses.update(proposal_losses) | |
return losses |