nirraviv89 commited on
Commit
6f8733e
1 Parent(s): 5d70d88

add model src code

Browse files
Files changed (4) hide show
  1. requirements.txt +3 -0
  2. src/config.py +46 -0
  3. src/inference.py +188 -0
  4. src/models.py +24 -0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy==1.23.5
2
+ torch==2.2.2
3
+ transformers==4.44.2
src/config.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+
4
+ class CustomBertConfig(BertConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`BertForPunctuation`]. It is based on BERT config
7
+ to the specified arguments, defining the model architecture.
8
+ Args:
9
+ backward_context (`int`, *optional*, defaults to 15):
10
+ size of backward context window
11
+ forward_context (`int`, *optional*, defaults to 16):
12
+ size of forward context window
13
+ output_size (`int`, *optional*, defaults to 4):
14
+ number of punctuation classes
15
+ dropout (`float`, *optional*, defaults to 0.3):
16
+ dropout rate
17
+
18
+ Examples:
19
+ ```python
20
+ >>> from transformers import BertConfig, BertModel
21
+
22
+ >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
23
+ >>> configuration = CustomBertConfig()
24
+
25
+ >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
26
+ >>> model = BertForPunctuation(configuration)
27
+
28
+ >>> # Accessing the model configuration
29
+ >>> configuration = model.config
30
+ ```"""
31
+
32
+ model_type = "custom_bert"
33
+
34
+ def __init__(
35
+ self,
36
+ backward_context=15,
37
+ forward_context=16,
38
+ output_size=4,
39
+ dropout=0.3,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.backward_context = backward_context
44
+ self.forward_context = forward_context
45
+ self.output_size = output_size
46
+ self.dropout = dropout
src/inference.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import BertTokenizer
7
+
8
+ from src.models import BertForPunctuation
9
+
10
+ PUNCTUATION_SIGNS = ['', ',', '.', '?']
11
+
12
+
13
+ def tokenize_text(
14
+ word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer
15
+ ) -> Tuple[List[int], List[int], List[float]]:
16
+ """
17
+ Tokenizes text and generates pause list for each word
18
+ Args:
19
+ word_list: list of words
20
+ pause_list: list of pauses after each word in seconds
21
+ tokenizer: tokenizer
22
+
23
+ Returns:
24
+ original_word_idx: list of indexes of original words
25
+ x: list of indexed words
26
+ pause: list of pauses after each word in seconds
27
+ """
28
+ assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length"
29
+ x, pause = [], []
30
+
31
+ # when we do tokenization the number of tokens might be more than one for single word, so we need to keep
32
+ # mapping tokens into real words
33
+ original_word_idx = []
34
+ for w, p in zip(word_list, pause_list):
35
+ tokens = tokenizer.tokenize(w)
36
+ p = [p]
37
+ # converting tokens to idx, if we have no token for current word then just pad it with 0 to be safe
38
+ _x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0]
39
+
40
+ if len(_x) > 1:
41
+ p = (len(_x) - 1) * [0] + p
42
+ x += _x
43
+ original_word_idx.append(len(x) - 1)
44
+ pause += p
45
+
46
+ return original_word_idx, x, pause
47
+
48
+
49
+ def gen_model_inputs(
50
+ x: List[int],
51
+ pause: List[float],
52
+ forward_context: int,
53
+ backward_context: int,
54
+ pause_tokens: Optional[Dict[Tuple, int]] = None,
55
+ ) -> torch.Tensor:
56
+ """
57
+ Generates inputs for model out of list of indexed words.
58
+ Inserts a pause token into the segment
59
+ Args:
60
+ x: list of indexed words
61
+ pause: list of corresponding pauses
62
+ forward_context: size of the forward context window
63
+ backward_context: size of the backward context window (without the pivot token)`
64
+ pause_tokens: dictionary of pause ranges and corresponding tokens from bert tokenizer
65
+
66
+ Returns:
67
+ A tensor of model inputs for each indexed word in x
68
+ """
69
+ if pause_tokens is None:
70
+ pause_tokens = {(-1000, 1000): 0}
71
+ model_input = []
72
+ tokenized_pause = []
73
+ x_pad = [0] * backward_context + x + [0] * forward_context
74
+
75
+ for i, p in enumerate(pause):
76
+ tokenized_pause.append(next(value for key, value in pause_tokens.items() if key[0] < p <= key[1]))
77
+
78
+ for i in range(len(x)):
79
+ segment = x_pad[i:i + backward_context + forward_context + 1]
80
+ segment.insert(backward_context + 1, tokenized_pause[i])
81
+ model_input.append(segment)
82
+ return torch.tensor(model_input)
83
+
84
+
85
+ def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str:
86
+ """
87
+ Inserts punctuation to text on provided punctuation string for every word
88
+ Args:
89
+ text: text to insert punctuation to
90
+ punct_prob: matrix of probabilities for each punctuation
91
+
92
+ Returns:
93
+ text with punctuation
94
+ """
95
+ words = text.split()
96
+ new_words = list()
97
+
98
+ punctuation_idx = np.argmax(punct_prob, axis=1)
99
+ punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx]
100
+
101
+ for word, punctuation_str in zip(words, punctuation_list):
102
+ if punctuation_str:
103
+ new_words.append(word + punctuation_str)
104
+ else:
105
+ new_words.append(word)
106
+
107
+ punct_text = ' '.join(new_words)
108
+ return punct_text
109
+
110
+
111
+ def get_prediction(
112
+ model: BertForPunctuation,
113
+ text: str,
114
+ tokenizer: BertTokenizer,
115
+ batch_size: int = 16,
116
+ backward_context: int = 15,
117
+ forward_context: int = 16,
118
+ pause_list: Optional[List[float]] = None,
119
+ device: str = 'cpu',
120
+ return_prob: bool = False,
121
+ ):
122
+ """
123
+ Generates predictions for given list of words.
124
+ Args:
125
+ model: punctuation model
126
+ text: text to predict punctuation for
127
+ tokenizer: tokenizer
128
+ batch_size: batch size
129
+ backward_context: size of the backward context window
130
+ forward_context: size of the forward context window
131
+ pause_list: list of pauses after each word in seconds
132
+ device: device to run model on
133
+ return_prob: if True returns probabilities, if False returns text with punctuation
134
+
135
+ Returns:
136
+ matrix of probabilities for each punctuation class or text with punctuation
137
+ """
138
+ word_list = text.split()
139
+ if not pause_list:
140
+ # make default pauses if pauses are not provided
141
+ pause_list = [0.0] * len(word_list)
142
+
143
+ # prepare text
144
+ # we need original word idx since after tokenize number of tokens might not be equal to number of words
145
+ word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer)
146
+
147
+ model_inputs = gen_model_inputs(x, pause, forward_context, backward_context)
148
+ model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device)
149
+ inputs_length = len(model_inputs)
150
+
151
+ output = []
152
+ with torch.no_grad():
153
+ for ndx in range(0, inputs_length, batch_size):
154
+ o = model(model_inputs[ndx: min(ndx + batch_size, inputs_length)])
155
+ o = F.softmax(o, dim=1)
156
+ output.append(o.cpu().data.numpy())
157
+
158
+ punct_probabilities_matrix = np.concatenate(output, axis=0)
159
+
160
+ if return_prob:
161
+ return punct_probabilities_matrix
162
+
163
+ punct_text = add_punctuation_to_text(text, punct_probabilities_matrix)
164
+
165
+ return punct_text
166
+
167
+
168
+ def main():
169
+ model = BertForPunctuation.from_pretrained("verbit/hebrew_punctuation")
170
+ tokenizer = BertTokenizer.from_pretrained("verbit/hebrew_punctuation")
171
+ model.eval()
172
+
173
+ text = ("讞讘专转 讜专讘讬讟 驻讬转讞讛 诪注专讻转 诇转诪诇讜诇 讛诪讘讜住住转 注诇 讘讬谞讛 诪诇讗讻讜转讬转 讜讙讜专诐 讗谞讜砖讬 讜砖讜拽讚转 注诇 转诪诇讜诇 注讚讜讬讜转 谞讬爪讜诇讬 砖讜讗讛 讗转 "
174
+ "讛转讜爪讗讜转 讗驻砖专 诇专讗讜转 讻讘专 讘专砖转 讘讛谉 讞诇拽讬诐 诪注讚讜转讜 砖诇 讟讜讘讬讛 讘讬讬诇住拽讬 砖讛讬讛 诪驻拽讚 讙讚讜讚 讛驻专讟讬讝谞讬诐 讛讬讛讜讚讬诐 "
175
+ "讘讘讬讬诇讜专讜住讬讛")
176
+ punct_text = get_prediction(
177
+ model=model,
178
+ text=text,
179
+ tokenizer=tokenizer,
180
+ backward_context=model.config.backward_context,
181
+ forward_context=model.config.forward_context,
182
+ return_prob=False
183
+ )
184
+ print(punct_text)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
src/models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import BertForMaskedLM, PreTrainedModel
3
+
4
+ from src.config import CustomBertConfig
5
+
6
+
7
+ class BertForPunctuation(PreTrainedModel):
8
+ config_class = CustomBertConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ # backward_context + forward_context + pivot token + pause token
13
+ segment_size = config.backward_context + config.forward_context + 2
14
+ bert_vocab_size = config.vocab_size
15
+ self.bert = BertForMaskedLM(config)
16
+ self.bn = nn.BatchNorm1d(segment_size * bert_vocab_size)
17
+ self.fc = nn.Linear(segment_size * bert_vocab_size, config.output_size)
18
+ self.dropout = nn.Dropout(config.dropout)
19
+
20
+ def forward(self, x):
21
+ x = self.bert(x)[0]
22
+ x = x.view(x.shape[0], -1)
23
+ x = self.fc(self.dropout(self.bn(x)))
24
+ return x