stevengrove
initial commit
186701e
raw
history blame
4.68 kB
# Copyright (c) Tencent Inc. All rights reserved.
import json
import random
from typing import Tuple
import numpy as np
from mmyolo.registry import TRANSFORMS
@TRANSFORMS.register_module()
class RandomLoadText:
def __init__(self,
text_path: str = None,
prompt_format: str = '{}',
num_neg_samples: Tuple[int, int] = (80, 80),
max_num_samples: int = 80,
padding_to_max: bool = False,
padding_value: str = '') -> None:
self.prompt_format = prompt_format
self.num_neg_samples = num_neg_samples
self.max_num_samples = max_num_samples
self.padding_to_max = padding_to_max
self.padding_value = padding_value
if text_path is not None:
with open(text_path, 'r') as f:
self.class_texts = json.load(f)
def __call__(self, results: dict) -> dict:
assert 'texts' in results or hasattr(self, 'class_texts'), (
'No texts found in results.')
class_texts = results.get(
'texts',
getattr(self, 'class_texts', None))
num_classes = len(class_texts)
if 'gt_labels' in results:
gt_label_tag = 'gt_labels'
elif 'gt_bboxes_labels' in results:
gt_label_tag = 'gt_bboxes_labels'
else:
raise ValueError('No valid labels found in results.')
positive_labels = set(results[gt_label_tag])
if len(positive_labels) > self.max_num_samples:
positive_labels = set(random.sample(list(positive_labels),
k=self.max_num_samples))
num_neg_samples = min(
min(num_classes, self.max_num_samples) - len(positive_labels),
random.randint(*self.num_neg_samples))
candidate_neg_labels = []
for idx in range(num_classes):
if idx not in positive_labels:
candidate_neg_labels.append(idx)
negative_labels = random.sample(
candidate_neg_labels, k=num_neg_samples)
sampled_labels = list(positive_labels) + list(negative_labels)
random.shuffle(sampled_labels)
label2ids = {label: i for i, label in enumerate(sampled_labels)}
gt_valid_mask = np.zeros(len(results['gt_bboxes']), dtype=bool)
for idx, label in enumerate(results[gt_label_tag]):
if label in label2ids:
gt_valid_mask[idx] = True
results[gt_label_tag][idx] = label2ids[label]
results['gt_bboxes'] = results['gt_bboxes'][gt_valid_mask]
results[gt_label_tag] = results[gt_label_tag][gt_valid_mask]
if 'instances' in results:
retaged_instances = []
for idx, inst in enumerate(results['instances']):
label = inst['bbox_label']
if label in label2ids:
inst['bbox_label'] = label2ids[label]
retaged_instances.append(inst)
results['instances'] = retaged_instances
texts = []
for label in sampled_labels:
cls_caps = class_texts[label]
assert len(cls_caps) > 0
cap_id = random.randrange(len(cls_caps))
sel_cls_cap = self.prompt_format.format(cls_caps[cap_id])
texts.append(sel_cls_cap)
if self.padding_to_max:
num_valid_labels = len(positive_labels) + len(negative_labels)
num_padding = self.max_num_samples - num_valid_labels
if num_padding > 0:
texts += [self.padding_value] * num_padding
results['texts'] = texts
return results
@TRANSFORMS.register_module()
class LoadText:
def __init__(self,
text_path: str = None,
prompt_format: str = '{}',
multi_prompt_flag: str = '/') -> None:
self.prompt_format = prompt_format
self.multi_prompt_flag = multi_prompt_flag
if text_path is not None:
with open(text_path, 'r') as f:
self.class_texts = json.load(f)
def __call__(self, results: dict) -> dict:
assert 'texts' in results or hasattr(self, 'class_texts'), (
'No texts found in results.')
class_texts = results.get(
'texts',
getattr(self, 'class_texts', None))
texts = []
for idx, cls_caps in enumerate(class_texts):
assert len(cls_caps) > 0
sel_cls_cap = cls_caps[0]
sel_cls_cap = self.prompt_format.format(sel_cls_cap)
texts.append(sel_cls_cap)
results['texts'] = texts
return results