import gzip import hashlib import json import multiprocessing import os import re import shutil import time from pathlib import Path import numpy as np from arguments import PreprocessingArguments from datasets import load_dataset from minhash_deduplication import deduplicate_dataset from transformers import AutoTokenizer, HfArgumentParser PATTERN = re.compile(r"\s+") def get_hash(example): """Get hash of content field.""" return {"hash": hashlib.md5(re.sub(PATTERN, "", example["content"]).encode("utf-8")).hexdigest()} def line_stats(example): """Calculates mean and max line length of file.""" line_lengths = [len(line) for line in example["content"].splitlines()] return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)} def alpha_stats(example): """Calculates mean and max line length of file.""" alpha_frac = np.mean([c.isalnum() for c in example["content"]]) return {"alpha_frac": alpha_frac} def check_uniques(example, uniques): """Check if current hash is still in set of unique hashes and remove if true.""" if example["hash"] in uniques: uniques.remove(example["hash"]) return True else: return False def is_autogenerated(example, scan_width=5): """Check if file is autogenerated by looking for keywords in the first few lines of the file.""" keywords = ["auto-generated", "autogenerated", "automatically generated"] lines = example["content"].splitlines() for _, line in zip(range(scan_width), lines): for keyword in keywords: if keyword in line.lower(): return {"autogenerated": True} else: return {"autogenerated": False} def is_config_or_test(example, scan_width=5, coeff=0.05): """Check if file is a configuration file or a unit test by : 1- looking for keywords in the first few lines of the file. 2- counting number of occurence of the words 'config' and 'test' with respect to number of lines. """ keywords = ["unit tests", "test file", "configuration file"] lines = example["content"].splitlines() count_config = 0 count_test = 0 # first test for _, line in zip(range(scan_width), lines): for keyword in keywords: if keyword in line.lower(): return {"config_or_test": True} # second test nlines = example["content"].count("\n") threshold = int(coeff * nlines) for line in lines: count_config += line.lower().count("config") count_test += line.lower().count("test") if count_config > threshold or count_test > threshold: return {"config_or_test": True} return {"config_or_test": False} def has_no_keywords(example): """Check if a python file has none of the keywords for: funcion, class, for loop, while loop.""" keywords = ["def ", "class ", "for ", "while "] lines = example["content"].splitlines() for line in lines: for keyword in keywords: if keyword in line.lower(): return {"has_no_keywords": False} return {"has_no_keywords": True} def has_few_assignments(example, minimum=4): """Check if file uses symbol '=' less than `minimum` times.""" lines = example["content"].splitlines() counter = 0 for line in lines: counter += line.lower().count("=") if counter > minimum: return {"has_few_assignments": False} return {"has_few_assignments": True} def char_token_ratio(example): """Compute character/token ratio of the file with tokenizer.""" input_ids = tokenizer(example["content"], truncation=False)["input_ids"] ratio = len(example["content"]) / len(input_ids) return {"ratio": ratio} def preprocess(example): """Chain all preprocessing steps into one function to not fill cache.""" results = {} results.update(get_hash(example)) results.update(line_stats(example)) results.update(alpha_stats(example)) results.update(char_token_ratio(example)) results.update(is_autogenerated(example)) results.update(is_config_or_test(example)) results.update(has_no_keywords(example)) results.update(has_few_assignments(example)) return results def filter(example, uniques, args): """Filter dataset with heuristics. Config, test and has_no_keywords files are removed with a given probability.""" if not check_uniques(example, uniques): return False elif example["autogenerated"]: return False elif example["line_max"] > args.line_max: return False elif example["line_mean"] > args.line_mean: return False elif example["alpha_frac"] < args.alpha_frac: return False elif example["ratio"] < args.min_token_ratio: return False elif example["config_or_test"] and np.random.rand() <= args.filter_proba: return False elif example["has_no_keywords"] and np.random.rand() <= args.filter_proba: return False elif example["has_few_assignments"]: return False else: return True def compress_file(file_path): """Compress a file with g-zip.""" with open(file_path, "rb") as f_in: with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out: shutil.copyfileobj(f_in, f_out) os.unlink(file_path) # Settings parser = HfArgumentParser(PreprocessingArguments) args = parser.parse_args() if args.num_workers is None: args.num_workers = multiprocessing.cpu_count() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) # Load dataset t_start = time.time() ds = load_dataset(args.dataset_name, split="train") print(f"Time to load dataset: {time.time()-t_start:.2f}") # Run preprocessing t_start = time.time() ds = ds.map(preprocess, num_proc=args.num_workers) print(f"Time to preprocess dataset: {time.time()-t_start:.2f}") # Deduplicate hashes uniques = set(ds.unique("hash")) frac = len(uniques) / len(ds) print(f"Fraction of duplicates: {1-frac:.2%}") # Deduplicate data and apply heuristics t_start = time.time() ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args}) print(f"Time to filter dataset: {time.time()-t_start:.2f}") print(f"Size of filtered dataset: {len(ds_filter)}") # Deduplicate with minhash and jaccard similarity if args.near_deduplication: t_start = time.time() ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold) print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}") print(f"Size of deduplicate dataset: {len(ds_filter)}") # Save data in batches of samples_per_file output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True) # save duplicate_clusters in the output_dir as artifacts # not sure it is the right place the save it if args.near_deduplication: with open(output_dir / "duplicate_clusters.json", "w") as f: json.dump(duplicate_clusters, f) data_dir = output_dir / "data" data_dir.mkdir(exist_ok=True) t_start = time.time() for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)): file_path = str(data_dir / f"file-{file_number+1:012}.json") end_index = min(len(ds_filter), index + args.samples_per_file) ds_filter.select(list(range(index, end_index))).to_json(file_path) compress_file(file_path) print(f"Time to save dataset: {time.time()-t_start:.2f}")