Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import random | |
from abc import abstractmethod | |
from collections import Counter | |
from typing import List | |
import mmengine | |
import numpy as np | |
from mmengine.dataset import BaseDataset | |
from pycocotools.coco import COCO | |
from mmpretrain.registry import DATASETS | |
from .coco_vqa import COCOVQA | |
class FlamingoFewShotMixin: | |
"""Flamingo fewshot eval dataset minin. | |
Args: | |
num_shots (int): Number of shots to perform evaluation. | |
Defaults to 0. | |
Note: 0 does not mean a strict zero-shot in Flamingo setting. | |
It will use 2 only-text prompt without in context images. | |
num_support_examples (int): Number of support examples to get the | |
few shots from. Defaults to 2048. | |
num_query_examples (int): Number of query examples to perform the | |
final evaluation. Defaults to 5000. | |
incontext_prompt_temp (str): In context prompt template for few shot | |
examples. Defaults to ''. | |
final_prompt_temp (str): Final query prompt template. Defaults to ''. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
num_shots: int = 0, | |
num_support_examples: int = 2048, | |
num_query_examples: int = 5000, | |
incontext_prompt_temp: str = '', | |
final_prompt_temp: str = '', | |
**kwarg): | |
self.num_shots = num_shots | |
self.num_support_examples = num_support_examples | |
self.num_query_examples = num_query_examples | |
self.incontext_prompt_temp = incontext_prompt_temp | |
self.final_prompt_temp = final_prompt_temp | |
super().__init__(**kwarg) | |
def get_subset_idx(self, total_num): | |
random_idx = np.random.choice( | |
total_num, | |
self.num_support_examples + self.num_query_examples, | |
replace=False) | |
support_idx = random_idx[:self.num_support_examples] | |
query_idx = random_idx[self.num_support_examples:] | |
return support_idx, query_idx | |
def parse_basic_anno(self, anno: dict) -> dict: | |
"""Parse basic annotation for support and query set.""" | |
pass | |
def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict: | |
"""Parse fewshot related annotation for query set with support list.""" | |
pass | |
class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA): | |
"""Flamingo few shot VQAv2 dataset. | |
Args: | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. | |
ann_file (str): Annotation file path. | |
question_file (str): Question file path. | |
num_shots (int): Number of shots to perform evaluation. | |
Defaults to 0. | |
Note: 0 does not mean a strict zero-shot in Flamingo setting. | |
It will use 2 only-text prompt without in context images. | |
num_support_examples (int): Number of support examples to get the | |
few shots from. Defaults to 2048. | |
num_query_examples (int): Number of query examples to perform the | |
final evaluation. Defaults to 5000. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
question_file: str, | |
ann_file: str = '', | |
num_shots: int = 0, | |
num_support_examples: int = 2048, | |
num_query_examples: int = 5000, | |
**kwarg): | |
super().__init__( | |
data_root=data_root, | |
question_file=question_file, | |
ann_file=ann_file, | |
num_shots=num_shots, | |
num_support_examples=num_support_examples, | |
num_query_examples=num_query_examples, | |
**kwarg) | |
def parse_basic_anno(self, ann: dict) -> dict: | |
"""Parse basic annotation for support and query set. | |
Args: | |
anno (dict): Annotation for single example. | |
Return: | |
dict: Parsed annotation for single example. | |
""" | |
if ann is None: | |
return {} | |
answers = [a['answer'] for a in ann['answers']] | |
count = Counter(answers) | |
answer_weight = [i / len(answers) for i in count.values()] | |
answer_info = { | |
'gt_answer': list(count.keys()), | |
'gt_answer_weight': answer_weight | |
} | |
return answer_info | |
def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: | |
"""Parse fewshot related annotation for query set with support list. | |
Args: | |
anno (dict): Annotation for single example. | |
support_list (List): List of support subset to subsample few shots. | |
Return: | |
dict: Parsed annotation for single example. | |
""" | |
# prepare n shots examples | |
shots = random.sample(support_list, self.num_shots) | |
# append image path for n shots | |
img_path = [shot['img_path'] for shot in shots] | |
img_path.append(query['img_path']) | |
query['img_path'] = img_path | |
query['shots'] = [ | |
dict( | |
question=item['question'], | |
answer=item['gt_answer'][0], | |
) for item in shots | |
] | |
return query | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
questions = mmengine.load(self.question_file)['questions'] | |
if self.ann_file: | |
annotations = mmengine.load(self.ann_file)['annotations'] | |
assert len(questions) == len(annotations) | |
else: | |
annotations = [None] * len(questions) | |
if self.num_shots > 0: | |
raise ValueError('Unable to construct few-shot examples ' | |
'since no annotation file.') | |
# The original VQAv2 annotation file and question file includes | |
# only image id but no image file paths. | |
self.image_index = self._create_image_index() | |
num_data = len(questions) | |
support_idx, query_idx = self.get_subset_idx(num_data) | |
# prepare support subset | |
if self.num_shots > 0: | |
support_list = [] | |
for idx in support_idx: | |
question = questions[idx] | |
ann = annotations[idx] | |
support = {**question, **self.parse_basic_anno(ann)} | |
support['img_path'] = self.image_index[question['image_id']] | |
support_list.append(support) | |
# prepare query subset | |
data_list = [] | |
for idx in query_idx: | |
question = questions[idx] | |
ann = annotations[idx] | |
data_info = {**question, **self.parse_basic_anno(ann)} | |
data_info['img_path'] = self.image_index[question['image_id']] | |
if self.num_shots > 0: | |
data_info = self.parse_fewshot_anno(data_info, support_list) | |
data_list.append(data_info) | |
return data_list | |
class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset): | |
"""Flamingo few shot COCO Caption dataset. | |
Args: | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. | |
ann_file (str): Annotation file path. | |
data_prefix (dict): Prefix for data field. Defaults to | |
``dict(img_path='')``. | |
num_shots (int): Number of shots to perform evaluation. | |
Defaults to 0. | |
num_support_examples (int): Number of support examples to get the | |
few shots from. Defaults to 2048. | |
num_query_examples (int): Number of query examples to perform the | |
final evaluation. Defaults to 5000. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
ann_file: str, | |
num_shots: int = 0, | |
num_support_examples: int = 2048, | |
num_query_examples: int = 5000, | |
**kwarg): | |
super().__init__( | |
data_root=data_root, | |
ann_file=ann_file, | |
num_shots=num_shots, | |
num_support_examples=num_support_examples, | |
num_query_examples=num_query_examples, | |
**kwarg) | |
def parse_basic_anno(self, ann: dict, coco: COCO) -> dict: | |
"""Parse basic annotation for support and query set. | |
Args: | |
anno (dict): Annotation for single example. | |
coco (COCO): The coco dataset. | |
Return: | |
dict: Parsed annotation for single example. | |
""" | |
img_prefix = self.data_prefix['img_path'] | |
img = coco.imgs[ann['image_id']] | |
data_info = dict( | |
img_path=mmengine.join_path(img_prefix, img['file_name']), | |
gt_caption=ann['caption'], | |
image_id=ann['image_id'], | |
) | |
return data_info | |
def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: | |
"""Parse fewshot related annotation for query set with support list. | |
Args: | |
query (dict): Annotation for single example. | |
support_list (List): List of support subset to subsample few shots. | |
coco (COCO): The coco dataset. | |
Return: | |
dict: Parsed annotation for single example. | |
""" | |
# prepare n shots examples | |
shots = random.sample(support_list, self.num_shots) | |
# append image path for n shots | |
img_path = [shot['img_path'] for shot in shots] | |
img_path.append(query['img_path']) | |
query['img_path'] = img_path | |
query['shots'] = [dict(caption=item['gt_caption']) for item in shots] | |
return query | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
with mmengine.get_local_path(self.ann_file) as ann_file: | |
coco = COCO(ann_file) | |
num_data = len(coco.anns) | |
support_idx, query_idx = self.get_subset_idx(num_data) | |
ann_ids = list(coco.anns) | |
# prepare support subset | |
if self.num_shots > 0: | |
support_list = [] | |
for idx in support_idx: | |
support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) | |
support_list.append(support) | |
# prepare query subset | |
query_list = [] | |
for idx in query_idx: | |
data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) | |
if self.num_shots > 0: | |
data_info = self.parse_fewshot_anno(data_info, support_list) | |
query_list.append(data_info) | |
return query_list | |