OpenLID-v2 / scripts /sample_with_temperature.py
laurievb's picture
Upload scripts/sample_with_temperature.py with huggingface_hub
895b334 verified
raw
history blame
4.47 kB
"""samples with temperature, grouping by language code. assumes input files is sorted by language group"""
import argparse
import logging
import random
import sys
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("corpus_filepath", type=str, help="path to input corpus to sample")
parser.add_argument("linecounts_filepath", type=str, help="path to file containing line counts of input corpus (from 'uniq -c')")
return parser.parse_args()
# def count_lines(file):
# def blocks(files, size=65536):
# while True:
# b = files.read(size)
# if not b: break
# yield b
# with open(file, "r",encoding="utf-8",errors='ignore') as f:
# return (sum(bl.count("\n") for bl in blocks(f)))
def main():
logging.basicConfig(
level=logging.INFO,
filename='sampling.log',
filemode='w',
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger(__name__)
args = parse_args()
logger.info(f"creating counts lookup dict from {args.linecounts_filepath}")
with open(args.linecounts_filepath) as f:
total_raw_lines = 0
lc_lookup = dict()
for line in f:
count, lang = line.strip().split(' ')
count = int(count)
lc_lookup[lang] = {"raw_lines": count}
total_raw_lines += count
logger.info(f"lookup dict finished ({len(lc_lookup)} entries)")
logger.info(f"dataset contains {total_raw_lines} lines")
# calculate lines to keep with (((raw_lines_in_lang / total_line_count) ** 0.3) / total_proprotions) * total lines
# calculate proportions
logger.info("calculating sampling factors")
total_sampling_factors = 0
for lang in lc_lookup:
# we sample lines proportional to this so smaller langs are upsampled and larger langs are downsampled
sampling_factor = (lc_lookup[lang]['raw_lines'] / total_raw_lines) ** 0.3
lc_lookup[lang]["sampling_factor"] = sampling_factor
total_sampling_factors += sampling_factor
logger.info(f"sampling factor total is {total_sampling_factors}")
logger.info(f"calculating number of lines to sample")
total_lines_to_sample = 0
for lang in lc_lookup:
lines_to_sample = round(lc_lookup[lang]["sampling_factor"]/total_sampling_factors * total_raw_lines)
lc_lookup[lang]['lines_to_sample'] = lines_to_sample
total_lines_to_sample += lines_to_sample
prop_size_difference = abs((total_raw_lines - total_lines_to_sample)/total_lines_to_sample)
assert prop_size_difference < 0.01 # sense check that sampled corpus is right size
logger.info(
f"total raw lines is {total_raw_lines}, total sampled lines is {total_lines_to_sample} ({prop_size_difference:.3%} difference)")
# assume input file is sorted by group
logger.info(f"sampling from {args.corpus_filepath}")
with open(args.corpus_filepath, "r") as f:
single_lang_line_store = []
langcode = ""
while line := f.readline():
line = line.strip()
_, nextlang, _ = line.split('\t')
if langcode == nextlang or langcode == "": # same language
single_lang_line_store.append(line)
else: # language change, time to sample and write out
raw_lines_in_lang = len(single_lang_line_store)
assert raw_lines_in_lang == lc_lookup[langcode]["raw_lines"] # sanity check it's same data
num_lines_to_keep = lc_lookup[langcode]["lines_to_sample"]
logger.info(f"finished reading {langcode}: read in {raw_lines_in_lang}, writing {num_lines_to_keep}")
if raw_lines_in_lang > num_lines_to_keep:
sampled_lines_gc = (x for x in random.sample(single_lang_line_store, num_lines_to_keep))
else: # need to oversample, so now use sampling with replacement
sampled_lines_gc = (x for x in random.choices(single_lang_line_store, k=num_lines_to_keep))
for out in sampled_lines_gc:
sys.stdout.write(f"{out}\n")
logger.info(f"finished writing {langcode} to stdout, now collecting lines for {nextlang}")
single_lang_line_store = [line]
langcode = nextlang
logger.info("sampling complete!")
if __name__ == "__main__":
main()