laurievb commited on
Commit
895b334
·
verified ·
1 Parent(s): d88461d

Upload scripts/sample_with_temperature.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/sample_with_temperature.py +98 -0
scripts/sample_with_temperature.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """samples with temperature, grouping by language code. assumes input files is sorted by language group"""
2
+
3
+ import argparse
4
+ import logging
5
+ import random
6
+ import sys
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("corpus_filepath", type=str, help="path to input corpus to sample")
11
+ parser.add_argument("linecounts_filepath", type=str, help="path to file containing line counts of input corpus (from 'uniq -c')")
12
+ return parser.parse_args()
13
+
14
+ # def count_lines(file):
15
+ # def blocks(files, size=65536):
16
+ # while True:
17
+ # b = files.read(size)
18
+ # if not b: break
19
+ # yield b
20
+ # with open(file, "r",encoding="utf-8",errors='ignore') as f:
21
+ # return (sum(bl.count("\n") for bl in blocks(f)))
22
+
23
+ def main():
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ filename='sampling.log',
27
+ filemode='w',
28
+ format='%(asctime)s %(levelname)s: %(message)s',
29
+ datefmt='%m/%d/%Y %I:%M:%S %p')
30
+ logger = logging.getLogger(__name__)
31
+
32
+ args = parse_args()
33
+
34
+ logger.info(f"creating counts lookup dict from {args.linecounts_filepath}")
35
+ with open(args.linecounts_filepath) as f:
36
+ total_raw_lines = 0
37
+ lc_lookup = dict()
38
+ for line in f:
39
+ count, lang = line.strip().split(' ')
40
+ count = int(count)
41
+ lc_lookup[lang] = {"raw_lines": count}
42
+ total_raw_lines += count
43
+
44
+ logger.info(f"lookup dict finished ({len(lc_lookup)} entries)")
45
+ logger.info(f"dataset contains {total_raw_lines} lines")
46
+
47
+ # calculate lines to keep with (((raw_lines_in_lang / total_line_count) ** 0.3) / total_proprotions) * total lines
48
+
49
+ # calculate proportions
50
+ logger.info("calculating sampling factors")
51
+ total_sampling_factors = 0
52
+ for lang in lc_lookup:
53
+ # we sample lines proportional to this so smaller langs are upsampled and larger langs are downsampled
54
+ sampling_factor = (lc_lookup[lang]['raw_lines'] / total_raw_lines) ** 0.3
55
+ lc_lookup[lang]["sampling_factor"] = sampling_factor
56
+ total_sampling_factors += sampling_factor
57
+
58
+ logger.info(f"sampling factor total is {total_sampling_factors}")
59
+ logger.info(f"calculating number of lines to sample")
60
+ total_lines_to_sample = 0
61
+ for lang in lc_lookup:
62
+ lines_to_sample = round(lc_lookup[lang]["sampling_factor"]/total_sampling_factors * total_raw_lines)
63
+ lc_lookup[lang]['lines_to_sample'] = lines_to_sample
64
+ total_lines_to_sample += lines_to_sample
65
+ prop_size_difference = abs((total_raw_lines - total_lines_to_sample)/total_lines_to_sample)
66
+ assert prop_size_difference < 0.01 # sense check that sampled corpus is right size
67
+ logger.info(
68
+ f"total raw lines is {total_raw_lines}, total sampled lines is {total_lines_to_sample} ({prop_size_difference:.3%} difference)")
69
+
70
+ # assume input file is sorted by group
71
+ logger.info(f"sampling from {args.corpus_filepath}")
72
+ with open(args.corpus_filepath, "r") as f:
73
+ single_lang_line_store = []
74
+ langcode = ""
75
+ while line := f.readline():
76
+ line = line.strip()
77
+ _, nextlang, _ = line.split('\t')
78
+ if langcode == nextlang or langcode == "": # same language
79
+ single_lang_line_store.append(line)
80
+ else: # language change, time to sample and write out
81
+ raw_lines_in_lang = len(single_lang_line_store)
82
+ assert raw_lines_in_lang == lc_lookup[langcode]["raw_lines"] # sanity check it's same data
83
+ num_lines_to_keep = lc_lookup[langcode]["lines_to_sample"]
84
+ logger.info(f"finished reading {langcode}: read in {raw_lines_in_lang}, writing {num_lines_to_keep}")
85
+ if raw_lines_in_lang > num_lines_to_keep:
86
+ sampled_lines_gc = (x for x in random.sample(single_lang_line_store, num_lines_to_keep))
87
+ else: # need to oversample, so now use sampling with replacement
88
+ sampled_lines_gc = (x for x in random.choices(single_lang_line_store, k=num_lines_to_keep))
89
+ for out in sampled_lines_gc:
90
+ sys.stdout.write(f"{out}\n")
91
+ logger.info(f"finished writing {langcode} to stdout, now collecting lines for {nextlang}")
92
+ single_lang_line_store = [line]
93
+ langcode = nextlang
94
+ logger.info("sampling complete!")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()