saujasv commited on
Commit
1a8e5ac
·
1 Parent(s): 6b9bc55

make demo gpu compatible

Browse files
Files changed (1) hide show
  1. listener.py +66 -44
listener.py CHANGED
@@ -1,24 +1,37 @@
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
2
  from dataclasses import dataclass
3
  from typing import List, Optional
4
- from utils import get_preprocess_function, get_utterance_processing_functions, byt5_decode_batch, consistent
5
- from utils import PROGRAM_SPECIAL_TOKEN, UTTERANCES_SPECIAL_TOKEN, GT_PROGRAM_SPECIAL_TOKEN
 
 
 
 
 
 
 
 
 
6
  from greenery import parse
7
  from greenery.parse import NoMatch
8
  import numpy as np
9
  import torch
10
 
 
11
  class Agent:
12
- def __init__(self,
13
- model_path: str,
14
- gen_config: dict,
15
- inference_batch_size: int = 1,
16
- ):
17
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
 
 
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
19
  self.gen_config = GenerationConfig(**gen_config)
20
  self.inference_batch_size = inference_batch_size
21
 
 
22
  @dataclass
23
  class ListenerOutput:
24
  programs: List[List[str]]
@@ -27,21 +40,20 @@ class ListenerOutput:
27
  decoded_scores: Optional[List[List[float]]] = None
28
  pruned: Optional[List[List[str]]] = None
29
 
 
30
  class Listener(Agent):
31
- def __init__(self,
 
32
  model_path,
33
- gen_config,
34
  inference_batch_size=4,
35
  label_pos="suffix",
36
- idx: bool=True,
37
  program_special_token=PROGRAM_SPECIAL_TOKEN,
38
- utterances_special_token=UTTERANCES_SPECIAL_TOKEN
 
39
  ):
40
- super().__init__(
41
- model_path,
42
- gen_config,
43
- inference_batch_size,
44
- )
45
  self.label_pos = label_pos
46
  self.idx = idx
47
  self.program_special_token = program_special_token
@@ -49,10 +61,10 @@ class Listener(Agent):
49
  self.utterances_to_string, self.string_to_utterances = (
50
  get_utterance_processing_functions(
51
  label_pos, idx, separator=utterances_special_token
52
- )
53
  )
 
54
  self.device = self.model.device
55
-
56
  def synthesize(self, context, return_scores=False, enforce_consistency=True):
57
  # If context is a list of utterances, convert to string
58
  if isinstance(context[0], list):
@@ -61,25 +73,39 @@ class Listener(Agent):
61
  context_str = context
62
 
63
  context_tokens = self.tokenizer(
64
- [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c
65
- for c in context_str],
 
 
 
 
 
 
66
  return_tensors="pt",
67
- padding=True
68
- ).to(self.device)
69
-
70
  decoder_inputs = self.tokenizer(
71
- [self.program_special_token for _ in context], return_tensors="pt",
72
- add_special_tokens=False
73
- ).to(self.device)
 
74
 
75
- outputs = self.model.generate(**context_tokens,
76
- decoder_input_ids=decoder_inputs.input_ids,
77
- generation_config=self.gen_config,
78
- return_dict_in_generate=True,
79
- output_scores=True
80
- )
 
81
 
82
- decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True)
 
 
 
 
 
 
83
 
84
  consistent_programs = []
85
  idxs = []
@@ -94,12 +120,14 @@ class Listener(Agent):
94
  else:
95
  cp.append(p)
96
  idx.append(i)
97
-
98
  consistent_programs.append(cp)
99
  idxs.append(idx)
100
-
101
  logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1)
102
- gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1)
 
 
103
  gen_probs.masked_fill_(gen_probs.isinf(), 0)
104
  scores = gen_probs.sum(-1)
105
  n_decoded = scores.shape[0]
@@ -108,12 +136,6 @@ class Listener(Agent):
108
  scores_list = scores.tolist()
109
 
110
  if return_scores:
111
- return ListenerOutput(
112
- consistent_programs,
113
- idxs,
114
- decoded_batch,
115
- scores_list
116
- )
117
  else:
118
  return ListenerOutput(consistent_programs)
119
-
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
2
  from dataclasses import dataclass
3
  from typing import List, Optional
4
+ from utils import (
5
+ get_preprocess_function,
6
+ get_utterance_processing_functions,
7
+ byt5_decode_batch,
8
+ consistent,
9
+ )
10
+ from utils import (
11
+ PROGRAM_SPECIAL_TOKEN,
12
+ UTTERANCES_SPECIAL_TOKEN,
13
+ GT_PROGRAM_SPECIAL_TOKEN,
14
+ )
15
  from greenery import parse
16
  from greenery.parse import NoMatch
17
  import numpy as np
18
  import torch
19
 
20
+
21
  class Agent:
22
+ def __init__(
23
+ self,
24
+ model_path: str,
25
+ gen_config: dict,
26
+ inference_batch_size: int = 1,
27
+ device=None,
28
+ ):
29
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
30
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
31
  self.gen_config = GenerationConfig(**gen_config)
32
  self.inference_batch_size = inference_batch_size
33
 
34
+
35
  @dataclass
36
  class ListenerOutput:
37
  programs: List[List[str]]
 
40
  decoded_scores: Optional[List[List[float]]] = None
41
  pruned: Optional[List[List[str]]] = None
42
 
43
+
44
  class Listener(Agent):
45
+ def __init__(
46
+ self,
47
  model_path,
48
+ gen_config,
49
  inference_batch_size=4,
50
  label_pos="suffix",
51
+ idx: bool = True,
52
  program_special_token=PROGRAM_SPECIAL_TOKEN,
53
+ utterances_special_token=UTTERANCES_SPECIAL_TOKEN,
54
+ device=None,
55
  ):
56
+ super().__init__(model_path, gen_config, inference_batch_size, device)
 
 
 
 
57
  self.label_pos = label_pos
58
  self.idx = idx
59
  self.program_special_token = program_special_token
 
61
  self.utterances_to_string, self.string_to_utterances = (
62
  get_utterance_processing_functions(
63
  label_pos, idx, separator=utterances_special_token
 
64
  )
65
+ )
66
  self.device = self.model.device
67
+
68
  def synthesize(self, context, return_scores=False, enforce_consistency=True):
69
  # If context is a list of utterances, convert to string
70
  if isinstance(context[0], list):
 
73
  context_str = context
74
 
75
  context_tokens = self.tokenizer(
76
+ [
77
+ (
78
+ f"{self.utterances_special_token}{c}"
79
+ if not c.startswith(self.utterances_special_token)
80
+ else c
81
+ )
82
+ for c in context_str
83
+ ],
84
  return_tensors="pt",
85
+ padding=True,
86
+ ).to(self.device)
87
+
88
  decoder_inputs = self.tokenizer(
89
+ [self.program_special_token for _ in context],
90
+ return_tensors="pt",
91
+ add_special_tokens=False,
92
+ ).to(self.device)
93
 
94
+ outputs = self.model.generate(
95
+ **context_tokens,
96
+ decoder_input_ids=decoder_inputs.input_ids,
97
+ generation_config=self.gen_config,
98
+ return_dict_in_generate=True,
99
+ output_scores=True,
100
+ )
101
 
102
+ decoded_batch = byt5_decode_batch(
103
+ outputs.sequences.reshape(
104
+ (len(context), -1, outputs.sequences.shape[-1])
105
+ ).tolist(),
106
+ skip_position_token=True,
107
+ skip_special_tokens=True,
108
+ )
109
 
110
  consistent_programs = []
111
  idxs = []
 
120
  else:
121
  cp.append(p)
122
  idx.append(i)
123
+
124
  consistent_programs.append(cp)
125
  idxs.append(idx)
126
+
127
  logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1)
128
+ gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(
129
+ -1
130
+ )
131
  gen_probs.masked_fill_(gen_probs.isinf(), 0)
132
  scores = gen_probs.sum(-1)
133
  n_decoded = scores.shape[0]
 
136
  scores_list = scores.tolist()
137
 
138
  if return_scores:
139
+ return ListenerOutput(consistent_programs, idxs, decoded_batch, scores_list)
 
 
 
 
 
140
  else:
141
  return ListenerOutput(consistent_programs)