from .vqa_dataset import VQADataset class LlavaDataset(VQADataset): def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs) def _add_instance_ids(self, key="id"): for idx, ann in enumerate(self.annotation): ann[key] = str(idx) def process_text(self, ann): question = ann["conversations"][0]["value"] # remove '' tag and '\n' question = question.replace("", "").replace("\n", "") answer = ann["conversations"][1]["value"] instruction = self.prompter(question) return dict(instruction=instruction, answer=answer)