from collections import Counter, defaultdict import unicodedata def get_stats(ids): """ Given `ids`, a list of 2-tuples of iterables of ints and int values, returns a defaultdict with the counts of occurrences of all the consecutive pairs of integers within each bytes object, multiplied by the integer value associated with each key. This function does not count pairs between the last element of one key the first element of the next key. The integer value associated with each key serves as a multiplier for the count of each pair within that object. Consecutive identical pairs within the same bytes object are counted only once to avoid overcounting repeat characters. Example: get_stats({b'abc': 2, b'bcd': 1, b'eee': 1}) -> defaultdict(, {(97, 98): 1, (98, 99): 2, (99, 100): 1, (101, 101): 1}) """ counts = defaultdict(int) for chunk, num in ids: last_index = len(chunk) - 1 i = 0 while i < last_index: j = i + 1 counts[(chunk[i], chunk[j])] += num i = j return counts def merge_batch_get_stats(ids, pairs): counts = defaultdict(int) for chunk, num in ids: last_index = len(chunk) - 1 i = 0 while i < last_index: j = i + 1 token = pairs.get((chunk[i], chunk[j])) if token is not None: chunk[i] = token del chunk[j] last_index -= 1 if i: counts[(chunk[i-1], chunk[i])] += num i = j if i and i == last_index: counts[(chunk[-2], chunk[i])] += num return counts def merge(ids, pair, idx, len_ids): """ In the list of integers (ids), replace all consecutive occurrences of pair with the new integer token idx Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] """ i = 0 while i + 1 < len_ids: j = i + 1 if ids[i] == pair[0] and ids[j] == pair[1]: ids[i] = idx del ids[j] len_ids -= 1 i = j return len_ids def replace_control_characters(s: str) -> str: # we don't want to print control characters # which distort the output (e.g. \n or much worse) # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 # http://www.unicode.org/reports/tr44/#GC_Values_Table chars = [] for ch in s: if unicodedata.category(ch)[0] != "C": chars.append(ch) # this character is ok else: chars.append(f"\\u{ord(ch):04x}") # escape return "".join(chars) def render_token(t: bytes) -> str: # pretty print a token, escaping control characters s = t.decode('utf-8', errors='replace') s = replace_control_characters(s) return s def _process_dicts(batch, compiled_pattern): # for raw datasets.Dataset counter = Counter() for item in batch: counter.update(re.findall(compiled_pattern, item)) return counter def _process_string_scalar(batch, compiled_pattern): counter = Counter() for item in batch: counter.update(re.findall(compiled_pattern, item.as_py())) return counter