human-data-ft-listener / listener.py
saujasv's picture
Add listener class
003a9ca
from dataclasses import dataclass
from typing import Optional, List
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
import regex as re
import torch
import torch.nn.functional as F
PROGRAM_SPECIAL_TOKEN="<extra_id_124>"
UTTERANCES_SPECIAL_TOKEN="<extra_id_123>"
GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>"
def consistent(rx, spec):
# spec is in the form of (string, '+'/'-') pairs
for s, label in spec:
if not label in ['+', '-']:
return None
try:
if re.fullmatch(rx, s, timeout=1):
if label == '-':
return False
else:
if label == '+':
return False
except re.error:
return None
except TimeoutError:
return None
return True
def get_utterance_processing_functions(label_pos, idx, separator=' '):
if label_pos == "suffix":
if idx:
def utterances_to_string(spec):
return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)])
else:
def utterances_to_string(spec):
return separator.join([f"{s}{label}" for s, label in spec])
else:
if idx:
def utterances_to_string(spec):
return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)])
else:
def utterances_to_string(spec):
return separator.join([f"{label}{s}" for s, label in spec])
if label_pos == "suffix":
if idx:
def string_to_utterances(string):
string = re.sub(r'<extra_id_\d+>', ' ', string)
return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0]
else:
def string_to_utterances(string):
return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0]
else:
if idx:
def string_to_utterances(string):
string = re.sub(r'<extra_id_\d+>', '', string)
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
else:
def string_to_utterances(string):
return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
return utterances_to_string, string_to_utterances
def decode(c):
if c < 3:
return f"<{c}>"
elif c < 258:
return chr(c - 3)
else:
return f"<extra_id_{c - 259}>"
def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False):
skipped_tokens = outputs
if skip_special_tokens:
skipped_tokens = [
[[t for t in x if t >= 3] for x in beam]
for beam in skipped_tokens
]
if skip_position_token:
skipped_tokens = [
[[t for t in x if t <= 258] for x in beam]
for beam in skipped_tokens
]
return [
[''.join([decode(t) for t in x]) for x in beam]
for beam in skipped_tokens
]
class Agent:
def __init__(self,
model_path: str,
gen_config: dict,
device: str = "cuda",
):
self.device = device
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.gen_config = GenerationConfig(**gen_config)
@dataclass
class ListenerOutput:
programs: List[List[str]]
idx: Optional[List[List[int]]] = None
decoded: Optional[List[List[str]]] = None
decoded_scores: Optional[List[List[float]]] = None
pruned: Optional[List[List[str]]] = None
class Listener(Agent):
def __init__(self,
model_path,
gen_config,
device="cuda",
label_pos="suffix",
idx: bool=True,
program_special_token=PROGRAM_SPECIAL_TOKEN,
utterances_special_token=UTTERANCES_SPECIAL_TOKEN
):
super().__init__(
model_path,
gen_config,
device=device
)
self.label_pos = label_pos
self.idx = idx
self.program_special_token = program_special_token
self.utterances_special_token = utterances_special_token
self.utterances_to_string, self.string_to_utterances = (
get_utterance_processing_functions(
label_pos, idx, separator=utterances_special_token
)
)
def synthesize(self, context, return_scores=False, enforce_consistency=True):
# If context is a list of utterances, convert to string
if isinstance(context[0], list):
context_str = list(map(self.utterances_to_string, context))
else:
context_str = context
context_tokens = self.tokenizer(
[f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c
for c in context_str],
return_tensors="pt",
padding=True
).to(self.device)
decoder_inputs = self.tokenizer(
[self.program_special_token for _ in context], return_tensors="pt",
add_special_tokens=False
).to(self.device)
outputs = self.model.generate(**context_tokens,
decoder_input_ids=decoder_inputs.input_ids,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True)
consistent_programs = []
idxs = []
for decoded, ctx in zip(decoded_batch, context):
cp = []
idx = []
for i, p in enumerate(decoded):
if enforce_consistency:
if consistent(p, ctx):
cp.append(p)
idx.append(i)
else:
cp.append(p)
idx.append(i)
consistent_programs.append(cp)
idxs.append(idx)
logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1)
gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1)
gen_probs.masked_fill_(gen_probs.isinf(), 0)
scores = gen_probs.sum(-1)
n_decoded = scores.shape[0]
n_seq = n_decoded // len(context)
scores = scores.reshape((len(context), n_seq))
scores_list = scores.tolist()
if return_scores:
return ListenerOutput(
consistent_programs,
idxs,
decoded_batch,
scores_list
)
else:
return ListenerOutput(consistent_programs)
def score_program(self, contexts, programs):
if isinstance(contexts[0], list):
context_str = list(map(self.utterances_to_string, contexts))
else:
context_str = contexts
context_tokens = self.tokenizer(
[f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c
for c in context_str],
return_tensors="pt",
padding=True
).to(self.device)
program_tokens = self.tokenizer([f"{self.program_special_token}{p}" for p in programs], return_tensors="pt").to(self.device)
outputs = self.model(input_ids=context_tokens.input_ids, decoder_input_ids=program_tokens.input_ids, return_dict=True)
logprobs = torch.gather(F.log_softmax(outputs.logits, dim=-1), 2, program_tokens.input_ids[:, 1:, None]).squeeze(-1)
logprobs.masked_fill_(program_tokens.input_ids[:, 1:] == 0, 0)
scores = logprobs.sum(-1)
return scores.tolist()