Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from PIL.Image import Image as img | |
from PIL.Image import DecompressionBombError | |
from PIL import UnidentifiedImageError | |
import json | |
from pathlib import Path | |
from tqdm import tqdm | |
from typing import List, Tuple, Generator | |
import random | |
from multiprocessing import Pool, cpu_count | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from typing import Tuple | |
from torchtyping import TensorType | |
import traceback | |
def read_jsonl(filename: str) -> Generator[List, None, None]: | |
""" | |
Iterator over data from a jsonl file | |
""" | |
with open(filename) as file: | |
for line in file: | |
yield json.loads(line.rstrip("\n|\r")) | |
def read_img_captions(filename: str) -> List[Tuple[str, str]]: | |
""" | |
Yields image_path, image_caption from cc jsonl files | |
""" | |
img_captions = [] | |
for item in read_jsonl(filename): | |
if not "N/A" in item[-2:]: | |
img_captions.append((item[-1], item[-2])) | |
return img_captions | |
def load_json(filename): | |
try: | |
with open(filename) as f: | |
return json.load(f) | |
except Exception: | |
print(f"ERROR: Error loading json file {filename}") | |
traceback.print_exc() | |
def _read_image_data(data_dir): | |
image_data = [] | |
img_data_dir = data_dir / "image_data" | |
paths = _load_paths(data_dir) | |
pbar = tqdm( | |
paths, | |
desc=f"loading dataset from {str(data_dir)}", | |
) | |
# read data with multiprocessing | |
with Pool(cpu_count()) as pool: | |
for img_data in pool.imap(load_json, pbar): | |
if img_data is not None: | |
image_data.append(img_data) | |
return image_data | |
def _load_paths(data_dir, sort=True): | |
paths = [] | |
img_data_dir = data_dir / "image_data" | |
for p in tqdm( | |
Path(img_data_dir).glob("*/*.json"), | |
desc=f"loading dataset paths from {str(data_dir)}", | |
): | |
paths.append(p) | |
return sorted(paths) | |
class LazyLoader: | |
def __init__(self, data_dir): | |
self.paths = _load_paths(data_dir) | |
def __len__(self): | |
return len(self.paths) | |
def __getitem__(self, idx): | |
data = load_json(self.paths[idx]) | |
if data is None: | |
return self[random.randint(0, len(self) - 1)] | |
return data | |
class ImgCptDataset(Dataset): | |
""" | |
Dataset which loads image caption data from our standard format and transforms them into tensors that can be input to the model. | |
Images are expected to be stored in data_dir/images, image data in data_dir/image_data and each data item is a json file with format {"image_path": img_path, "captions": [caption1, caption2,...], "metadata":{...}} | |
""" | |
def __init__( | |
self, data_dir, tokenizer, transforms, seq_len=2048, load_data_in_memory=False | |
): | |
self.data_dir = Path(data_dir) | |
self.tokenizer = tokenizer | |
self.transforms = transforms | |
self.seq_len = seq_len | |
self.load_data_in_memory = load_data_in_memory | |
if self.load_data_in_memory: | |
self.data = _read_image_data(self.data_dir) | |
else: | |
self.data = LazyLoader(self.data_dir) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__( | |
self, idx | |
) -> Tuple[TensorType["b", "c", "h", "w"], TensorType["b", "s"]]: | |
img_data = self.data[idx] | |
try: | |
try: | |
img_path = self.data_dir / img_data["image_path"] | |
except KeyError as e: | |
# if no image path is found, assume path is same as .json, but .jpg | |
if not self.load_data_in_memory: | |
p = self.data.paths[idx] | |
img_path = ( | |
self.data_dir | |
/ "images" | |
/ Path(p.parent).name | |
/ Path(p.name).with_suffix(".jpg") | |
) | |
else: | |
raise e | |
img = Image.open(img_path) | |
img_tensor = self.transforms(img) | |
caption = random.choice(img_data["captions"]) | |
caption_tensor = self.tokenizer.encode( | |
caption, | |
return_tensors="pt", | |
max_length=self.seq_len, | |
padding="max_length", | |
truncation=True, | |
) | |
return img_tensor, caption_tensor | |
except ( | |
UnidentifiedImageError, | |
OSError, | |
DecompressionBombError, | |
IndexError, | |
) as e: | |
# return random index if image is corrupt | |
print(f"Warning: Could not load image {str(img_path)}") | |
return self[random.randint(0, len(self) - 1)] | |
def collate_fn(batch_data: List[Tuple[torch.Tensor, torch.Tensor]], seq_len=2048): | |
all_images, all_captions = list( | |
zip(*batch_data) | |
) # [(img1, caption1), (img2, caption2), ... ] -> [(img1, img2, ... ), (caption1, caption2, ... )] | |
return torch.cat(all_images), torch.cat([i[:, :seq_len] for i in all_captions]) | |