multimodal / open_flamingo /tools /make_vqav2_ft_dataset.py
Li
init
5282eae
raw
history blame
982 Bytes
import webdataset as wds
import os
from tqdm import tqdm
from PIL import Image
from io import BytesIO
import base64
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_train_wds"
TOTAL = 1828467
if __name__ == "__main__":
with wds.ShardWriter(os.path.join(OUT_DIR, "%06d.tar"), maxcount=10000) as sink:
sink.verbose = False
f = open("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_ofa/vqa_train.tsv")
for data in tqdm(f, total=TOTAL):
data = data.rstrip().split("\t")
id1 = data[0]
id2 = data[1]
question = data[2]
answer = data[3].split("|!+")[-1]
image = data[5]
id3 = data[6]
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
caption = f"Question: {question.strip()} Answer: {answer.strip()}"
sink.write({"__key__": f"vqav2_{id1}_{id2}_{id3}", "jpg": image, "txt": caption})