Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
# SPDX-License-Identifier: Apache-2.0 | |
import io | |
import random | |
from PIL import Image, ImageFile, PngImagePlugin | |
from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset | |
from ..data_utils import pil_img2rgb | |
Image.MAX_IMAGE_PIXELS = 200000000 | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
MaximumDecompressedSize = 1024 | |
MegaByte = 2 ** 20 | |
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte | |
class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset): | |
def parse_row(self, row): | |
image_num = len(row["image_list"]) | |
# randomly choose start and end, return [0, 1] when only two images | |
start_idx = random.choice(range(image_num - 1)) | |
max_end = min(start_idx + 3, image_num) | |
end_idx = random.choice(range(start_idx + 1, max_end)) | |
data = self._init_data() | |
data = self._add_image( | |
data, | |
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))), | |
need_loss=False, | |
need_vae=True, | |
need_vit=True, | |
) | |
if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction | |
if end_idx == image_num - 1: | |
end_idx -= 1 | |
instruction = "" | |
for idx in range(start_idx + 1, end_idx + 1): | |
instruction += random.choice(row["instruction_list"][idx-1]) + ". " | |
data = self._add_text(data, instruction.rstrip(), need_loss=False) | |
data = self._add_image( | |
data, | |
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))), | |
need_loss=True, | |
need_vae=False, | |
need_vit=False, | |
) | |
else: | |
for idx in range(start_idx + 1, end_idx + 1): | |
instruction = random.choice(row["instruction_list"][idx-1]) | |
data = self._add_text(data, instruction, need_loss=False) | |
if idx != end_idx: | |
data = self._add_image( | |
data, | |
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))), | |
need_loss=True, | |
need_vae=True, | |
need_vit=True, | |
) | |
else: | |
data = self._add_image( | |
data, | |
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))), | |
need_loss=True, | |
need_vae=False, | |
need_vit=False, | |
) | |
return data | |