|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import multiprocessing |
|
import os |
|
import pickle |
|
import shutil |
|
import tarfile |
|
import time |
|
from argparse import ArgumentParser |
|
from glob import glob |
|
from multiprocessing import Pool |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import pandas as pd |
|
|
|
SHAPES = { |
|
'prompt_embeds': (77, 2048), |
|
'pooled_prompt_embeds': (1280,), |
|
'latents_256': (4, 32, 32), |
|
} |
|
|
|
|
|
def convert_single_parquet_to_tar(parquet_file): |
|
pf = pd.read_parquet(parquet_file) |
|
tmp_folder = Path(parquet_file.split('.')[0] + '-tmp-pickle-files') |
|
os.makedirs(tmp_folder, exist_ok=True) |
|
tar_file = Path(args.output_folder) / os.path.basename(parquet_file).replace('parquet', 'tar') |
|
with tarfile.open(tar_file, 'w') as f: |
|
tmp_pickle_files = [] |
|
for i in range(len(pf.index)): |
|
data = pf.iloc[i] |
|
info = dict() |
|
for key, shape in SHAPES.items(): |
|
info[key] = np.frombuffer(data[key], dtype=np.float32).reshape(shape) |
|
tmp_pickle_filename = f'{i}.pickle' |
|
pickle.dump(info, open(tmp_folder / tmp_pickle_filename, 'wb')) |
|
f.add(tmp_folder / tmp_pickle_filename, tmp_pickle_filename) |
|
tmp_pickle_files.append(tmp_pickle_filename) |
|
shutil.rmtree(tmp_folder) |
|
|
|
|
|
def generate_wdinfo(tar_folder: str, chunk_size: int, output_path: Optional[str]): |
|
if not output_path: |
|
return |
|
tar_files = [] |
|
for fname in glob(os.path.join(tar_folder, '*.tar')): |
|
|
|
if os.path.getsize(fname) > 0 and not os.path.exists(f"{fname}.INCOMPLETE"): |
|
tar_files.append(os.path.basename(fname)) |
|
data = {'tar_files': sorted(tar_files), 'chunk_size': chunk_size, 'total_key_count': len(tar_files) * chunk_size} |
|
with open(output_path, 'wb') as f: |
|
pickle.dump(data, f) |
|
print("Generated", output_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument('--parquet_folder', type=str, default='data/parquet') |
|
parser.add_argument('--output_folder', type=str, default='data/output') |
|
parser.add_argument('--num_process', type=int, default=-1) |
|
parser.add_argument('--num_files', type=int, default=-1) |
|
args = parser.parse_args() |
|
|
|
PROFILE = True |
|
if PROFILE: |
|
shutil.rmtree(args.output_folder) |
|
|
|
os.makedirs(args.output_folder, exist_ok=True) |
|
parquets = glob(f'{args.parquet_folder}/*.parquet') |
|
if args.num_files > 0: |
|
parquets = parquets[: args.num_files] |
|
args.num_files = len(parquets) |
|
print(f'Processing {args.num_files} files.') |
|
if args.num_process <= 0: |
|
args.num_process = min(len(parquets), multiprocessing.cpu_count()) |
|
print(f'Converting using {args.num_process} processes.') |
|
assert args.num_process <= args.num_files |
|
|
|
t0 = time.time() |
|
with Pool(processes=args.num_process) as pool: |
|
pool.map(convert_single_parquet_to_tar, parquets) |
|
t1 = time.time() |
|
if PROFILE: |
|
print("====== Summary ======") |
|
print(f"{args.num_process} processes and {args.num_files} files.") |
|
print(f"Total time {t1-t0:.2f}") |
|
print(f"Time per file {(t1-t0)/len(parquets):.2f}") |
|
|
|
generate_wdinfo(args.output_folder, chunk_size=5000, output_path=os.path.join(args.output_folder, 'wdinfo.pkl')) |
|
|