|
import sys |
|
import logging |
|
import warnings |
|
from typing import Dict, Any, Sequence |
|
|
|
import torch |
|
from torchvision.ops import box_iou |
|
|
|
from ..utils import ( |
|
MInstrDataset, |
|
BaseComputeMetrics, |
|
) |
|
|
|
from ..process_function import ( |
|
BoxFormatter, |
|
) |
|
|
|
from ..root import ( |
|
DATASETS, |
|
METRICS, |
|
IMAGE_PLACEHOLDER, |
|
BOXES_PLACEHOLDER, |
|
EXPR_PLACEHOLDER, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout), ], |
|
) |
|
|
|
|
|
@DATASETS.register_module() |
|
class RECDataset(MInstrDataset): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, EXPR_PLACEHOLDER)) |
|
|
|
def __getitem__(self, index): |
|
item = self.get_raw_item(index) |
|
img_path = item['img_path'] |
|
expr = item['expression'] |
|
bbox = item['bbox'] |
|
|
|
image = self.get_image(img_path) |
|
question = self.get_template().replace(EXPR_PLACEHOLDER, expr) |
|
|
|
ret = { |
|
'image': image, |
|
'target': { |
|
'boxes': [bbox], |
|
}, |
|
'conversations': [ |
|
{ |
|
'from': 'human', |
|
'value': question, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': f'Answer: {BOXES_PLACEHOLDER} .', |
|
'boxes_seq': [[0]], |
|
} |
|
] |
|
} |
|
return ret |
|
|
|
|
|
@METRICS.register_module() |
|
class RECComputeMetrics(BaseComputeMetrics): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.box_formatter: BoxFormatter = self.preprocessor['target']['boxes'] |
|
|
|
def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]: |
|
failed = 0 |
|
target_failed = 0 |
|
|
|
pred_boxes, target_boxes = [], [] |
|
for pred, target in zip(preds, targets): |
|
extract_pred = self.extract_ans(pred) |
|
extract_target = self.extract_ans(target) |
|
if extract_target is None: |
|
target_failed += 1 |
|
logger.warning(f"failed to extract ans for target: {target}") |
|
continue |
|
if extract_pred is None: |
|
failed += 1 |
|
logger.warning(f"failed to extract ans for pred: {pred}") |
|
extract_pred = [0, 0, 0, 0] |
|
target_boxes.append(extract_target) |
|
pred_boxes.append(extract_pred) |
|
|
|
with torch.no_grad(): |
|
target_boxes = torch.tensor(target_boxes) |
|
pred_boxes = torch.tensor(pred_boxes) |
|
|
|
ious = box_iou(pred_boxes * 1000, target_boxes * 1000) |
|
ious = torch.einsum('i i -> i', ious) |
|
|
|
iou = ious.mean().item() |
|
correct = (ious > 0.5).sum().item() |
|
|
|
|
|
warn_message = "this iou is calculate on normalized box. just for non-rigorous training progress checking." \ |
|
"the value is consistent with real iou only if image.width == image.height." |
|
warnings.warn(warn_message) |
|
|
|
return { |
|
'accuracy': 1.0 * correct / len(targets), |
|
'target_failed': target_failed, |
|
'failed': failed, |
|
'iou': iou, |
|
'warning': warn_message, |
|
} |
|
|
|
def extract_ans(self, string: str): |
|
try: |
|
list_of_boxes = self.box_formatter.extract(string) |
|
if len(list_of_boxes) != 1 or len(list_of_boxes[0]) != 1: |
|
return None |
|
box = list_of_boxes[0][0] |
|
if len(box) != 4: |
|
return None |
|
return box |
|
except Exception as e: |
|
logger.warning(f"extract_ans for {string} but get exception: {e}") |
|
return None |
|
|