nirraviv89 commited on
Commit
fbc3442
โ€ข
1 Parent(s): 6f8733e

rename and documentation

Browse files
Files changed (3) hide show
  1. src/config.py +9 -11
  2. src/inference.py +26 -40
  3. src/models.py +3 -3
src/config.py CHANGED
@@ -1,9 +1,9 @@
1
  from transformers import BertConfig
2
 
3
 
4
- class CustomBertConfig(BertConfig):
5
  r"""
6
- This is the configuration class to store the configuration of a [`BertForPunctuation`]. It is based on BERT config
7
  to the specified arguments, defining the model architecture.
8
  Args:
9
  backward_context (`int`, *optional*, defaults to 15):
@@ -20,7 +20,7 @@ class CustomBertConfig(BertConfig):
20
  >>> from transformers import BertConfig, BertModel
21
 
22
  >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
23
- >>> configuration = CustomBertConfig()
24
 
25
  >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
26
  >>> model = BertForPunctuation(configuration)
@@ -29,15 +29,13 @@ class CustomBertConfig(BertConfig):
29
  >>> configuration = model.config
30
  ```"""
31
 
32
- model_type = "custom_bert"
33
-
34
  def __init__(
35
- self,
36
- backward_context=15,
37
- forward_context=16,
38
- output_size=4,
39
- dropout=0.3,
40
- **kwargs,
41
  ):
42
  super().__init__(**kwargs)
43
  self.backward_context = backward_context
 
1
  from transformers import BertConfig
2
 
3
 
4
+ class PunctuationBertConfig(BertConfig):
5
  r"""
6
+ This is the configuration class to store the configuration of a [`PunctuationBertConfig`]. It is based on BERT config
7
  to the specified arguments, defining the model architecture.
8
  Args:
9
  backward_context (`int`, *optional*, defaults to 15):
 
20
  >>> from transformers import BertConfig, BertModel
21
 
22
  >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
23
+ >>> configuration = PunctuationBertConfig()
24
 
25
  >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
26
  >>> model = BertForPunctuation(configuration)
 
29
  >>> configuration = model.config
30
  ```"""
31
 
 
 
32
  def __init__(
33
+ self,
34
+ backward_context=15,
35
+ forward_context=16,
36
+ output_size=4,
37
+ dropout=0.3,
38
+ **kwargs,
39
  ):
40
  super().__init__(**kwargs)
41
  self.backward_context = backward_context
src/inference.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Tuple
2
 
3
  import numpy as np
4
  import torch
@@ -8,10 +8,12 @@ from transformers import BertTokenizer
8
  from src.models import BertForPunctuation
9
 
10
  PUNCTUATION_SIGNS = ['', ',', '.', '?']
 
 
11
 
12
 
13
  def tokenize_text(
14
- word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer
15
  ) -> Tuple[List[int], List[int], List[float]]:
16
  """
17
  Tokenizes text and generates pause list for each word
@@ -47,11 +49,10 @@ def tokenize_text(
47
 
48
 
49
  def gen_model_inputs(
50
- x: List[int],
51
- pause: List[float],
52
- forward_context: int,
53
- backward_context: int,
54
- pause_tokens: Optional[Dict[Tuple, int]] = None,
55
  ) -> torch.Tensor:
56
  """
57
  Generates inputs for model out of list of indexed words.
@@ -60,23 +61,17 @@ def gen_model_inputs(
60
  x: list of indexed words
61
  pause: list of corresponding pauses
62
  forward_context: size of the forward context window
63
- backward_context: size of the backward context window (without the pivot token)`
64
- pause_tokens: dictionary of pause ranges and corresponding tokens from bert tokenizer
65
 
66
  Returns:
67
  A tensor of model inputs for each indexed word in x
68
  """
69
- if pause_tokens is None:
70
- pause_tokens = {(-1000, 1000): 0}
71
  model_input = []
72
- tokenized_pause = []
73
  x_pad = [0] * backward_context + x + [0] * forward_context
74
 
75
- for i, p in enumerate(pause):
76
- tokenized_pause.append(next(value for key, value in pause_tokens.items() if key[0] < p <= key[1]))
77
-
78
  for i in range(len(x)):
79
- segment = x_pad[i:i + backward_context + forward_context + 1]
80
  segment.insert(backward_context + 1, tokenized_pause[i])
81
  model_input.append(segment)
82
  return torch.tensor(model_input)
@@ -109,16 +104,15 @@ def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str:
109
 
110
 
111
  def get_prediction(
112
- model: BertForPunctuation,
113
- text: str,
114
- tokenizer: BertTokenizer,
115
- batch_size: int = 16,
116
- backward_context: int = 15,
117
- forward_context: int = 16,
118
- pause_list: Optional[List[float]] = None,
119
- device: str = 'cpu',
120
- return_prob: bool = False,
121
- ):
122
  """
123
  Generates predictions for given list of words.
124
  Args:
@@ -130,18 +124,15 @@ def get_prediction(
130
  forward_context: size of the forward context window
131
  pause_list: list of pauses after each word in seconds
132
  device: device to run model on
133
- return_prob: if True returns probabilities, if False returns text with punctuation
134
 
135
  Returns:
136
- matrix of probabilities for each punctuation class or text with punctuation
137
  """
138
  word_list = text.split()
139
  if not pause_list:
140
  # make default pauses if pauses are not provided
141
  pause_list = [0.0] * len(word_list)
142
 
143
- # prepare text
144
- # we need original word idx since after tokenize number of tokens might not be equal to number of words
145
  word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer)
146
 
147
  model_inputs = gen_model_inputs(x, pause, forward_context, backward_context)
@@ -151,35 +142,30 @@ def get_prediction(
151
  output = []
152
  with torch.no_grad():
153
  for ndx in range(0, inputs_length, batch_size):
154
- o = model(model_inputs[ndx: min(ndx + batch_size, inputs_length)])
155
  o = F.softmax(o, dim=1)
156
  output.append(o.cpu().data.numpy())
157
 
158
  punct_probabilities_matrix = np.concatenate(output, axis=0)
159
 
160
- if return_prob:
161
- return punct_probabilities_matrix
162
-
163
  punct_text = add_punctuation_to_text(text, punct_probabilities_matrix)
164
 
165
  return punct_text
166
 
167
 
168
  def main():
169
- model = BertForPunctuation.from_pretrained("verbit/hebrew_punctuation")
170
- tokenizer = BertTokenizer.from_pretrained("verbit/hebrew_punctuation")
171
  model.eval()
172
 
173
- text = ("ื—ื‘ืจืช ื•ืจื‘ื™ื˜ ืคื™ืชื—ื” ืžืขืจื›ืช ืœืชืžืœื•ืœ ื”ืžื‘ื•ืกืกืช ืขืœ ื‘ื™ื ื” ืžืœืื›ื•ืชื™ืช ื•ื’ื•ืจื ืื ื•ืฉื™ ื•ืฉื•ืงื“ืช ืขืœ ืชืžืœื•ืœ ืขื“ื•ื™ื•ืช ื ื™ืฆื•ืœื™ ืฉื•ืื” ืืช "
174
- "ื”ืชื•ืฆืื•ืช ืืคืฉืจ ืœืจืื•ืช ื›ื‘ืจ ื‘ืจืฉืช ื‘ื”ืŸ ื—ืœืงื™ื ืžืขื“ื•ืชื• ืฉืœ ื˜ื•ื‘ื™ื” ื‘ื™ื™ืœืกืงื™ ืฉื”ื™ื” ืžืคืงื“ ื’ื“ื•ื“ ื”ืคืจื˜ื™ื–ื ื™ื ื”ื™ื”ื•ื“ื™ื "
175
- "ื‘ื‘ื™ื™ืœื•ืจื•ืกื™ื”")
176
  punct_text = get_prediction(
177
  model=model,
178
  text=text,
179
  tokenizer=tokenizer,
180
  backward_context=model.config.backward_context,
181
  forward_context=model.config.forward_context,
182
- return_prob=False
183
  )
184
  print(punct_text)
185
 
 
1
+ from typing import List, Optional, Tuple
2
 
3
  import numpy as np
4
  import torch
 
8
  from src.models import BertForPunctuation
9
 
10
  PUNCTUATION_SIGNS = ['', ',', '.', '?']
11
+ PAUSE_TOKEN = 0
12
+ MODEL_NAME = "verbit/hebrew_punctuation"
13
 
14
 
15
  def tokenize_text(
16
+ word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer
17
  ) -> Tuple[List[int], List[int], List[float]]:
18
  """
19
  Tokenizes text and generates pause list for each word
 
49
 
50
 
51
  def gen_model_inputs(
52
+ x: List[int],
53
+ pause: List[float],
54
+ forward_context: int,
55
+ backward_context: int,
 
56
  ) -> torch.Tensor:
57
  """
58
  Generates inputs for model out of list of indexed words.
 
61
  x: list of indexed words
62
  pause: list of corresponding pauses
63
  forward_context: size of the forward context window
64
+ backward_context: size of the backward context window (without the predicted token)`
 
65
 
66
  Returns:
67
  A tensor of model inputs for each indexed word in x
68
  """
 
 
69
  model_input = []
70
+ tokenized_pause = [PAUSE_TOKEN] * len(pause)
71
  x_pad = [0] * backward_context + x + [0] * forward_context
72
 
 
 
 
73
  for i in range(len(x)):
74
+ segment = x_pad[i : i + backward_context + forward_context + 1]
75
  segment.insert(backward_context + 1, tokenized_pause[i])
76
  model_input.append(segment)
77
  return torch.tensor(model_input)
 
104
 
105
 
106
  def get_prediction(
107
+ model: BertForPunctuation,
108
+ text: str,
109
+ tokenizer: BertTokenizer,
110
+ batch_size: int = 16,
111
+ backward_context: int = 15,
112
+ forward_context: int = 16,
113
+ pause_list: Optional[List[float]] = None,
114
+ device: str = 'cpu',
115
+ ) -> str:
 
116
  """
117
  Generates predictions for given list of words.
118
  Args:
 
124
  forward_context: size of the forward context window
125
  pause_list: list of pauses after each word in seconds
126
  device: device to run model on
 
127
 
128
  Returns:
129
+ text with punctuation
130
  """
131
  word_list = text.split()
132
  if not pause_list:
133
  # make default pauses if pauses are not provided
134
  pause_list = [0.0] * len(word_list)
135
 
 
 
136
  word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer)
137
 
138
  model_inputs = gen_model_inputs(x, pause, forward_context, backward_context)
 
142
  output = []
143
  with torch.no_grad():
144
  for ndx in range(0, inputs_length, batch_size):
145
+ o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)])
146
  o = F.softmax(o, dim=1)
147
  output.append(o.cpu().data.numpy())
148
 
149
  punct_probabilities_matrix = np.concatenate(output, axis=0)
150
 
 
 
 
151
  punct_text = add_punctuation_to_text(text, punct_probabilities_matrix)
152
 
153
  return punct_text
154
 
155
 
156
  def main():
157
+ model = BertForPunctuation.from_pretrained(MODEL_NAME)
158
+ tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
159
  model.eval()
160
 
161
+ text = """ื—ื‘ืจืช ื•ืจื‘ื™ื˜ ืคื™ืชื—ื” ืžืขืจื›ืช ืœืชืžืœื•ืœ ื”ืžื‘ื•ืกืกืช ืขืœ ื‘ื™ื ื” ืžืœืื›ื•ืชื™ืช ื•ื’ื•ืจื ืื ื•ืฉื™ ื•ืฉื•ืงื“ืช ืขืœ ืชืžืœื•ืœ ืขื“ื•ื™ื•ืช ื ื™ืฆื•ืœื™ ืฉื•ืื”
162
+ ืืช ื”ืชื•ืฆืื•ืช ืืคืฉืจ ืœืจืื•ืช ื›ื‘ืจ ื‘ืจืฉืช ื‘ื”ืŸ ื—ืœืงื™ื ืžืขื“ื•ืชื• ืฉืœ ื˜ื•ื‘ื™ื” ื‘ื™ื™ืœืกืงื™ ืฉื”ื™ื” ืžืคืงื“ ื’ื“ื•ื“ ื”ืคืจื˜ื™ื–ื ื™ื ื”ื™ื”ื•ื“ื™ื ื‘ื‘ื™ื™ืœื•ืจื•ืกื™ื”"""
 
163
  punct_text = get_prediction(
164
  model=model,
165
  text=text,
166
  tokenizer=tokenizer,
167
  backward_context=model.config.backward_context,
168
  forward_context=model.config.forward_context,
 
169
  )
170
  print(punct_text)
171
 
src/models.py CHANGED
@@ -1,15 +1,15 @@
1
  from torch import nn
2
  from transformers import BertForMaskedLM, PreTrainedModel
3
 
4
- from src.config import CustomBertConfig
5
 
6
 
7
  class BertForPunctuation(PreTrainedModel):
8
- config_class = CustomBertConfig
9
 
10
  def __init__(self, config):
11
  super().__init__(config)
12
- # backward_context + forward_context + pivot token + pause token
13
  segment_size = config.backward_context + config.forward_context + 2
14
  bert_vocab_size = config.vocab_size
15
  self.bert = BertForMaskedLM(config)
 
1
  from torch import nn
2
  from transformers import BertForMaskedLM, PreTrainedModel
3
 
4
+ from src.config import PunctuationBertConfig
5
 
6
 
7
  class BertForPunctuation(PreTrainedModel):
8
+ config_class = PunctuationBertConfig
9
 
10
  def __init__(self, config):
11
  super().__init__(config)
12
+ # segment_size equal backward_context + forward_context + predicted token + pause token
13
  segment_size = config.backward_context + config.forward_context + 2
14
  bert_vocab_size = config.vocab_size
15
  self.bert = BertForMaskedLM(config)