|
import os |
|
import logging |
|
import json |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
import torch |
|
import lmdb |
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode |
|
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.utils.data.sampler import SequentialSampler |
|
import torchvision.datasets as datasets |
|
from clip import tokenize |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert('RGB') |
|
|
|
|
|
def _preprocess_text(text): |
|
|
|
text = text.lower().replace("“", "\"").replace("”", "\"") |
|
return text |
|
|
|
|
|
class EvalTxtDataset(Dataset): |
|
def __init__(self, jsonl_filename, max_txt_length=24): |
|
assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename) |
|
|
|
logging.debug(f'Loading jsonl data from {jsonl_filename}.') |
|
self.texts = [] |
|
with open(jsonl_filename, "r", encoding="utf-8") as fin: |
|
for line in fin: |
|
obj = json.loads(line.strip()) |
|
text_id = obj['text_id'] |
|
text = obj['text'] |
|
self.texts.append((text_id, text)) |
|
logging.debug(f'Finished loading jsonl data from {jsonl_filename}.') |
|
|
|
self.max_txt_length = max_txt_length |
|
|
|
def __len__(self): |
|
return len(self.texts) |
|
|
|
def __getitem__(self, idx): |
|
text_id, text = self.texts[idx] |
|
text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0] |
|
return text_id, text |
|
|
|
|
|
class EvalImgDataset(Dataset): |
|
def __init__(self, lmdb_imgs, resolution=224): |
|
assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs) |
|
|
|
logging.debug(f'Loading image LMDB from {lmdb_imgs}.') |
|
|
|
self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False) |
|
self.txn_imgs = self.env_imgs.begin(buffers=True) |
|
self.cursor_imgs = self.txn_imgs.cursor() |
|
self.iter_imgs = iter(self.cursor_imgs) |
|
self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8')) |
|
logging.info("The specified LMDB directory contains {} images.".format(self.number_images)) |
|
|
|
self.transform = self._build_transform(resolution) |
|
|
|
def _build_transform(self, resolution): |
|
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
return Compose([ |
|
Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
]) |
|
|
|
def __len__(self): |
|
return self.number_images |
|
|
|
def __getitem__(self, idx): |
|
img_id, image_b64 = next(self.iter_imgs) |
|
if img_id == b"num_images": |
|
img_id, image_b64 = next(self.iter_imgs) |
|
|
|
img_id = img_id.tobytes() |
|
image_b64 = image_b64.tobytes() |
|
|
|
img_id = int(img_id.decode(encoding="utf8", errors="ignore")) |
|
image_b64 = image_b64.decode(encoding="utf8", errors="ignore") |
|
image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) |
|
image = self.transform(image) |
|
|
|
return img_id, image |
|
|
|
|
|
@dataclass |
|
class DataInfo: |
|
dataloader: DataLoader |
|
sampler: DistributedSampler |
|
|
|
|
|
def get_eval_txt_dataset(args, max_txt_length=24): |
|
input_filename = args.text_data |
|
dataset = EvalTxtDataset( |
|
input_filename, |
|
max_txt_length=max_txt_length) |
|
num_samples = len(dataset) |
|
sampler = SequentialSampler(dataset) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=args.text_batch_size, |
|
num_workers=0, |
|
pin_memory=True, |
|
sampler=sampler, |
|
drop_last=False, |
|
) |
|
dataloader.num_samples = num_samples |
|
dataloader.num_batches = len(dataloader) |
|
|
|
return DataInfo(dataloader, sampler) |
|
|
|
|
|
def fetch_resolution(vision_model): |
|
|
|
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json" |
|
with open(vision_model_config_file, 'r') as fv: |
|
model_info = json.load(fv) |
|
return model_info["image_resolution"] |
|
|
|
|
|
def get_eval_img_dataset(args): |
|
lmdb_imgs = args.image_data |
|
dataset = EvalImgDataset( |
|
lmdb_imgs, resolution=fetch_resolution(args.vision_model)) |
|
num_samples = len(dataset) |
|
sampler = SequentialSampler(dataset) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=args.img_batch_size, |
|
num_workers=0, |
|
pin_memory=True, |
|
sampler=sampler, |
|
drop_last=False, |
|
) |
|
dataloader.num_samples = num_samples |
|
dataloader.num_batches = len(dataloader) |
|
|
|
return DataInfo(dataloader, sampler) |
|
|
|
|
|
def get_zeroshot_dataset(args, preprocess_fn): |
|
dataset = datasets.ImageFolder(args.datapath, transform=preprocess_fn) |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=args.img_batch_size, |
|
num_workers=args.num_workers, |
|
sampler=None, |
|
) |
|
|
|
return DataInfo(dataloader, None) |