Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from collections import Counter | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision.datasets import ImageFolder | |
| from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL | |
| class CaptionDataset(Dataset): | |
| def __init__( | |
| self, | |
| image_train_dir_path, | |
| annotations_path, | |
| is_train, | |
| dataset_name, | |
| image_val_dir_path=None, | |
| which_gt=None, | |
| best_gt_caption_path=None, | |
| ): | |
| self.image_train_dir_path = image_train_dir_path | |
| self.image_val_dir_path = image_val_dir_path | |
| self.annotations = [] | |
| self.is_train = is_train | |
| self.dataset_name = dataset_name | |
| full_annotations = json.load(open(annotations_path))["images"] | |
| for i in range(len(full_annotations)): | |
| if self.is_train and full_annotations[i]["split"] != "train": | |
| continue | |
| elif not self.is_train and full_annotations[i]["split"] != "test": | |
| continue | |
| self.annotations.append(full_annotations[i]) | |
| if isinstance(which_gt, str): | |
| self.which_gt = int(which_gt) if which_gt.isdigit() else which_gt | |
| else: | |
| self.which_gt = which_gt | |
| if best_gt_caption_path is not None: | |
| with open(best_gt_caption_path, 'r') as f: | |
| self.best_gt_captions = json.load(f) | |
| else: | |
| self.best_gt_captions = None | |
| def __len__(self): | |
| return len(self.annotations) | |
| def __getitem__(self, idx): | |
| if self.dataset_name == "coco": | |
| image = Image.open( | |
| os.path.join( | |
| self.image_train_dir_path, self.annotations[idx]["filename"] | |
| ) | |
| if self.annotations[idx]["filepath"] == "train2014" | |
| else os.path.join( | |
| self.image_val_dir_path, self.annotations[idx]["filename"] | |
| ) | |
| ) | |
| elif self.dataset_name == "flickr": | |
| image = Image.open( | |
| os.path.join( | |
| self.image_train_dir_path, self.annotations[idx]["filename"] | |
| ) | |
| ) | |
| image.load() | |
| image_id = self.annotations[idx]["cocoid"] if self.dataset_name == "coco" else self.annotations[idx]["filename"].split(".")[0] | |
| if isinstance(self.which_gt, int): | |
| cpt_idx = self.which_gt | |
| elif isinstance(self.which_gt, dict): | |
| cpt_idx = self.which_gt[image_id] | |
| elif self.which_gt == "best": | |
| cpt_idx = self.best_gt_captions[str(image_id)] | |
| else: | |
| assert self.which_gt is None | |
| cpt_idx = 0 | |
| caption = self.annotations[idx]["sentences"][cpt_idx]["raw"] | |
| return { | |
| "image": image, | |
| "caption": caption, | |
| "image_id": image_id, | |
| } | |
| class VQADataset(Dataset): | |
| def __init__( | |
| self, image_dir_path, question_path, annotations_path, is_train, dataset_name, which_gt='all', is_tensor=False | |
| ): | |
| self.questions = json.load(open(question_path, "r"))["questions"] | |
| if annotations_path is not None: | |
| self.answers = json.load(open(annotations_path, "r"))["annotations"] | |
| else: | |
| self.answers = None | |
| self.image_dir_path = image_dir_path | |
| self.is_train = is_train | |
| self.dataset_name = dataset_name | |
| if self.dataset_name in {"vqav2", "ok_vqa"}: | |
| self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] | |
| assert self.img_coco_split in {"train2014", "val2014", "test2015"} | |
| self.which_gt = which_gt | |
| self.is_tensor = is_tensor | |
| def __len__(self): | |
| return len(self.questions) | |
| def get_img_path(self, question): | |
| if self.dataset_name in {"vqav2", "ok_vqa"}: | |
| return os.path.join( | |
| self.image_dir_path, | |
| f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" | |
| if self.is_train | |
| else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg", | |
| ) | |
| elif self.dataset_name == "vizwiz": | |
| return os.path.join(self.image_dir_path, question["image_id"]) | |
| elif self.dataset_name == "textvqa": | |
| return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") | |
| else: | |
| raise Exception(f"Unknown VQA dataset {self.dataset_name}") | |
| def get_from_id(self, question_id): | |
| assert not self.is_train | |
| assert self.dataset_name == "textvqa" | |
| prefix = '' | |
| image_path = f"{self.image_dir_path}/{prefix}{str(question_id).zfill(12)}.pt" | |
| image = torch.load(image_path) | |
| return image | |
| def __getitem__(self, idx): | |
| question = self.questions[idx] | |
| img_path = self.get_img_path(question) | |
| if self.is_tensor: | |
| image_path = img_path.replace("jpg", "pt") | |
| image = torch.load(image_path) | |
| else: | |
| image = Image.open(img_path) | |
| image.load() | |
| results = { | |
| "image": image, | |
| "question": question["question"], | |
| "question_id": question["question_id"], | |
| } | |
| if self.answers is not None: | |
| answers = self.answers[idx] | |
| answers = [a["answer"] for a in answers["answers"]] | |
| if self.which_gt in ["all", None]: | |
| results["answers"] = answers | |
| elif isinstance(self.which_gt, int) or isinstance(self.which_gt, dict): | |
| which_gt = self.which_gt[question["question_id"]] if isinstance(self.which_gt, dict) else self.which_gt | |
| # return the nth most common answer | |
| counter = Counter(answers) | |
| most_common = counter.most_common() | |
| if which_gt >= len(most_common): | |
| results["answers"] = [] | |
| else: | |
| results["answers"] = [most_common[which_gt][0]] | |
| else: | |
| raise ValueError(f"Unknown which_gt: {self.which_gt}") | |
| return results | |
| class ImageNetDataset(ImageFolder): | |
| """Class to represent the ImageNet1k dataset.""" | |
| def __init__(self, root, **kwargs): | |
| super().__init__(root=root, **kwargs) | |
| def __getitem__(self, idx): | |
| sample, target = super().__getitem__(idx) | |
| target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target] | |
| return { | |
| "id": idx, | |
| "image": sample, | |
| "class_id": target, # numeric ID of the ImageNet class | |
| "class_name": target_label, # human-readable name of ImageNet class | |
| } | |
| class HatefulMemesDataset(Dataset): | |
| def __init__(self, image_dir_path, annotations_path): | |
| self.image_dir_path = image_dir_path | |
| with open(annotations_path, "r") as f: | |
| self.annotations = [json.loads(line) for line in f] | |
| def __len__(self): | |
| return len(self.annotations) | |
| def __getitem__(self, idx): | |
| annotation = self.annotations[idx] | |
| img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1]) | |
| image = Image.open(img_path) | |
| image.load() | |
| return { | |
| "id": annotation["id"], | |
| "image": image, | |
| "ocr": annotation["text"], | |
| "class_name": "yes" if annotation["label"] == 1 else "no", | |
| "class_id": annotation["label"], | |
| } | |
| class TensorCaptionDataset(CaptionDataset): | |
| def get_from_id(self, image_id): | |
| assert self.dataset_name == "coco" | |
| assert not self.is_train | |
| # prefix = 'COCO_val2014_' | |
| prefix = '' | |
| image_path = f"{self.image_val_dir_path}/{prefix}{str(image_id).zfill(12)}.pt" | |
| image = torch.load(image_path) | |
| return image | |
| def __getitem__(self, idx): | |
| if self.dataset_name == "coco": | |
| image_path = os.path.join( | |
| self.image_train_dir_path if self.annotations[idx]["filepath"] == "train2014" else self.image_val_dir_path, | |
| self.annotations[idx]["filename"] | |
| ) | |
| image_path = image_path.replace("jpg", "pt") | |
| image = torch.load(image_path) | |
| elif self.dataset_name == "flickr": | |
| raise NotImplementedError | |
| image = Image.open( | |
| os.path.join( | |
| self.image_train_dir_path, self.annotations[idx]["filename"] | |
| ) | |
| ) | |
| caption = self.annotations[idx]["sentences"][0]["raw"] | |
| return { | |
| "image": image, | |
| "caption": caption, | |
| "image_id": self.annotations[idx]["cocoid"] | |
| if self.dataset_name == "coco" | |
| else self.annotations[idx]["filename"].split(".")[0], | |
| } |