Spaces:
Runtime error
Runtime error
File size: 5,050 Bytes
54199b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import pytest
import util_test
import collections
import tarfile
import io
from PIL import Image
from training.data import get_wds_dataset
from training.params import parse_args
from training.main import random_seed
TRAIN_NUM_SAMPLES = 10_000
RTOL = 0.2
# NOTE: we use two test tar files, which are created on the fly and saved to data/input.
# 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9
# 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4
def build_inputs(test_name):
base_input_dir, _ = util_test.get_data_dirs()
input_dir = os.path.join(base_input_dir, test_name)
os.makedirs(input_dir, exist_ok=True)
def save_tar(idx, num_samples):
filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar')
tar = tarfile.open(filename, 'w')
for sample_idx in range(num_samples):
# Image
image = Image.new('RGB', (32, 32))
info = tarfile.TarInfo(f'{sample_idx}.png')
bio = io.BytesIO()
image.save(bio, format='png')
size = bio.tell()
bio.seek(0)
info.size = size
tar.addfile(info, bio)
# Caption
info = tarfile.TarInfo(f'{sample_idx}.txt')
bio = io.BytesIO()
bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8'))
size = bio.tell()
bio.seek(0)
info.size = size
tar.addfile(info, bio)
tar.close()
save_tar(0, 10)
save_tar(1, 5)
return input_dir
def build_params(input_shards, seed=0):
args = parse_args([])
args.train_data = input_shards
args.train_num_samples = TRAIN_NUM_SAMPLES
args.dataset_resampled = True
args.seed = seed
args.workers = 1
args.world_size = 1
args.batch_size = 1
random_seed(seed)
preprocess_img = lambda x: x
tokenizer = lambda x: [x.strip()]
return args, preprocess_img, tokenizer
def get_dataloader(input_shards):
args, preprocess_img, tokenizer = build_params(input_shards)
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
return dataloader
def test_single_source():
"""Test webdataset with a single tar file."""
input_dir = build_inputs('single_source')
input_shards = os.path.join(input_dir, 'test_data_000.tar')
dataloader = get_dataloader(input_shards)
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL)
def test_two_sources():
"""Test webdataset with a single two tar files."""
input_dir = build_inputs('two_sources')
input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar')
dataloader = get_dataloader(input_shards)
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'
def test_two_sources_same_weights():
"""Test webdataset with a two tar files, using --train-data-weights=1::1."""
input_dir = build_inputs('two_sources_same_weights')
input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
args, preprocess_img, tokenizer = build_params(input_shards)
args.train_data_upsampling_factors = '1::1'
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'
def test_two_sources_with_upsampling():
"""Test webdataset with a two tar files with upsampling."""
input_dir = build_inputs('two_sources_with_upsampling')
input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
args, preprocess_img, tokenizer = build_params(input_shards)
args.train_data_upsampling_factors = '1::2'
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
if key.startswith('000'):
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}'
else:
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}'
|