Upload scripts/sample_with_temperature.py with huggingface_hub
Browse files
scripts/sample_with_temperature.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""samples with temperature, grouping by language code. assumes input files is sorted by language group"""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
|
8 |
+
def parse_args():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("corpus_filepath", type=str, help="path to input corpus to sample")
|
11 |
+
parser.add_argument("linecounts_filepath", type=str, help="path to file containing line counts of input corpus (from 'uniq -c')")
|
12 |
+
return parser.parse_args()
|
13 |
+
|
14 |
+
# def count_lines(file):
|
15 |
+
# def blocks(files, size=65536):
|
16 |
+
# while True:
|
17 |
+
# b = files.read(size)
|
18 |
+
# if not b: break
|
19 |
+
# yield b
|
20 |
+
# with open(file, "r",encoding="utf-8",errors='ignore') as f:
|
21 |
+
# return (sum(bl.count("\n") for bl in blocks(f)))
|
22 |
+
|
23 |
+
def main():
|
24 |
+
logging.basicConfig(
|
25 |
+
level=logging.INFO,
|
26 |
+
filename='sampling.log',
|
27 |
+
filemode='w',
|
28 |
+
format='%(asctime)s %(levelname)s: %(message)s',
|
29 |
+
datefmt='%m/%d/%Y %I:%M:%S %p')
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
args = parse_args()
|
33 |
+
|
34 |
+
logger.info(f"creating counts lookup dict from {args.linecounts_filepath}")
|
35 |
+
with open(args.linecounts_filepath) as f:
|
36 |
+
total_raw_lines = 0
|
37 |
+
lc_lookup = dict()
|
38 |
+
for line in f:
|
39 |
+
count, lang = line.strip().split(' ')
|
40 |
+
count = int(count)
|
41 |
+
lc_lookup[lang] = {"raw_lines": count}
|
42 |
+
total_raw_lines += count
|
43 |
+
|
44 |
+
logger.info(f"lookup dict finished ({len(lc_lookup)} entries)")
|
45 |
+
logger.info(f"dataset contains {total_raw_lines} lines")
|
46 |
+
|
47 |
+
# calculate lines to keep with (((raw_lines_in_lang / total_line_count) ** 0.3) / total_proprotions) * total lines
|
48 |
+
|
49 |
+
# calculate proportions
|
50 |
+
logger.info("calculating sampling factors")
|
51 |
+
total_sampling_factors = 0
|
52 |
+
for lang in lc_lookup:
|
53 |
+
# we sample lines proportional to this so smaller langs are upsampled and larger langs are downsampled
|
54 |
+
sampling_factor = (lc_lookup[lang]['raw_lines'] / total_raw_lines) ** 0.3
|
55 |
+
lc_lookup[lang]["sampling_factor"] = sampling_factor
|
56 |
+
total_sampling_factors += sampling_factor
|
57 |
+
|
58 |
+
logger.info(f"sampling factor total is {total_sampling_factors}")
|
59 |
+
logger.info(f"calculating number of lines to sample")
|
60 |
+
total_lines_to_sample = 0
|
61 |
+
for lang in lc_lookup:
|
62 |
+
lines_to_sample = round(lc_lookup[lang]["sampling_factor"]/total_sampling_factors * total_raw_lines)
|
63 |
+
lc_lookup[lang]['lines_to_sample'] = lines_to_sample
|
64 |
+
total_lines_to_sample += lines_to_sample
|
65 |
+
prop_size_difference = abs((total_raw_lines - total_lines_to_sample)/total_lines_to_sample)
|
66 |
+
assert prop_size_difference < 0.01 # sense check that sampled corpus is right size
|
67 |
+
logger.info(
|
68 |
+
f"total raw lines is {total_raw_lines}, total sampled lines is {total_lines_to_sample} ({prop_size_difference:.3%} difference)")
|
69 |
+
|
70 |
+
# assume input file is sorted by group
|
71 |
+
logger.info(f"sampling from {args.corpus_filepath}")
|
72 |
+
with open(args.corpus_filepath, "r") as f:
|
73 |
+
single_lang_line_store = []
|
74 |
+
langcode = ""
|
75 |
+
while line := f.readline():
|
76 |
+
line = line.strip()
|
77 |
+
_, nextlang, _ = line.split('\t')
|
78 |
+
if langcode == nextlang or langcode == "": # same language
|
79 |
+
single_lang_line_store.append(line)
|
80 |
+
else: # language change, time to sample and write out
|
81 |
+
raw_lines_in_lang = len(single_lang_line_store)
|
82 |
+
assert raw_lines_in_lang == lc_lookup[langcode]["raw_lines"] # sanity check it's same data
|
83 |
+
num_lines_to_keep = lc_lookup[langcode]["lines_to_sample"]
|
84 |
+
logger.info(f"finished reading {langcode}: read in {raw_lines_in_lang}, writing {num_lines_to_keep}")
|
85 |
+
if raw_lines_in_lang > num_lines_to_keep:
|
86 |
+
sampled_lines_gc = (x for x in random.sample(single_lang_line_store, num_lines_to_keep))
|
87 |
+
else: # need to oversample, so now use sampling with replacement
|
88 |
+
sampled_lines_gc = (x for x in random.choices(single_lang_line_store, k=num_lines_to_keep))
|
89 |
+
for out in sampled_lines_gc:
|
90 |
+
sys.stdout.write(f"{out}\n")
|
91 |
+
logger.info(f"finished writing {langcode} to stdout, now collecting lines for {nextlang}")
|
92 |
+
single_lang_line_store = [line]
|
93 |
+
langcode = nextlang
|
94 |
+
logger.info("sampling complete!")
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
main()
|