saujasv commited on
Commit
2869f1d
1 Parent(s): 980c3e1

make barebones gradio interface

Browse files
Files changed (4) hide show
  1. app.py +12 -0
  2. listener.py +119 -0
  3. requirements.txt +0 -0
  4. utils.py +97 -0
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from listener import Listener
3
+
4
+ listener = Listener("pragmatic-programs/pragmatic-ft-listener", {"do_sample": True, "num_return_sequences": 100, "num_beams": 1, "temperature": 1, "top_p": 0.9})
5
+
6
+ def synthesize(context):
7
+ spec = [[[s[:-1], s[-1]] for s in context.split(' ')]]
8
+ return listener.synthesize(spec).programs[0][0]
9
+
10
+
11
+ iface = gr.Interface(fn=synthesize, inputs="text", outputs="text")
12
+ iface.launch()
listener.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+ from utils import get_preprocess_function, get_utterance_processing_functions, byt5_decode_batch, consistent
5
+ from utils import PROGRAM_SPECIAL_TOKEN, UTTERANCES_SPECIAL_TOKEN, GT_PROGRAM_SPECIAL_TOKEN
6
+ from greenery import parse
7
+ from greenery.parse import NoMatch
8
+ import numpy as np
9
+ import torch
10
+
11
+ class Agent:
12
+ def __init__(self,
13
+ model_path: str,
14
+ gen_config: dict,
15
+ inference_batch_size: int = 1,
16
+ ):
17
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
19
+ self.gen_config = GenerationConfig(**gen_config)
20
+ self.inference_batch_size = inference_batch_size
21
+
22
+ @dataclass
23
+ class ListenerOutput:
24
+ programs: List[List[str]]
25
+ idx: Optional[List[List[int]]] = None
26
+ decoded: Optional[List[List[str]]] = None
27
+ decoded_scores: Optional[List[List[float]]] = None
28
+ pruned: Optional[List[List[str]]] = None
29
+
30
+ class Listener(Agent):
31
+ def __init__(self,
32
+ model_path,
33
+ gen_config,
34
+ inference_batch_size=4,
35
+ label_pos="suffix",
36
+ idx: bool=True,
37
+ program_special_token=PROGRAM_SPECIAL_TOKEN,
38
+ utterances_special_token=UTTERANCES_SPECIAL_TOKEN
39
+ ):
40
+ super().__init__(
41
+ model_path,
42
+ gen_config,
43
+ inference_batch_size,
44
+ )
45
+ self.label_pos = label_pos
46
+ self.idx = idx
47
+ self.program_special_token = program_special_token
48
+ self.utterances_special_token = utterances_special_token
49
+ self.utterances_to_string, self.string_to_utterances = (
50
+ get_utterance_processing_functions(
51
+ label_pos, idx, separator=utterances_special_token
52
+ )
53
+ )
54
+ self.device = self.model.device
55
+
56
+ def synthesize(self, context, return_scores=False, enforce_consistency=True):
57
+ # If context is a list of utterances, convert to string
58
+ if isinstance(context[0], list):
59
+ context_str = list(map(self.utterances_to_string, context))
60
+ else:
61
+ context_str = context
62
+
63
+ context_tokens = self.tokenizer(
64
+ [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c
65
+ for c in context_str],
66
+ return_tensors="pt",
67
+ padding=True
68
+ ).to(self.device)
69
+
70
+ decoder_inputs = self.tokenizer(
71
+ [self.program_special_token for _ in context], return_tensors="pt",
72
+ add_special_tokens=False
73
+ ).to(self.device)
74
+
75
+ outputs = self.model.generate(**context_tokens,
76
+ decoder_input_ids=decoder_inputs.input_ids,
77
+ generation_config=self.gen_config,
78
+ return_dict_in_generate=True,
79
+ output_scores=True
80
+ )
81
+
82
+ decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True)
83
+
84
+ consistent_programs = []
85
+ idxs = []
86
+ for decoded, ctx in zip(decoded_batch, context):
87
+ cp = []
88
+ idx = []
89
+ for i, p in enumerate(decoded):
90
+ if enforce_consistency:
91
+ if consistent(p, ctx):
92
+ cp.append(p)
93
+ idx.append(i)
94
+ else:
95
+ cp.append(p)
96
+ idx.append(i)
97
+
98
+ consistent_programs.append(cp)
99
+ idxs.append(idx)
100
+
101
+ logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1)
102
+ gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1)
103
+ gen_probs.masked_fill_(gen_probs.isinf(), 0)
104
+ scores = gen_probs.sum(-1)
105
+ n_decoded = scores.shape[0]
106
+ n_seq = n_decoded // len(context)
107
+ scores = scores.reshape((len(context), n_seq))
108
+ scores_list = scores.tolist()
109
+
110
+ if return_scores:
111
+ return ListenerOutput(
112
+ consistent_programs,
113
+ idxs,
114
+ decoded_batch,
115
+ scores_list
116
+ )
117
+ else:
118
+ return ListenerOutput(consistent_programs)
119
+
requirements.txt ADDED
File without changes
utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+
3
+ PROGRAM_SPECIAL_TOKEN="<extra_id_124>"
4
+ UTTERANCES_SPECIAL_TOKEN="<extra_id_123>"
5
+ GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>"
6
+
7
+ def consistent(rx, spec):
8
+ # spec is in the form of (string, '+'/'-') pairs
9
+ for s, label in spec:
10
+ if not label in ['+', '-']:
11
+ return None
12
+ try:
13
+ if re.fullmatch(rx, s, timeout=1):
14
+ if label == '-':
15
+ return False
16
+ else:
17
+ if label == '+':
18
+ return False
19
+ except re.error:
20
+ return None
21
+ except TimeoutError:
22
+ return None
23
+
24
+ return True
25
+
26
+ def decode(c):
27
+ if c < 3:
28
+ return f"<{c}>"
29
+ elif c < 258:
30
+ return chr(c - 3)
31
+ else:
32
+ return f"<extra_id_{c - 259}>"
33
+
34
+ def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False):
35
+ skipped_tokens = outputs
36
+ if skip_special_tokens:
37
+ skipped_tokens = [
38
+ [[t for t in x if t >= 3] for x in beam]
39
+ for beam in skipped_tokens
40
+ ]
41
+
42
+ if skip_position_token:
43
+ skipped_tokens = [
44
+ [[t for t in x if t <= 258] for x in beam]
45
+ for beam in skipped_tokens
46
+ ]
47
+
48
+ return [
49
+ [''.join([decode(t) for t in x]) for x in beam]
50
+ for beam in skipped_tokens
51
+ ]
52
+
53
+ def get_preprocess_function(tokenizer):
54
+ def preprocess_function(examples):
55
+ model_inputs = tokenizer(
56
+ [' ' if x is None else x for x in examples["context"]],
57
+ text_target=examples["target"],
58
+ truncation=True
59
+ )
60
+ return model_inputs
61
+
62
+ return preprocess_function
63
+
64
+ def get_utterance_processing_functions(label_pos, idx, separator=' '):
65
+ if label_pos == "suffix":
66
+ if idx:
67
+ def utterances_to_string(spec):
68
+ return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)])
69
+ else:
70
+ def utterances_to_string(spec):
71
+ return separator.join([f"{s}{label}" for s, label in spec])
72
+ else:
73
+ if idx:
74
+ def utterances_to_string(spec):
75
+ return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)])
76
+ else:
77
+ def utterances_to_string(spec):
78
+ return separator.join([f"{label}{s}" for s, label in spec])
79
+
80
+ if label_pos == "suffix":
81
+ if idx:
82
+ def string_to_utterances(string):
83
+ string = re.sub(r'<extra_id_\d+>', ' ', string)
84
+ return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0]
85
+ else:
86
+ def string_to_utterances(string):
87
+ return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0]
88
+ else:
89
+ if idx:
90
+ def string_to_utterances(string):
91
+ string = re.sub(r'<extra_id_\d+>', '', string)
92
+ return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
93
+ else:
94
+ def string_to_utterances(string):
95
+ return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0]
96
+
97
+ return utterances_to_string, string_to_utterances