|
import os |
|
import json |
|
import pickle |
|
import random |
|
import time |
|
import itertools |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import skimage.io as io |
|
import matplotlib.pyplot as plt |
|
from matplotlib.collections import PatchCollection |
|
from matplotlib.patches import Polygon, Rectangle |
|
from torch.utils.data import Dataset |
|
import webdataset as wds |
|
|
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset |
|
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset |
|
|
|
|
|
class GroundedDetailDataset(Dataset): |
|
def __init__(self, vis_processor, text_processor, vis_root, ann_path): |
|
""" |
|
vis_root (string): Root directory of images (e.g. coco/images/) |
|
ann_root (string): directory to store the annotation file |
|
""" |
|
self.vis_root = vis_root |
|
|
|
self.vis_processor = vis_processor |
|
self.text_processor = text_processor |
|
|
|
self.instruction_pool = [ |
|
'[grounding] please describe this image in details', |
|
'[grounding] describe this image as detailed as possible', |
|
'[grounding] summarize this image in details', |
|
'[grounding] give a thorough description of what you see in this image', |
|
] |
|
|
|
with open(ann_path, 'r') as f: |
|
self.ann = json.load(f) |
|
|
|
def __len__(self): |
|
return len(self.ann) |
|
|
|
def __getitem__(self, index): |
|
info = self.ann[index] |
|
|
|
image_file = 'COCO_train2014_{}.jpg'.format(info['image_id']) |
|
image_path = os.path.join(self.vis_root, image_file) |
|
image = Image.open(image_path).convert("RGB") |
|
image = self.vis_processor(image) |
|
|
|
answer = info['grounded_caption'] |
|
|
|
instruction = random.choice(self.instruction_pool) |
|
|
|
instruction = "<Img><ImageHere></Img> {} ".format(instruction) |
|
|
|
return { |
|
"image": image, |
|
"instruction_input": instruction, |
|
"answer": answer, |
|
"image_id": info['image_id'], |
|
} |
|
|