from dataclasses import dataclass from typing import Optional, List from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig import regex as re import torch import torch.nn.functional as F PROGRAM_SPECIAL_TOKEN="" UTTERANCES_SPECIAL_TOKEN="" GT_PROGRAM_SPECIAL_TOKEN="" def consistent(rx, spec): # spec is in the form of (string, '+'/'-') pairs for s, label in spec: if not label in ['+', '-']: return None try: if re.fullmatch(rx, s, timeout=1): if label == '-': return False else: if label == '+': return False except re.error: return None except TimeoutError: return None return True def get_utterance_processing_functions(label_pos, idx, separator=' '): if label_pos == "suffix": if idx: def utterances_to_string(spec): return ''.join([f"{s}{label}" for i, (s, label) in enumerate(spec)]) else: def utterances_to_string(spec): return separator.join([f"{s}{label}" for s, label in spec]) else: if idx: def utterances_to_string(spec): return ''.join([f"{label}{s}" for i, (s, label) in enumerate(spec)]) else: def utterances_to_string(spec): return separator.join([f"{label}{s}" for s, label in spec]) if label_pos == "suffix": if idx: def string_to_utterances(string): string = re.sub(r'', ' ', string) return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] else: def string_to_utterances(string): return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] else: if idx: def string_to_utterances(string): string = re.sub(r'', '', string) return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] else: def string_to_utterances(string): return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] return utterances_to_string, string_to_utterances def decode(c): if c < 3: return f"<{c}>" elif c < 258: return chr(c - 3) else: return f"" def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): skipped_tokens = outputs if skip_special_tokens: skipped_tokens = [ [[t for t in x if t >= 3] for x in beam] for beam in skipped_tokens ] if skip_position_token: skipped_tokens = [ [[t for t in x if t <= 258] for x in beam] for beam in skipped_tokens ] return [ [''.join([decode(t) for t in x]) for x in beam] for beam in skipped_tokens ] class Agent: def __init__(self, model_path: str, gen_config: dict, device: str = "cuda", ): self.device = device self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.gen_config = GenerationConfig(**gen_config) @dataclass class ListenerOutput: programs: List[List[str]] idx: Optional[List[List[int]]] = None decoded: Optional[List[List[str]]] = None decoded_scores: Optional[List[List[float]]] = None pruned: Optional[List[List[str]]] = None class Listener(Agent): def __init__(self, model_path, gen_config, device="cuda", label_pos="suffix", idx: bool=True, program_special_token=PROGRAM_SPECIAL_TOKEN, utterances_special_token=UTTERANCES_SPECIAL_TOKEN ): super().__init__( model_path, gen_config, device=device ) self.label_pos = label_pos self.idx = idx self.program_special_token = program_special_token self.utterances_special_token = utterances_special_token self.utterances_to_string, self.string_to_utterances = ( get_utterance_processing_functions( label_pos, idx, separator=utterances_special_token ) ) def synthesize(self, context, return_scores=False, enforce_consistency=True): # If context is a list of utterances, convert to string if isinstance(context[0], list): context_str = list(map(self.utterances_to_string, context)) else: context_str = context context_tokens = self.tokenizer( [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c for c in context_str], return_tensors="pt", padding=True ).to(self.device) decoder_inputs = self.tokenizer( [self.program_special_token for _ in context], return_tensors="pt", add_special_tokens=False ).to(self.device) outputs = self.model.generate(**context_tokens, decoder_input_ids=decoder_inputs.input_ids, generation_config=self.gen_config, return_dict_in_generate=True, output_scores=True ) decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True) consistent_programs = [] idxs = [] for decoded, ctx in zip(decoded_batch, context): cp = [] idx = [] for i, p in enumerate(decoded): if enforce_consistency: if consistent(p, ctx): cp.append(p) idx.append(i) else: cp.append(p) idx.append(i) consistent_programs.append(cp) idxs.append(idx) logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1) gen_probs.masked_fill_(gen_probs.isinf(), 0) scores = gen_probs.sum(-1) n_decoded = scores.shape[0] n_seq = n_decoded // len(context) scores = scores.reshape((len(context), n_seq)) scores_list = scores.tolist() if return_scores: return ListenerOutput( consistent_programs, idxs, decoded_batch, scores_list ) else: return ListenerOutput(consistent_programs) def score_program(self, contexts, programs): if isinstance(contexts[0], list): context_str = list(map(self.utterances_to_string, contexts)) else: context_str = contexts context_tokens = self.tokenizer( [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c for c in context_str], return_tensors="pt", padding=True ).to(self.device) program_tokens = self.tokenizer([f"{self.program_special_token}{p}" for p in programs], return_tensors="pt").to(self.device) outputs = self.model(input_ids=context_tokens.input_ids, decoder_input_ids=program_tokens.input_ids, return_dict=True) logprobs = torch.gather(F.log_softmax(outputs.logits, dim=-1), 2, program_tokens.input_ids[:, 1:, None]).squeeze(-1) logprobs.masked_fill_(program_tokens.input_ids[:, 1:] == 0, 0) scores = logprobs.sum(-1) return scores.tolist()