KingNish's picture
Upload 110 files
e6af450 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import random
from PIL import Image, ImageFile, PngImagePlugin
from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
from ..data_utils import pil_img2rgb
Image.MAX_IMAGE_PIXELS = 200000000
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
def parse_row(self, row):
image_num = len(row["image_list"])
# randomly choose start and end, return [0, 1] when only two images
start_idx = random.choice(range(image_num - 1))
max_end = min(start_idx + 3, image_num)
end_idx = random.choice(range(start_idx + 1, max_end))
data = self._init_data()
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
need_loss=False,
need_vae=True,
need_vit=True,
)
if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
if end_idx == image_num - 1:
end_idx -= 1
instruction = ""
for idx in range(start_idx + 1, end_idx + 1):
instruction += random.choice(row["instruction_list"][idx-1]) + ". "
data = self._add_text(data, instruction.rstrip(), need_loss=False)
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
need_loss=True,
need_vae=False,
need_vit=False,
)
else:
for idx in range(start_idx + 1, end_idx + 1):
instruction = random.choice(row["instruction_list"][idx-1])
data = self._add_text(data, instruction, need_loss=False)
if idx != end_idx:
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
need_loss=True,
need_vae=True,
need_vit=True,
)
else:
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
need_loss=True,
need_vae=False,
need_vit=False,
)
return data