whisper-large-finnish-v3
/
community-events
/jax-controlnet-sprint
/dataset_tools
/coyo_1m_dataset_preprocess.py
import argparse | |
import logging | |
import random | |
import cv2 | |
import jsonlines | |
import numpy as np | |
import requests | |
from datasets import load_dataset | |
from PIL import Image | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Example of a data preprocessing script." | |
) | |
parser.add_argument( | |
"--train_data_dir", | |
type=str, | |
required=True, | |
help="The directory to store the dataset", | |
) | |
parser.add_argument( | |
"--cache_dir", | |
type=str, | |
required=True, | |
help="The directory to store cache", | |
) | |
parser.add_argument( | |
"--max_train_samples", | |
type=int, | |
default=None, | |
help="number of examples in the dataset", | |
) | |
parser.add_argument( | |
"--num_proc", | |
type=int, | |
default=1, | |
help="number of processors to use in `dataset.map()`", | |
) | |
args = parser.parse_args() | |
return args | |
# filter for `max_train_samples`` | |
def filter_function(example): | |
if example["clip_similarity_vitb32"] < 0.3: | |
return False | |
if example["watermark_score"] > 0.4: | |
return False | |
if example["aesthetic_score_laion_v2"] < 6.0: | |
return False | |
return True | |
def filter_dataset(dataset, max_train_samples): | |
small_dataset = dataset.select(range(max_train_samples)).filter(filter_function) | |
return small_dataset | |
if __name__ == "__main__": | |
args = parse_args() | |
# load coyo-700 | |
dataset = load_dataset( | |
"kakaobrain/coyo-700m", | |
cache_dir=args.cache_dir, | |
split="train", | |
) | |
# estimation the % of images filtered | |
filter_ratio = len(filter_dataset(dataset, 20000)) / 20000 | |
# esimate max_train_samples based on | |
# (1) filter_ratio we calculuted with 20k examples | |
# (2) assumption that only 80% of the URLs are still valid | |
max_train_samples = int(args.max_train_samples / filter_ratio / 0.8) | |
# filter dataset down to 1 million | |
small_dataset = filter_dataset(dataset, max_train_samples) | |
def preprocess_and_save(example): | |
image_url = example["url"] | |
try: | |
# download original image | |
image = Image.open(requests.get(image_url, stream=True, timeout=5).raw) | |
image_path = f"{args.train_data_dir}/images/{example['id']}.png" | |
image.save(image_path) | |
# generate and save canny image | |
processed_image = np.array(image) | |
# apply random threholds | |
# note that this should normally be applied on the fly during training. | |
# But that's fine when dealing with a larger dataset like here. | |
threholds = ( | |
random.randint(0, 255), | |
random.randint(0, 255), | |
) | |
processed_image = cv2.Canny(processed_image, min(threholds), max(threholds)) | |
processed_image = processed_image[:, :, None] | |
processed_image = np.concatenate( | |
[processed_image, processed_image, processed_image], axis=2 | |
) | |
processed_image = Image.fromarray(processed_image) | |
processed_image_path = ( | |
f"{args.train_data_dir}/processed_images/{example['id']}.png" | |
) | |
processed_image.save(processed_image_path) | |
# write to meta.jsonl | |
meta = { | |
"image": image_path, | |
"conditioning_image": processed_image_path, | |
"caption": example["text"], | |
} | |
with jsonlines.open( | |
f"{args.train_data_dir}/meta.jsonl", "a" | |
) as writer: # for writing | |
writer.write(meta) | |
except Exception as e: | |
logger.error(f"Failed to process image{image_url}: {str(e)}") | |
# preprocess -> image, processed image and meta.jsonl | |
small_dataset.map(preprocess_and_save, num_proc=args.num_proc) | |
print(f"created data folder at: {args.train_data_dir}") | |