|
from dataclasses import dataclass |
|
from typing import Optional, List |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig |
|
import regex as re |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
PROGRAM_SPECIAL_TOKEN="<extra_id_124>" |
|
UTTERANCES_SPECIAL_TOKEN="<extra_id_123>" |
|
GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>" |
|
|
|
def consistent(rx, spec): |
|
|
|
for s, label in spec: |
|
if not label in ['+', '-']: |
|
return None |
|
try: |
|
if re.fullmatch(rx, s, timeout=1): |
|
if label == '-': |
|
return False |
|
else: |
|
if label == '+': |
|
return False |
|
except re.error: |
|
return None |
|
except TimeoutError: |
|
return None |
|
|
|
return True |
|
|
|
def get_utterance_processing_functions(label_pos, idx, separator=' '): |
|
if label_pos == "suffix": |
|
if idx: |
|
def utterances_to_string(spec): |
|
return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)]) |
|
else: |
|
def utterances_to_string(spec): |
|
return separator.join([f"{s}{label}" for s, label in spec]) |
|
else: |
|
if idx: |
|
def utterances_to_string(spec): |
|
return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)]) |
|
else: |
|
def utterances_to_string(spec): |
|
return separator.join([f"{label}{s}" for s, label in spec]) |
|
|
|
if label_pos == "suffix": |
|
if idx: |
|
def string_to_utterances(string): |
|
string = re.sub(r'<extra_id_\d+>', ' ', string) |
|
return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] |
|
else: |
|
def string_to_utterances(string): |
|
return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] |
|
else: |
|
if idx: |
|
def string_to_utterances(string): |
|
string = re.sub(r'<extra_id_\d+>', '', string) |
|
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] |
|
else: |
|
def string_to_utterances(string): |
|
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] |
|
|
|
return utterances_to_string, string_to_utterances |
|
|
|
def decode(c): |
|
if c < 3: |
|
return f"<{c}>" |
|
elif c < 258: |
|
return chr(c - 3) |
|
else: |
|
return f"<extra_id_{c - 259}>" |
|
|
|
def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): |
|
skipped_tokens = outputs |
|
if skip_special_tokens: |
|
skipped_tokens = [ |
|
[[t for t in x if t >= 3] for x in beam] |
|
for beam in skipped_tokens |
|
] |
|
|
|
if skip_position_token: |
|
skipped_tokens = [ |
|
[[t for t in x if t <= 258] for x in beam] |
|
for beam in skipped_tokens |
|
] |
|
|
|
return [ |
|
[''.join([decode(t) for t in x]) for x in beam] |
|
for beam in skipped_tokens |
|
] |
|
|
|
class Agent: |
|
def __init__(self, |
|
model_path: str, |
|
gen_config: dict, |
|
device: str = "cuda", |
|
): |
|
self.device = device |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.gen_config = GenerationConfig(**gen_config) |
|
|
|
@dataclass |
|
class ListenerOutput: |
|
programs: List[List[str]] |
|
idx: Optional[List[List[int]]] = None |
|
decoded: Optional[List[List[str]]] = None |
|
decoded_scores: Optional[List[List[float]]] = None |
|
pruned: Optional[List[List[str]]] = None |
|
|
|
|
|
class Listener(Agent): |
|
def __init__(self, |
|
model_path, |
|
gen_config, |
|
device="cuda", |
|
label_pos="suffix", |
|
idx: bool=True, |
|
program_special_token=PROGRAM_SPECIAL_TOKEN, |
|
utterances_special_token=UTTERANCES_SPECIAL_TOKEN |
|
): |
|
super().__init__( |
|
model_path, |
|
gen_config, |
|
device=device |
|
) |
|
self.label_pos = label_pos |
|
self.idx = idx |
|
self.program_special_token = program_special_token |
|
self.utterances_special_token = utterances_special_token |
|
self.utterances_to_string, self.string_to_utterances = ( |
|
get_utterance_processing_functions( |
|
label_pos, idx, separator=utterances_special_token |
|
) |
|
) |
|
|
|
def synthesize(self, context, return_scores=False, enforce_consistency=True): |
|
|
|
if isinstance(context[0], list): |
|
context_str = list(map(self.utterances_to_string, context)) |
|
else: |
|
context_str = context |
|
|
|
context_tokens = self.tokenizer( |
|
[f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c |
|
for c in context_str], |
|
return_tensors="pt", |
|
padding=True |
|
).to(self.device) |
|
|
|
decoder_inputs = self.tokenizer( |
|
[self.program_special_token for _ in context], return_tensors="pt", |
|
add_special_tokens=False |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**context_tokens, |
|
decoder_input_ids=decoder_inputs.input_ids, |
|
generation_config=self.gen_config, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
|
|
decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True) |
|
|
|
consistent_programs = [] |
|
idxs = [] |
|
for decoded, ctx in zip(decoded_batch, context): |
|
cp = [] |
|
idx = [] |
|
for i, p in enumerate(decoded): |
|
if enforce_consistency: |
|
if consistent(p, ctx): |
|
cp.append(p) |
|
idx.append(i) |
|
else: |
|
cp.append(p) |
|
idx.append(i) |
|
|
|
consistent_programs.append(cp) |
|
idxs.append(idx) |
|
|
|
logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) |
|
gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1) |
|
gen_probs.masked_fill_(gen_probs.isinf(), 0) |
|
scores = gen_probs.sum(-1) |
|
n_decoded = scores.shape[0] |
|
n_seq = n_decoded // len(context) |
|
scores = scores.reshape((len(context), n_seq)) |
|
scores_list = scores.tolist() |
|
|
|
if return_scores: |
|
return ListenerOutput( |
|
consistent_programs, |
|
idxs, |
|
decoded_batch, |
|
scores_list |
|
) |
|
else: |
|
return ListenerOutput(consistent_programs) |
|
|
|
|
|
def score_program(self, contexts, programs): |
|
if isinstance(contexts[0], list): |
|
context_str = list(map(self.utterances_to_string, contexts)) |
|
else: |
|
context_str = contexts |
|
|
|
context_tokens = self.tokenizer( |
|
[f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c |
|
for c in context_str], |
|
return_tensors="pt", |
|
padding=True |
|
).to(self.device) |
|
|
|
program_tokens = self.tokenizer([f"{self.program_special_token}{p}" for p in programs], return_tensors="pt").to(self.device) |
|
outputs = self.model(input_ids=context_tokens.input_ids, decoder_input_ids=program_tokens.input_ids, return_dict=True) |
|
|
|
logprobs = torch.gather(F.log_softmax(outputs.logits, dim=-1), 2, program_tokens.input_ids[:, 1:, None]).squeeze(-1) |
|
|
|
logprobs.masked_fill_(program_tokens.input_ids[:, 1:] == 0, 0) |
|
|
|
scores = logprobs.sum(-1) |
|
|
|
return scores.tolist() |