import regex as re PROGRAM_SPECIAL_TOKEN="" UTTERANCES_SPECIAL_TOKEN="" GT_PROGRAM_SPECIAL_TOKEN="" def consistent(rx, spec): # spec is in the form of (string, '+'/'-') pairs for s, label in spec: if not label in ['+', '-']: return None try: if re.fullmatch(rx, s, timeout=1): if label == '-': return False else: if label == '+': return False except re.error: return None except TimeoutError: return None return True def decode(c): if c < 3: return f"<{c}>" elif c < 258: return chr(c - 3) else: return f"" def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): skipped_tokens = outputs if skip_special_tokens: skipped_tokens = [ [[t for t in x if t >= 3] for x in beam] for beam in skipped_tokens ] if skip_position_token: skipped_tokens = [ [[t for t in x if t <= 258] for x in beam] for beam in skipped_tokens ] return [ [''.join([decode(t) for t in x]) for x in beam] for beam in skipped_tokens ] def get_preprocess_function(tokenizer): def preprocess_function(examples): model_inputs = tokenizer( [' ' if x is None else x for x in examples["context"]], text_target=examples["target"], truncation=True ) return model_inputs return preprocess_function def get_utterance_processing_functions(label_pos, idx, separator=' '): if label_pos == "suffix": if idx: def utterances_to_string(spec): return ''.join([f"{s}{label}" for i, (s, label) in enumerate(spec)]) else: def utterances_to_string(spec): return separator.join([f"{s}{label}" for s, label in spec]) else: if idx: def utterances_to_string(spec): return ''.join([f"{label}{s}" for i, (s, label) in enumerate(spec)]) else: def utterances_to_string(spec): return separator.join([f"{label}{s}" for s, label in spec]) if label_pos == "suffix": if idx: def string_to_utterances(string): string = re.sub(r'', ' ', string) return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] else: def string_to_utterances(string): return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] else: if idx: def string_to_utterances(string): string = re.sub(r'', '', string) return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] else: def string_to_utterances(string): return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] return utterances_to_string, string_to_utterances