import pytest from training.data import get_dataset_size @pytest.mark.parametrize( "shards,expected_size", [ ('/path/to/shard.tar', 1), ('/path/to/shard_{000..000}.tar', 1), ('/path/to/shard_{000..009}.tar', 10), ('/path/to/shard_{000..009}_{000..009}.tar', 100), ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11), ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20), (['/path/to/shard.tar'], 1), (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2), ] ) def test_num_shards(shards, expected_size): _, size = get_dataset_size(shards) assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.'