Spaces:
Sleeping
Sleeping
Upload TMIDIX.py
Browse files
TMIDIX.py
CHANGED
|
@@ -51,7 +51,7 @@ r'''############################################################################
|
|
| 51 |
|
| 52 |
###################################################################################
|
| 53 |
|
| 54 |
-
__version__ = "26.3.
|
| 55 |
|
| 56 |
print('=' * 70)
|
| 57 |
print('TMIDIX Python module')
|
|
@@ -1511,6 +1511,8 @@ from functools import reduce, lru_cache
|
|
| 1511 |
|
| 1512 |
import struct
|
| 1513 |
|
|
|
|
|
|
|
| 1514 |
import matplotlib.pyplot as plt
|
| 1515 |
|
| 1516 |
import psutil
|
|
@@ -1528,7 +1530,7 @@ from array import array
|
|
| 1528 |
from pathlib import Path
|
| 1529 |
from fnmatch import fnmatch
|
| 1530 |
|
| 1531 |
-
from typing import List, Optional, Tuple, Dict, Any
|
| 1532 |
|
| 1533 |
###################################################################################
|
| 1534 |
#
|
|
@@ -17851,6 +17853,377 @@ def distribute_k_values(k: float, n: int):
|
|
| 17851 |
|
| 17852 |
###################################################################################
|
| 17853 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17854 |
print('Module loaded!')
|
| 17855 |
print('=' * 70)
|
| 17856 |
print('Enjoy! :)')
|
|
|
|
| 51 |
|
| 52 |
###################################################################################
|
| 53 |
|
| 54 |
+
__version__ = "26.3.28"
|
| 55 |
|
| 56 |
print('=' * 70)
|
| 57 |
print('TMIDIX Python module')
|
|
|
|
| 1511 |
|
| 1512 |
import struct
|
| 1513 |
|
| 1514 |
+
import heapq
|
| 1515 |
+
|
| 1516 |
import matplotlib.pyplot as plt
|
| 1517 |
|
| 1518 |
import psutil
|
|
|
|
| 1530 |
from pathlib import Path
|
| 1531 |
from fnmatch import fnmatch
|
| 1532 |
|
| 1533 |
+
from typing import List, Optional, Tuple, Dict, Any, Optional, Iterable, Set
|
| 1534 |
|
| 1535 |
###################################################################################
|
| 1536 |
#
|
|
|
|
| 17853 |
|
| 17854 |
###################################################################################
|
| 17855 |
|
| 17856 |
+
def binary_rle_encoder(bits):
|
| 17857 |
+
|
| 17858 |
+
deltas = []
|
| 17859 |
+
last_pos = -1
|
| 17860 |
+
|
| 17861 |
+
for i, b in enumerate(bits):
|
| 17862 |
+
if b == 1:
|
| 17863 |
+
if last_pos == -1:
|
| 17864 |
+
deltas.append(i)
|
| 17865 |
+
|
| 17866 |
+
else:
|
| 17867 |
+
deltas.append(i - last_pos - 1)
|
| 17868 |
+
|
| 17869 |
+
last_pos = i
|
| 17870 |
+
|
| 17871 |
+
return deltas
|
| 17872 |
+
|
| 17873 |
+
###################################################################################
|
| 17874 |
+
|
| 17875 |
+
def binary_rle_decoder(deltas):
|
| 17876 |
+
|
| 17877 |
+
if not deltas:
|
| 17878 |
+
return []
|
| 17879 |
+
|
| 17880 |
+
positions = []
|
| 17881 |
+
pos = -1
|
| 17882 |
+
|
| 17883 |
+
for d in deltas:
|
| 17884 |
+
pos = pos + d + 1
|
| 17885 |
+
positions.append(pos)
|
| 17886 |
+
|
| 17887 |
+
length = (((positions[-1] + 1) // 128)+1) * 128
|
| 17888 |
+
|
| 17889 |
+
bits = [0] * length
|
| 17890 |
+
|
| 17891 |
+
for p in positions:
|
| 17892 |
+
bits[p] = 1
|
| 17893 |
+
|
| 17894 |
+
return bits
|
| 17895 |
+
|
| 17896 |
+
###################################################################################
|
| 17897 |
+
|
| 17898 |
+
class _Node:
|
| 17899 |
+
__slots__ = ("token", "prev", "next", "seq_idx")
|
| 17900 |
+
def __init__(self, token: int, seq_idx: int):
|
| 17901 |
+
self.token = token
|
| 17902 |
+
self.prev: Optional["_Node"] = None
|
| 17903 |
+
self.next: Optional["_Node"] = None
|
| 17904 |
+
self.seq_idx = seq_idx
|
| 17905 |
+
def __repr__(self):
|
| 17906 |
+
return f"_Node(tok={self.token},seq={self.seq_idx})"
|
| 17907 |
+
# Use default object identity hashing (fast and unique)
|
| 17908 |
+
def __hash__(self):
|
| 17909 |
+
return id(self)
|
| 17910 |
+
def __eq__(self, other):
|
| 17911 |
+
return self is other
|
| 17912 |
+
|
| 17913 |
+
###################################################################################
|
| 17914 |
+
|
| 17915 |
+
def train_bpe(
|
| 17916 |
+
corpus: List[List[int]],
|
| 17917 |
+
target_vocab_size: int,
|
| 17918 |
+
min_frequency: int = 2,
|
| 17919 |
+
start_token_id: Optional[int] = None,
|
| 17920 |
+
verbose: bool = False,
|
| 17921 |
+
show_progress: bool = True
|
| 17922 |
+
) -> Tuple[List[Tuple[int,int,int]], Dict[Any,int], Dict[int,Any]]:
|
| 17923 |
+
|
| 17924 |
+
"""
|
| 17925 |
+
Fast BPE trainer using node-based occurrences and incremental updates.
|
| 17926 |
+
|
| 17927 |
+
Returns:
|
| 17928 |
+
merges: list of (left_id, right_id, new_id)
|
| 17929 |
+
token_to_id: mapping from original token or structured rep -> id
|
| 17930 |
+
id_to_token: mapping from id -> structured rep
|
| 17931 |
+
"""
|
| 17932 |
+
|
| 17933 |
+
seqs: List[List[int]] = [list(s) for s in corpus]
|
| 17934 |
+
|
| 17935 |
+
orig_vocab = sorted({tok for s in seqs for tok in s})
|
| 17936 |
+
base_id = 0 if start_token_id is None else start_token_id
|
| 17937 |
+
orig_to_compact: Dict[int, int] = {tok: base_id + i for i, tok in enumerate(orig_vocab)}
|
| 17938 |
+
compact_to_orig: Dict[int, int] = {cid: tok for tok, cid in orig_to_compact.items()}
|
| 17939 |
+
|
| 17940 |
+
for i, s in enumerate(seqs):
|
| 17941 |
+
seqs[i] = [orig_to_compact[t] for t in s]
|
| 17942 |
+
|
| 17943 |
+
current_vocab_size = len(orig_vocab)
|
| 17944 |
+
|
| 17945 |
+
if target_vocab_size <= current_vocab_size:
|
| 17946 |
+
id_to_token = {cid: ('orig', compact_to_orig[cid]) for cid in compact_to_orig}
|
| 17947 |
+
token_to_id = {compact_to_orig[cid]: cid for cid in compact_to_orig}
|
| 17948 |
+
return [], token_to_id, id_to_token
|
| 17949 |
+
|
| 17950 |
+
next_id = base_id + current_vocab_size
|
| 17951 |
+
|
| 17952 |
+
merges: List[Tuple[int,int,int]] = []
|
| 17953 |
+
id_to_token: Dict[int, Any] = {cid: ('orig', compact_to_orig[cid]) for cid in compact_to_orig}
|
| 17954 |
+
token_to_id: Dict[Any, int] = {compact_to_orig[cid]: cid for cid in compact_to_orig}
|
| 17955 |
+
|
| 17956 |
+
pair_counts: Counter = Counter()
|
| 17957 |
+
occurrences: Dict[Tuple[int,int], Set[_Node]] = defaultdict(set)
|
| 17958 |
+
seq_nodes: List[List[_Node]] = []
|
| 17959 |
+
|
| 17960 |
+
for si, s in enumerate(seqs):
|
| 17961 |
+
|
| 17962 |
+
nodes = [ _Node(tok, si) for tok in s ]
|
| 17963 |
+
|
| 17964 |
+
for i in range(len(nodes)):
|
| 17965 |
+
if i > 0:
|
| 17966 |
+
nodes[i].prev = nodes[i-1]
|
| 17967 |
+
|
| 17968 |
+
if i + 1 < len(nodes):
|
| 17969 |
+
nodes[i].next = nodes[i+1]
|
| 17970 |
+
|
| 17971 |
+
seq_nodes.append(nodes)
|
| 17972 |
+
|
| 17973 |
+
for i in range(len(nodes)-1):
|
| 17974 |
+
left = nodes[i]
|
| 17975 |
+
pair = (left.token, left.next.token)
|
| 17976 |
+
pair_counts[pair] += 1
|
| 17977 |
+
occurrences[pair].add(left)
|
| 17978 |
+
|
| 17979 |
+
heap: List[Tuple[int, Tuple[int,int]]] = [(-cnt, pair) for pair, cnt in pair_counts.items()]
|
| 17980 |
+
heapq.heapify(heap)
|
| 17981 |
+
|
| 17982 |
+
merges_needed = target_vocab_size - current_vocab_size
|
| 17983 |
+
pbar = tqdm.tqdm(total=merges_needed, desc="BPE merges", disable=not show_progress)
|
| 17984 |
+
|
| 17985 |
+
def _repr_struct(cid):
|
| 17986 |
+
rep = id_to_token.get(cid)
|
| 17987 |
+
|
| 17988 |
+
if rep is None:
|
| 17989 |
+
return str(cid)
|
| 17990 |
+
|
| 17991 |
+
def _fmt(r):
|
| 17992 |
+
if r[0] == 'orig':
|
| 17993 |
+
return str(r[1])
|
| 17994 |
+
|
| 17995 |
+
return "(" + _fmt(r[1]) + "," + _fmt(r[2]) + ")"
|
| 17996 |
+
|
| 17997 |
+
return _fmt(rep)
|
| 17998 |
+
|
| 17999 |
+
def _dec_count(pair: Tuple[int,int], node: Optional[_Node] = None):
|
| 18000 |
+
"""Decrement count for pair and remove node from occurrences if provided."""
|
| 18001 |
+
c = pair_counts.get(pair, 0)
|
| 18002 |
+
|
| 18003 |
+
if c <= 1:
|
| 18004 |
+
pair_counts.pop(pair, None)
|
| 18005 |
+
|
| 18006 |
+
if pair in occurrences:
|
| 18007 |
+
if node is None:
|
| 18008 |
+
occurrences.pop(pair, None)
|
| 18009 |
+
|
| 18010 |
+
else:
|
| 18011 |
+
occ = occurrences.get(pair)
|
| 18012 |
+
if occ:
|
| 18013 |
+
occ.discard(node)
|
| 18014 |
+
if not occ:
|
| 18015 |
+
occurrences.pop(pair, None)
|
| 18016 |
+
|
| 18017 |
+
else:
|
| 18018 |
+
pair_counts[pair] = c - 1
|
| 18019 |
+
if node is not None:
|
| 18020 |
+
occ = occurrences.get(pair)
|
| 18021 |
+
if occ:
|
| 18022 |
+
occ.discard(node)
|
| 18023 |
+
if not occ:
|
| 18024 |
+
occurrences.pop(pair, None)
|
| 18025 |
+
|
| 18026 |
+
def _inc_count(pair: Tuple[int,int], node: Optional[_Node] = None):
|
| 18027 |
+
"""Increment count for pair and add node to occurrences if provided."""
|
| 18028 |
+
pair_counts[pair] += 1
|
| 18029 |
+
|
| 18030 |
+
if node is not None:
|
| 18031 |
+
occurrences[pair].add(node)
|
| 18032 |
+
|
| 18033 |
+
heapq.heappush(heap, (-pair_counts[pair], pair))
|
| 18034 |
+
|
| 18035 |
+
merges_done = 0
|
| 18036 |
+
|
| 18037 |
+
while current_vocab_size < target_vocab_size and heap:
|
| 18038 |
+
while heap:
|
| 18039 |
+
negcnt, pair = heap[0]
|
| 18040 |
+
cnt = -negcnt
|
| 18041 |
+
|
| 18042 |
+
if pair not in pair_counts or pair_counts[pair] != cnt:
|
| 18043 |
+
heapq.heappop(heap)
|
| 18044 |
+
continue
|
| 18045 |
+
|
| 18046 |
+
break
|
| 18047 |
+
|
| 18048 |
+
if not heap:
|
| 18049 |
+
break
|
| 18050 |
+
|
| 18051 |
+
negcnt, pair = heapq.heappop(heap)
|
| 18052 |
+
freq = -negcnt
|
| 18053 |
+
|
| 18054 |
+
if freq < min_frequency:
|
| 18055 |
+
break
|
| 18056 |
+
|
| 18057 |
+
a, b = pair
|
| 18058 |
+
new_id = next_id
|
| 18059 |
+
next_id += 1
|
| 18060 |
+
|
| 18061 |
+
if verbose:
|
| 18062 |
+
print(f"Merging pair ({_repr_struct(a)},{_repr_struct(b)}) -> {new_id} (freq={freq})")
|
| 18063 |
+
|
| 18064 |
+
merges.append((a, b, new_id))
|
| 18065 |
+
|
| 18066 |
+
left_repr = id_to_token[a]
|
| 18067 |
+
right_repr = id_to_token[b]
|
| 18068 |
+
new_repr = ('pair', left_repr, right_repr)
|
| 18069 |
+
id_to_token[new_id] = new_repr
|
| 18070 |
+
token_to_id[new_repr] = new_id
|
| 18071 |
+
|
| 18072 |
+
occ_set = occurrences.get(pair)
|
| 18073 |
+
|
| 18074 |
+
if not occ_set:
|
| 18075 |
+
pair_counts.pop(pair, None)
|
| 18076 |
+
occurrences.pop(pair, None)
|
| 18077 |
+
continue
|
| 18078 |
+
|
| 18079 |
+
affected_nodes = list(occ_set)
|
| 18080 |
+
|
| 18081 |
+
occurrences.pop(pair, None)
|
| 18082 |
+
pair_counts.pop(pair, None)
|
| 18083 |
+
|
| 18084 |
+
for left_node in affected_nodes:
|
| 18085 |
+
|
| 18086 |
+
if left_node.token != a:
|
| 18087 |
+
continue
|
| 18088 |
+
|
| 18089 |
+
right = left_node.next
|
| 18090 |
+
|
| 18091 |
+
if right is None or right.token != b:
|
| 18092 |
+
continue
|
| 18093 |
+
|
| 18094 |
+
prev_node = left_node.prev
|
| 18095 |
+
next_node = right.next
|
| 18096 |
+
|
| 18097 |
+
if prev_node is not None:
|
| 18098 |
+
_dec_count((prev_node.token, left_node.token), prev_node)
|
| 18099 |
+
|
| 18100 |
+
if next_node is not None:
|
| 18101 |
+
_dec_count((right.token, next_node.token), right)
|
| 18102 |
+
|
| 18103 |
+
left_node.token = new_id
|
| 18104 |
+
|
| 18105 |
+
left_node.next = next_node
|
| 18106 |
+
if next_node is not None:
|
| 18107 |
+
next_node.prev = left_node
|
| 18108 |
+
|
| 18109 |
+
if prev_node is not None:
|
| 18110 |
+
_inc_count((prev_node.token, left_node.token), prev_node)
|
| 18111 |
+
|
| 18112 |
+
if next_node is not None:
|
| 18113 |
+
_inc_count((left_node.token, next_node.token), left_node)
|
| 18114 |
+
|
| 18115 |
+
current_vocab_size += 1
|
| 18116 |
+
merges_done += 1
|
| 18117 |
+
pbar.update(1)
|
| 18118 |
+
|
| 18119 |
+
pbar.close()
|
| 18120 |
+
|
| 18121 |
+
return merges, token_to_id, id_to_token
|
| 18122 |
+
|
| 18123 |
+
###################################################################################
|
| 18124 |
+
|
| 18125 |
+
def bpe_encode(
|
| 18126 |
+
seq: List[int],
|
| 18127 |
+
merges: List[Tuple[int,int,int]],
|
| 18128 |
+
token_to_id: Optional[Dict[Any,int]] = None,
|
| 18129 |
+
show_progress: bool = True
|
| 18130 |
+
) -> List[int]:
|
| 18131 |
+
|
| 18132 |
+
"""
|
| 18133 |
+
Encode a single sequence using merges applied in order.
|
| 18134 |
+
This implementation uses a simple left-to-right scan per merge (same semantics as original).
|
| 18135 |
+
"""
|
| 18136 |
+
|
| 18137 |
+
if token_to_id is not None:
|
| 18138 |
+
s = [token_to_id.get(t, t) for t in seq]
|
| 18139 |
+
|
| 18140 |
+
else:
|
| 18141 |
+
s = list(seq)
|
| 18142 |
+
|
| 18143 |
+
if not merges or not s:
|
| 18144 |
+
return s
|
| 18145 |
+
|
| 18146 |
+
for a, b, new_tok in tqdm.tqdm(merges, disable=not show_progress):
|
| 18147 |
+
if len(s) < 2:
|
| 18148 |
+
break
|
| 18149 |
+
|
| 18150 |
+
out = []
|
| 18151 |
+
i = 0
|
| 18152 |
+
n = len(s)
|
| 18153 |
+
ai = a; bi = b; nt = new_tok
|
| 18154 |
+
|
| 18155 |
+
while i < n:
|
| 18156 |
+
if i + 1 < n and s[i] == ai and s[i+1] == bi:
|
| 18157 |
+
out.append(nt)
|
| 18158 |
+
i += 2
|
| 18159 |
+
|
| 18160 |
+
else:
|
| 18161 |
+
out.append(s[i])
|
| 18162 |
+
i += 1
|
| 18163 |
+
s = out
|
| 18164 |
+
|
| 18165 |
+
return s
|
| 18166 |
+
|
| 18167 |
+
###################################################################################
|
| 18168 |
+
|
| 18169 |
+
def encode_bpe_corpus(
|
| 18170 |
+
corpus: Iterable[List[int]],
|
| 18171 |
+
merges: List[Tuple[int,int,int]],
|
| 18172 |
+
token_to_id: Optional[Dict[Any,int]] = None,
|
| 18173 |
+
show_corpus_progress: bool = False,
|
| 18174 |
+
show_seq_progress: bool = True
|
| 18175 |
+
) -> List[List[int]]:
|
| 18176 |
+
|
| 18177 |
+
encoded_corpus = []
|
| 18178 |
+
|
| 18179 |
+
for seq in tqdm.tqdm(corpus, disable=not show_corpus_progress):
|
| 18180 |
+
encoded = bpe_encode(seq, merges, token_to_id=token_to_id, show_progress=show_seq_progress)
|
| 18181 |
+
encoded_corpus.append(encoded)
|
| 18182 |
+
|
| 18183 |
+
return encoded_corpus
|
| 18184 |
+
|
| 18185 |
+
###################################################################################
|
| 18186 |
+
|
| 18187 |
+
def bpe_decode(encoded_seq: List[int], id_to_token: Dict[int, Any]) -> List[int]:
|
| 18188 |
+
"""
|
| 18189 |
+
Decode encoded sequence back to original integer tokens using explicit stack.
|
| 18190 |
+
"""
|
| 18191 |
+
out: List[int] = []
|
| 18192 |
+
stack = list(reversed(encoded_seq))
|
| 18193 |
+
|
| 18194 |
+
while stack:
|
| 18195 |
+
tok = stack.pop()
|
| 18196 |
+
|
| 18197 |
+
if isinstance(tok, int) and tok in id_to_token:
|
| 18198 |
+
rep = id_to_token[tok]
|
| 18199 |
+
|
| 18200 |
+
else:
|
| 18201 |
+
rep = tok if isinstance(tok, tuple) else None
|
| 18202 |
+
|
| 18203 |
+
if rep is None:
|
| 18204 |
+
out.append(tok)
|
| 18205 |
+
continue
|
| 18206 |
+
|
| 18207 |
+
tag = rep[0]
|
| 18208 |
+
if tag == 'orig':
|
| 18209 |
+
out.append(rep[1])
|
| 18210 |
+
|
| 18211 |
+
elif tag == 'pair':
|
| 18212 |
+
stack.append(rep[2])
|
| 18213 |
+
stack.append(rep[1])
|
| 18214 |
+
|
| 18215 |
+
else:
|
| 18216 |
+
out.append(rep)
|
| 18217 |
+
|
| 18218 |
+
return out
|
| 18219 |
+
|
| 18220 |
+
###################################################################################
|
| 18221 |
+
|
| 18222 |
+
def decode_bpe_corpus(encoded_corpus: Iterable[List[int]], id_to_token: Dict[int, Any]) -> List[List[int]]:
|
| 18223 |
+
return [bpe_decode(seq, id_to_token) for seq in encoded_corpus]
|
| 18224 |
+
|
| 18225 |
+
###################################################################################
|
| 18226 |
+
|
| 18227 |
print('Module loaded!')
|
| 18228 |
print('=' * 70)
|
| 18229 |
print('Enjoy! :)')
|