saujasv commited on
Commit
003a9ca
1 Parent(s): 8803e86

Add listener class

Browse files
Files changed (1) hide show
  1. listener.py +225 -0
listener.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, List
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
4
+ import regex as re
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ PROGRAM_SPECIAL_TOKEN="<extra_id_124>"
9
+ UTTERANCES_SPECIAL_TOKEN="<extra_id_123>"
10
+ GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>"
11
+
12
+ def consistent(rx, spec):
13
+ # spec is in the form of (string, '+'/'-') pairs
14
+ for s, label in spec:
15
+ if not label in ['+', '-']:
16
+ return None
17
+ try:
18
+ if re.fullmatch(rx, s, timeout=1):
19
+ if label == '-':
20
+ return False
21
+ else:
22
+ if label == '+':
23
+ return False
24
+ except re.error:
25
+ return None
26
+ except TimeoutError:
27
+ return None
28
+
29
+ return True
30
+
31
+ def get_utterance_processing_functions(label_pos, idx, separator=' '):
32
+ if label_pos == "suffix":
33
+ if idx:
34
+ def utterances_to_string(spec):
35
+ return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)])
36
+ else:
37
+ def utterances_to_string(spec):
38
+ return separator.join([f"{s}{label}" for s, label in spec])
39
+ else:
40
+ if idx:
41
+ def utterances_to_string(spec):
42
+ return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)])
43
+ else:
44
+ def utterances_to_string(spec):
45
+ return separator.join([f"{label}{s}" for s, label in spec])
46
+
47
+ if label_pos == "suffix":
48
+ if idx:
49
+ def string_to_utterances(string):
50
+ string = re.sub(r'<extra_id_\d+>', ' ', string)
51
+ return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0]
52
+ else:
53
+ def string_to_utterances(string):
54
+ return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0]
55
+ else:
56
+ if idx:
57
+ def string_to_utterances(string):
58
+ string = re.sub(r'<extra_id_\d+>', '', string)
59
+ return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
60
+ else:
61
+ def string_to_utterances(string):
62
+ return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
63
+
64
+ return utterances_to_string, string_to_utterances
65
+
66
+ def decode(c):
67
+ if c < 3:
68
+ return f"<{c}>"
69
+ elif c < 258:
70
+ return chr(c - 3)
71
+ else:
72
+ return f"<extra_id_{c - 259}>"
73
+
74
+ def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False):
75
+ skipped_tokens = outputs
76
+ if skip_special_tokens:
77
+ skipped_tokens = [
78
+ [[t for t in x if t >= 3] for x in beam]
79
+ for beam in skipped_tokens
80
+ ]
81
+
82
+ if skip_position_token:
83
+ skipped_tokens = [
84
+ [[t for t in x if t <= 258] for x in beam]
85
+ for beam in skipped_tokens
86
+ ]
87
+
88
+ return [
89
+ [''.join([decode(t) for t in x]) for x in beam]
90
+ for beam in skipped_tokens
91
+ ]
92
+
93
+ class Agent:
94
+ def __init__(self,
95
+ model_path: str,
96
+ gen_config: dict,
97
+ device: str = "cuda",
98
+ ):
99
+ self.device = device
100
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
101
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
102
+ self.gen_config = GenerationConfig(**gen_config)
103
+
104
+ @dataclass
105
+ class ListenerOutput:
106
+ programs: List[List[str]]
107
+ idx: Optional[List[List[int]]] = None
108
+ decoded: Optional[List[List[str]]] = None
109
+ decoded_scores: Optional[List[List[float]]] = None
110
+ pruned: Optional[List[List[str]]] = None
111
+
112
+
113
+ class Listener(Agent):
114
+ def __init__(self,
115
+ model_path,
116
+ gen_config,
117
+ device="cuda",
118
+ label_pos="suffix",
119
+ idx: bool=True,
120
+ program_special_token=PROGRAM_SPECIAL_TOKEN,
121
+ utterances_special_token=UTTERANCES_SPECIAL_TOKEN
122
+ ):
123
+ super().__init__(
124
+ model_path,
125
+ gen_config,
126
+ device=device
127
+ )
128
+ self.label_pos = label_pos
129
+ self.idx = idx
130
+ self.program_special_token = program_special_token
131
+ self.utterances_special_token = utterances_special_token
132
+ self.utterances_to_string, self.string_to_utterances = (
133
+ get_utterance_processing_functions(
134
+ label_pos, idx, separator=utterances_special_token
135
+ )
136
+ )
137
+
138
+ def synthesize(self, context, return_scores=False, enforce_consistency=True):
139
+ # If context is a list of utterances, convert to string
140
+ if isinstance(context[0], list):
141
+ context_str = list(map(self.utterances_to_string, context))
142
+ else:
143
+ context_str = context
144
+
145
+ context_tokens = self.tokenizer(
146
+ [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c
147
+ for c in context_str],
148
+ return_tensors="pt",
149
+ padding=True
150
+ ).to(self.device)
151
+
152
+ decoder_inputs = self.tokenizer(
153
+ [self.program_special_token for _ in context], return_tensors="pt",
154
+ add_special_tokens=False
155
+ ).to(self.device)
156
+
157
+ outputs = self.model.generate(**context_tokens,
158
+ decoder_input_ids=decoder_inputs.input_ids,
159
+ generation_config=self.gen_config,
160
+ return_dict_in_generate=True,
161
+ output_scores=True
162
+ )
163
+
164
+ decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True)
165
+
166
+ consistent_programs = []
167
+ idxs = []
168
+ for decoded, ctx in zip(decoded_batch, context):
169
+ cp = []
170
+ idx = []
171
+ for i, p in enumerate(decoded):
172
+ if enforce_consistency:
173
+ if consistent(p, ctx):
174
+ cp.append(p)
175
+ idx.append(i)
176
+ else:
177
+ cp.append(p)
178
+ idx.append(i)
179
+
180
+ consistent_programs.append(cp)
181
+ idxs.append(idx)
182
+
183
+ logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1)
184
+ gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1)
185
+ gen_probs.masked_fill_(gen_probs.isinf(), 0)
186
+ scores = gen_probs.sum(-1)
187
+ n_decoded = scores.shape[0]
188
+ n_seq = n_decoded // len(context)
189
+ scores = scores.reshape((len(context), n_seq))
190
+ scores_list = scores.tolist()
191
+
192
+ if return_scores:
193
+ return ListenerOutput(
194
+ consistent_programs,
195
+ idxs,
196
+ decoded_batch,
197
+ scores_list
198
+ )
199
+ else:
200
+ return ListenerOutput(consistent_programs)
201
+
202
+
203
+ def score_program(self, contexts, programs):
204
+ if isinstance(contexts[0], list):
205
+ context_str = list(map(self.utterances_to_string, contexts))
206
+ else:
207
+ context_str = contexts
208
+
209
+ context_tokens = self.tokenizer(
210
+ [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c
211
+ for c in context_str],
212
+ return_tensors="pt",
213
+ padding=True
214
+ ).to(self.device)
215
+
216
+ program_tokens = self.tokenizer([f"{self.program_special_token}{p}" for p in programs], return_tensors="pt").to(self.device)
217
+ outputs = self.model(input_ids=context_tokens.input_ids, decoder_input_ids=program_tokens.input_ids, return_dict=True)
218
+
219
+ logprobs = torch.gather(F.log_softmax(outputs.logits, dim=-1), 2, program_tokens.input_ids[:, 1:, None]).squeeze(-1)
220
+
221
+ logprobs.masked_fill_(program_tokens.input_ids[:, 1:] == 0, 0)
222
+
223
+ scores = logprobs.sum(-1)
224
+
225
+ return scores.tolist()