Movie / test_tars.py
Mudrock's picture
Upload 18 files
4c94b0e
raw
history blame contribute delete
No virus
3.29 kB
import webdataset as wds
import soundfile as sf
import io
import os
import random
import copy
from tqdm import tqdm
import shutil
import argparse
import traceback
import logging
import json
from open_clip import tokenize
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tar-path",
type=str,
default=None,
help="Path to the tars",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="start from tar-path + start",
)
parser.add_argument(
"--end",
type=int,
default=99999,
help="end with tar-path + end",
)
parser.add_argument(
"--exclude",
nargs='+',
default=None,
help="exclude tar-path + exclude",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
)
parser.add_argument(
"--order",
default=False,
action='store_true',
help="if keep the search order accendingly",
)
args = parser.parse_args()
return args
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
return True
def preprocess(
sample,
):
"""
Preprocess a single sample for wdsdataloader.
"""
audio_ext = "flac"
text_ext = "json"
audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
sample["waveform"] = audio_data
texts = json_dict_raw["text"]
if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
texts = random.choice(texts)
sample["raw_text"] = texts
sample["text"] = tokenize(texts)
return sample
if __name__ == "__main__":
args = parse_args()
tar_path = args.tar_path
idx_list = list(range(args.start, args.end))
if args.exclude != None:
for x in args.exclude:
idx_list.remove(x)
if not args.order:
random.shuffle(idx_list)
if "aws" in tar_path:
args.local = False
if args.local:
input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list]
else:
input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list]
pipeline = [wds.SimpleShardList(input_shards)]
pipeline.extend(
[
wds.split_by_node,
wds.split_by_worker,
wds.tarfile_to_samples(handler=log_and_continue),
wds.map(preprocess),
wds.to_tuple("__url__", "__key__", "waveform"),
wds.batched(1),
]
)
dataset = wds.DataPipeline(*pipeline)
dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
old_k = 0
old_batch = None
try:
for k, batch in tqdm(enumerate(dataloader)):
print("k:", k)
print("batch:", batch)
old_k = k
old_batch = copy.deepcopy(batch)
except:
with open("check_tar_log.txt","a") as file:
traceback.print_exc(file = file)
print("old_k:", old_k)
print("old_batch:", old_batch)
pass