anuragshas commited on
Commit
f2874d4
·
1 Parent(s): da25b85

Initial Commit

Browse files
app.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from xlit_src import XlitEngine
3
+
4
+
5
+ def transliterate(input_text):
6
+ engine = XlitEngine()
7
+ result = engine.translit_sentence(input_text)
8
+ return result
9
+
10
+
11
+ input_box = gr.inputs.Textbox(type="str", label="Input Text")
12
+ target = gr.outputs.Textbox()
13
+
14
+ iface = gr.Interface(
15
+ transliterate,
16
+ input_box,
17
+ target,
18
+ title="English to Hindi Transliteration",
19
+ description='Model for Translating English to Hindi using a Character-level recurrent sequence-to-sequence trained with <a href="http://workshop.colips.org/news2018/dataset.html">NEWS2018 DATASET_04</a>',
20
+ article='Author: <a href="https://huggingface.co/anuragshas">Anurag Singh</a> . Using training and inference script from <a href="https://github.com/AI4Bharat/IndianNLP-Transliteration.git">AI4Bharat/IndianNLP-Transliteration</a>.',
21
+ examples=["Hi.", "Wait!", "Namaste"],
22
+ )
23
+
24
+ iface.launch(enable_queue=True)
models/default_lineup.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hi": {
3
+ "name" : "Hindi - हिंदी",
4
+ "eng_name": "hindi",
5
+ "script" : "hindi/hi_scripts.json",
6
+ "weight" : "hindi/hi_v1_model.pth"
7
+ }
8
+ }
models/hindi/hi_scripts.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "WARNING" : " !!! Do not modify the Order of Glyph List !!!",
3
+ "UNICODE" : {"name": "devanagari", "begin":2304, "end":2431},
4
+ "LANGUAGE": "hindi",
5
+
6
+ "glyphs" : [
7
+
8
+ "ऄ", "अ", "आ", "इ", "ई", "उ", "ऊ","ऍ", "ऎ", "ए", "ऐ",
9
+ "ऑ", "ऒ", "ओ", "औ","ऋ","ॠ","ऌ","ॡ","ॲ", "ॐ",
10
+ "क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ", "ट", "ठ", "ड", "ढ", "ण",
11
+ "त", "थ", "द", "ध", "न", "ऩ", "प", "फ", "ब", "भ", "म", "य", "र", "ऱ", "ल",
12
+ "ळ", "ऴ", "व", "श", "ष", "स", "ह", "क़", "ख़", "ग़", "ज़", "ड़", "ढ़", "फ़", "य़",
13
+ "्", "ा", "ि", "ी", "ु", "ू", "ॅ", "ॆ", "े", "ै", "ॉ", "ॊ", "ो", "ौ",
14
+ "ृ", "ॄ", "ॢ", "ॣ", "ँ", "ं", "ः", "़", "॑", "ऽ", "॥",
15
+ "\u200c", "\u200d"
16
+
17
+ ],
18
+
19
+ "numsym_map" : {
20
+ "0" : ["०"], "1" : ["१"], "2" : ["२"], "3" : ["३"], "4" : ["४"],
21
+ "5" : ["५"], "6" : ["६"], "7" : ["७"], "8" : ["८"], "9" : ["९"],
22
+ "." : ["।", "॰"]
23
+ }
24
+
25
+ }
models/hindi/hi_v1_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cca1ea5d19fd507934e175eba7868f02a71826a046345fa6f4fccc3058424881
3
+ size 40927419
models/hindi/hi_v2_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89d3dd4e5fa7ea355c194fce3ecce1fd5e953e08784db26cacbe5993d1cd4eae
3
+ size 40927419
xlit_src.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import random
5
+ import enum
6
+ import traceback
7
+
8
+ import os
9
+ import sys
10
+ import json
11
+
12
+ F_DIR = os.path.dirname(os.path.realpath(__file__))
13
+
14
+
15
+ class XlitError(enum.Enum):
16
+ lang_err = "Unsupported langauge ID requested ;( Please check available languages."
17
+ string_err = "String passed is incompatable ;("
18
+ internal_err = "Internal crash ;("
19
+ unknown_err = "Unknown Failure"
20
+ loading_err = "Loading failed ;( Check if metadata/paths are correctly configured."
21
+
22
+
23
+ class Encoder(nn.Module):
24
+ """
25
+ Simple RNN based encoder network
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ input_dim,
31
+ embed_dim,
32
+ hidden_dim,
33
+ rnn_type="gru",
34
+ layers=1,
35
+ bidirectional=False,
36
+ dropout=0,
37
+ device="cpu",
38
+ ):
39
+ super(Encoder, self).__init__()
40
+
41
+ self.input_dim = input_dim # src_vocab_sz
42
+ self.enc_embed_dim = embed_dim
43
+ self.enc_hidden_dim = hidden_dim
44
+ self.enc_rnn_type = rnn_type
45
+ self.enc_layers = layers
46
+ self.enc_directions = 2 if bidirectional else 1
47
+ self.device = device
48
+
49
+ self.embedding = nn.Embedding(self.input_dim, self.enc_embed_dim)
50
+
51
+ if self.enc_rnn_type == "gru":
52
+ self.enc_rnn = nn.GRU(
53
+ input_size=self.enc_embed_dim,
54
+ hidden_size=self.enc_hidden_dim,
55
+ num_layers=self.enc_layers,
56
+ bidirectional=bidirectional,
57
+ )
58
+ elif self.enc_rnn_type == "lstm":
59
+ self.enc_rnn = nn.LSTM(
60
+ input_size=self.enc_embed_dim,
61
+ hidden_size=self.enc_hidden_dim,
62
+ num_layers=self.enc_layers,
63
+ bidirectional=bidirectional,
64
+ )
65
+ else:
66
+ raise Exception("unknown RNN type mentioned")
67
+
68
+ def forward(self, x, x_sz, hidden=None):
69
+ """
70
+ x_sz: (batch_size, 1) - Unpadded sequence lengths used for pack_pad
71
+
72
+ Return:
73
+ output: (batch_size, max_length, hidden_dim)
74
+ hidden: (n_layer*num_directions, batch_size, hidden_dim) | if LSTM tuple -(h_n, c_n)
75
+
76
+ """
77
+ batch_sz = x.shape[0]
78
+ # x: batch_size, max_length, enc_embed_dim
79
+ x = self.embedding(x)
80
+
81
+ ## pack the padded data
82
+ # x: max_length, batch_size, enc_embed_dim -> for pack_pad
83
+ x = x.permute(1, 0, 2)
84
+ x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False) # unpad
85
+
86
+ # output: packed_size, batch_size, enc_embed_dim --> hidden from all timesteps
87
+ # hidden: n_layer**num_directions, batch_size, hidden_dim | if LSTM (h_n, c_n)
88
+ output, hidden = self.enc_rnn(x)
89
+
90
+ ## pad the sequence to the max length in the batch
91
+ # output: max_length, batch_size, enc_emb_dim*directions)
92
+ output, _ = nn.utils.rnn.pad_packed_sequence(output)
93
+
94
+ # output: batch_size, max_length, hidden_dim
95
+ output = output.permute(1, 0, 2)
96
+
97
+ return output, hidden
98
+
99
+
100
+ class Decoder(nn.Module):
101
+ """
102
+ Used as decoder stage
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ output_dim,
108
+ embed_dim,
109
+ hidden_dim,
110
+ rnn_type="gru",
111
+ layers=1,
112
+ use_attention=True,
113
+ enc_outstate_dim=None, # enc_directions * enc_hidden_dim
114
+ dropout=0,
115
+ device="cpu",
116
+ ):
117
+ super(Decoder, self).__init__()
118
+
119
+ self.output_dim = output_dim # tgt_vocab_sz
120
+ self.dec_hidden_dim = hidden_dim
121
+ self.dec_embed_dim = embed_dim
122
+ self.dec_rnn_type = rnn_type
123
+ self.dec_layers = layers
124
+ self.use_attention = use_attention
125
+ self.device = device
126
+ if self.use_attention:
127
+ self.enc_outstate_dim = enc_outstate_dim if enc_outstate_dim else hidden_dim
128
+ else:
129
+ self.enc_outstate_dim = 0
130
+
131
+ self.embedding = nn.Embedding(self.output_dim, self.dec_embed_dim)
132
+
133
+ if self.dec_rnn_type == "gru":
134
+ self.dec_rnn = nn.GRU(
135
+ input_size=self.dec_embed_dim
136
+ + self.enc_outstate_dim, # to concat attention_output
137
+ hidden_size=self.dec_hidden_dim, # previous Hidden
138
+ num_layers=self.dec_layers,
139
+ batch_first=True,
140
+ )
141
+ elif self.dec_rnn_type == "lstm":
142
+ self.dec_rnn = nn.LSTM(
143
+ input_size=self.dec_embed_dim
144
+ + self.enc_outstate_dim, # to concat attention_output
145
+ hidden_size=self.dec_hidden_dim, # previous Hidden
146
+ num_layers=self.dec_layers,
147
+ batch_first=True,
148
+ )
149
+ else:
150
+ raise Exception("unknown RNN type mentioned")
151
+
152
+ self.fc = nn.Sequential(
153
+ nn.Linear(self.dec_hidden_dim, self.dec_embed_dim),
154
+ nn.LeakyReLU(),
155
+ # nn.Linear(self.dec_embed_dim, self.dec_embed_dim), nn.LeakyReLU(), # removing to reduce size
156
+ nn.Linear(self.dec_embed_dim, self.output_dim),
157
+ )
158
+
159
+ ##----- Attention ----------
160
+ if self.use_attention:
161
+ self.W1 = nn.Linear(self.enc_outstate_dim, self.dec_hidden_dim)
162
+ self.W2 = nn.Linear(self.dec_hidden_dim, self.dec_hidden_dim)
163
+ self.V = nn.Linear(self.dec_hidden_dim, 1)
164
+
165
+ def attention(self, x, hidden, enc_output):
166
+ """
167
+ x: (batch_size, 1, dec_embed_dim) -> after Embedding
168
+ enc_output: batch_size, max_length, enc_hidden_dim *num_directions
169
+ hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n)
170
+ """
171
+
172
+ ## perform addition to calculate the score
173
+
174
+ # hidden_with_time_axis: batch_size, 1, hidden_dim
175
+ ## hidden_with_time_axis = hidden.permute(1, 0, 2) ## replaced with below 2lines
176
+ hidden_with_time_axis = torch.sum(hidden, axis=0)
177
+
178
+ hidden_with_time_axis = hidden_with_time_axis.unsqueeze(1)
179
+
180
+ # score: batch_size, max_length, hidden_dim
181
+ score = torch.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))
182
+
183
+ # attention_weights: batch_size, max_length, 1
184
+ # we get 1 at the last axis because we are applying score to self.V
185
+ attention_weights = torch.softmax(self.V(score), dim=1)
186
+
187
+ # context_vector shape after sum == (batch_size, hidden_dim)
188
+ context_vector = attention_weights * enc_output
189
+ context_vector = torch.sum(context_vector, dim=1)
190
+ # context_vector: batch_size, 1, hidden_dim
191
+ context_vector = context_vector.unsqueeze(1)
192
+
193
+ # attend_out (batch_size, 1, dec_embed_dim + hidden_size)
194
+ attend_out = torch.cat((context_vector, x), -1)
195
+
196
+ return attend_out, attention_weights
197
+
198
+ def forward(self, x, hidden, enc_output):
199
+ """
200
+ x: (batch_size, 1)
201
+ enc_output: batch_size, max_length, dec_embed_dim
202
+ hidden: n_layer, batch_size, hidden_size | lstm: (h_n, c_n)
203
+ """
204
+ if (hidden is None) and (self.use_attention is False):
205
+ raise Exception("No use of a decoder with No attention and No Hidden")
206
+
207
+ batch_sz = x.shape[0]
208
+
209
+ if hidden is None:
210
+ # hidden: n_layers, batch_size, hidden_dim
211
+ hid_for_att = torch.zeros(
212
+ (self.dec_layers, batch_sz, self.dec_hidden_dim)
213
+ ).to(self.device)
214
+ elif self.dec_rnn_type == "lstm":
215
+ hid_for_att = hidden[0] # h_n
216
+ else:
217
+ hid_for_att = hidden
218
+
219
+ # x (batch_size, 1, dec_embed_dim) -> after embedding
220
+ x = self.embedding(x)
221
+
222
+ if self.use_attention:
223
+ # x (batch_size, 1, dec_embed_dim + hidden_size) -> after attention
224
+ # aw: (batch_size, max_length, 1)
225
+ x, aw = self.attention(x, hid_for_att, enc_output)
226
+ else:
227
+ x, aw = x, 0
228
+
229
+ # passing the concatenated vector to the GRU
230
+ # output: (batch_size, n_layers, hidden_size)
231
+ # hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n)
232
+ output, hidden = (
233
+ self.dec_rnn(x, hidden) if hidden is not None else self.dec_rnn(x)
234
+ )
235
+
236
+ # output :shp: (batch_size * 1, hidden_size)
237
+ output = output.view(-1, output.size(2))
238
+
239
+ # output :shp: (batch_size * 1, output_dim)
240
+ output = self.fc(output)
241
+
242
+ return output, hidden, aw
243
+
244
+
245
+ class Seq2Seq(nn.Module):
246
+ """
247
+ Used to construct seq2seq architecture with encoder decoder objects
248
+ """
249
+
250
+ def __init__(
251
+ self, encoder, decoder, pass_enc2dec_hid=False, dropout=0, device="cpu"
252
+ ):
253
+ super(Seq2Seq, self).__init__()
254
+
255
+ self.encoder = encoder
256
+ self.decoder = decoder
257
+ self.device = device
258
+ self.pass_enc2dec_hid = pass_enc2dec_hid
259
+
260
+ if self.pass_enc2dec_hid:
261
+ assert (
262
+ decoder.dec_hidden_dim == encoder.enc_hidden_dim
263
+ ), "Hidden Dimension of encoder and decoder must be same, or unset `pass_enc2dec_hid`"
264
+ if decoder.use_attention:
265
+ assert (
266
+ decoder.enc_outstate_dim
267
+ == encoder.enc_directions * encoder.enc_hidden_dim
268
+ ), "Set `enc_out_dim` correctly in decoder"
269
+ assert (
270
+ self.pass_enc2dec_hid or decoder.use_attention
271
+ ), "No use of a decoder with No attention and No Hidden from Encoder"
272
+
273
+ def forward(self, src, tgt, src_sz, teacher_forcing_ratio=0):
274
+ """
275
+ src: (batch_size, sequence_len.padded)
276
+ tgt: (batch_size, sequence_len.padded)
277
+ src_sz: [batch_size, 1] - Unpadded sequence lengths
278
+ """
279
+ batch_size = tgt.shape[0]
280
+
281
+ # enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction)
282
+ # enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim)
283
+ enc_output, enc_hidden = self.encoder(src, src_sz)
284
+
285
+ if self.pass_enc2dec_hid:
286
+ # dec_hidden: dec_layers, batch_size , dec_hidden_dim
287
+ dec_hidden = enc_hidden
288
+ else:
289
+ # dec_hidden -> Will be initialized to zeros internally
290
+ dec_hidden = None
291
+
292
+ # pred_vecs: (batch_size, output_dim, sequence_sz) -> shape required for CELoss
293
+ pred_vecs = torch.zeros(batch_size, self.decoder.output_dim, tgt.size(1)).to(
294
+ self.device
295
+ )
296
+
297
+ # dec_input: (batch_size, 1)
298
+ dec_input = tgt[:, 0].unsqueeze(1) # initialize to start token
299
+ pred_vecs[:, 1, 0] = 1 # Initialize to start tokens all batches
300
+ for t in range(1, tgt.size(1)):
301
+ # dec_hidden: dec_layers, batch_size , dec_hidden_dim
302
+ # dec_output: batch_size, output_dim
303
+ # dec_input: (batch_size, 1)
304
+ dec_output, dec_hidden, _ = self.decoder(
305
+ dec_input,
306
+ dec_hidden,
307
+ enc_output,
308
+ )
309
+ pred_vecs[:, :, t] = dec_output
310
+
311
+ # # prediction: batch_size
312
+ prediction = torch.argmax(dec_output, dim=1)
313
+
314
+ # Teacher Forcing
315
+ if random.random() < teacher_forcing_ratio:
316
+ dec_input = tgt[:, t].unsqueeze(1)
317
+ else:
318
+ dec_input = prediction.unsqueeze(1)
319
+
320
+ return pred_vecs # (batch_size, output_dim, sequence_sz)
321
+
322
+ def inference(self, src, max_tgt_sz=50, debug=0):
323
+ """
324
+ single input only, No batch Inferencing
325
+ src: (sequence_len)
326
+ debug: if True will return attention weights also
327
+ """
328
+ batch_size = 1
329
+ start_tok = src[0]
330
+ end_tok = src[-1]
331
+ src_sz = torch.tensor([len(src)])
332
+ src_ = src.unsqueeze(0)
333
+
334
+ # enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction)
335
+ # enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim)
336
+ enc_output, enc_hidden = self.encoder(src_, src_sz)
337
+
338
+ if self.pass_enc2dec_hid:
339
+ # dec_hidden: dec_layers, batch_size , dec_hidden_dim
340
+ dec_hidden = enc_hidden
341
+ else:
342
+ # dec_hidden -> Will be initialized to zeros internally
343
+ dec_hidden = None
344
+
345
+ # pred_arr: (sequence_sz, 1) -> shape required for CELoss
346
+ pred_arr = torch.zeros(max_tgt_sz, 1).to(self.device)
347
+ if debug:
348
+ attend_weight_arr = torch.zeros(max_tgt_sz, len(src)).to(self.device)
349
+
350
+ # dec_input: (batch_size, 1)
351
+ dec_input = start_tok.view(1, 1) # initialize to start token
352
+ pred_arr[0] = start_tok.view(1, 1) # initialize to start token
353
+ for t in range(max_tgt_sz):
354
+ # dec_hidden: dec_layers, batch_size , dec_hidden_dim
355
+ # dec_output: batch_size, output_dim
356
+ # dec_input: (batch_size, 1)
357
+ dec_output, dec_hidden, aw = self.decoder(
358
+ dec_input,
359
+ dec_hidden,
360
+ enc_output,
361
+ )
362
+ # prediction :shp: (1,1)
363
+ prediction = torch.argmax(dec_output, dim=1)
364
+ dec_input = prediction.unsqueeze(1)
365
+ pred_arr[t] = prediction
366
+ if debug:
367
+ attend_weight_arr[t] = aw.squeeze(-1)
368
+
369
+ if torch.eq(prediction, end_tok):
370
+ break
371
+
372
+ if debug:
373
+ return pred_arr.squeeze(), attend_weight_arr
374
+ # pred_arr :shp: (sequence_len)
375
+ return pred_arr.squeeze().to(dtype=torch.long)
376
+
377
+ def active_beam_inference(self, src, beam_width=3, max_tgt_sz=50):
378
+ """Active beam Search based decoding
379
+ src: (sequence_len)
380
+ """
381
+
382
+ def _avg_score(p_tup):
383
+ """Used for Sorting
384
+ TODO: Dividing by length of sequence power alpha as hyperparam
385
+ """
386
+ return p_tup[0]
387
+
388
+ batch_size = 1
389
+ start_tok = src[0]
390
+ end_tok = src[-1]
391
+ src_sz = torch.tensor([len(src)])
392
+ src_ = src.unsqueeze(0)
393
+
394
+ # enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction)
395
+ # enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim)
396
+ enc_output, enc_hidden = self.encoder(src_, src_sz)
397
+
398
+ if self.pass_enc2dec_hid:
399
+ # dec_hidden: dec_layers, batch_size , dec_hidden_dim
400
+ init_dec_hidden = enc_hidden
401
+ else:
402
+ # dec_hidden -> Will be initialized to zeros internally
403
+ init_dec_hidden = None
404
+
405
+ # top_pred[][0] = Σ-log_softmax
406
+ # top_pred[][1] = sequence torch.tensor shape: (1)
407
+ # top_pred[][2] = dec_hidden
408
+ top_pred_list = [(0, start_tok.unsqueeze(0), init_dec_hidden)]
409
+
410
+ for t in range(max_tgt_sz):
411
+ cur_pred_list = []
412
+
413
+ for p_tup in top_pred_list:
414
+ if p_tup[1][-1] == end_tok:
415
+ cur_pred_list.append(p_tup)
416
+ continue
417
+
418
+ # dec_hidden: dec_layers, 1, hidden_dim
419
+ # dec_output: 1, output_dim
420
+ dec_output, dec_hidden, _ = self.decoder(
421
+ x=p_tup[1][-1].view(1, 1), # dec_input: (1,1)
422
+ hidden=p_tup[2],
423
+ enc_output=enc_output,
424
+ )
425
+
426
+ ## π{prob} = Σ{log(prob)} -> to prevent diminishing
427
+ # dec_output: (1, output_dim)
428
+ dec_output = nn.functional.log_softmax(dec_output, dim=1)
429
+ # pred_topk.values & pred_topk.indices: (1, beam_width)
430
+ pred_topk = torch.topk(dec_output, k=beam_width, dim=1)
431
+
432
+ for i in range(beam_width):
433
+ sig_logsmx_ = p_tup[0] + pred_topk.values[0][i]
434
+ # seq_tensor_ : (seq_len)
435
+ seq_tensor_ = torch.cat((p_tup[1], pred_topk.indices[0][i].view(1)))
436
+
437
+ cur_pred_list.append((sig_logsmx_, seq_tensor_, dec_hidden))
438
+
439
+ cur_pred_list.sort(key=_avg_score, reverse=True) # Maximized order
440
+ top_pred_list = cur_pred_list[:beam_width]
441
+
442
+ # check if end_tok of all topk
443
+ end_flags_ = [1 if t[1][-1] == end_tok else 0 for t in top_pred_list]
444
+ if beam_width == sum(end_flags_):
445
+ break
446
+
447
+ pred_tnsr_list = [t[1] for t in top_pred_list]
448
+
449
+ return pred_tnsr_list
450
+
451
+ def passive_beam_inference(self, src, beam_width=7, max_tgt_sz=50):
452
+ """
453
+ Passive Beam search based inference
454
+ src: (sequence_len)
455
+ """
456
+
457
+ def _avg_score(p_tup):
458
+ """Used for Sorting
459
+ TODO: Dividing by length of sequence power alpha as hyperparam
460
+ """
461
+ return p_tup[0]
462
+
463
+ def _beam_search_topk(topk_obj, start_tok, beam_width):
464
+ """search for sequence with maxim prob
465
+ topk_obj[x]: .values & .indices shape:(1, beam_width)
466
+ """
467
+ # top_pred_list[x]: tuple(prob, seq_tensor)
468
+ top_pred_list = [
469
+ (0, start_tok.unsqueeze(0)),
470
+ ]
471
+
472
+ for obj in topk_obj:
473
+ new_lst_ = list()
474
+ for itm in top_pred_list:
475
+ for i in range(beam_width):
476
+ sig_logsmx_ = itm[0] + obj.values[0][i]
477
+ seq_tensor_ = torch.cat((itm[1], obj.indices[0][i].view(1)))
478
+ new_lst_.append((sig_logsmx_, seq_tensor_))
479
+
480
+ new_lst_.sort(key=_avg_score, reverse=True)
481
+ top_pred_list = new_lst_[:beam_width]
482
+ return top_pred_list
483
+
484
+ batch_size = 1
485
+ start_tok = src[0]
486
+ end_tok = src[-1]
487
+ src_sz = torch.tensor([len(src)])
488
+ src_ = src.unsqueeze(0)
489
+
490
+ enc_output, enc_hidden = self.encoder(src_, src_sz)
491
+
492
+ if self.pass_enc2dec_hid:
493
+ # dec_hidden: dec_layers, batch_size , dec_hidden_dim
494
+ dec_hidden = enc_hidden
495
+ else:
496
+ # dec_hidden -> Will be initialized to zeros internally
497
+ dec_hidden = None
498
+
499
+ # dec_input: (1, 1)
500
+ dec_input = start_tok.view(1, 1) # initialize to start token
501
+
502
+ topk_obj = []
503
+ for t in range(max_tgt_sz):
504
+ dec_output, dec_hidden, aw = self.decoder(
505
+ dec_input,
506
+ dec_hidden,
507
+ enc_output,
508
+ )
509
+
510
+ ## π{prob} = Σ{log(prob)} -> to prevent diminishing
511
+ # dec_output: (1, output_dim)
512
+ dec_output = nn.functional.log_softmax(dec_output, dim=1)
513
+ # pred_topk.values & pred_topk.indices: (1, beam_width)
514
+ pred_topk = torch.topk(dec_output, k=beam_width, dim=1)
515
+
516
+ topk_obj.append(pred_topk)
517
+
518
+ # dec_input: (1, 1)
519
+ dec_input = pred_topk.indices[0][0].view(1, 1)
520
+ if torch.eq(dec_input, end_tok):
521
+ break
522
+
523
+ top_pred_list = _beam_search_topk(topk_obj, start_tok, beam_width)
524
+ pred_tnsr_list = [t[1] for t in top_pred_list]
525
+
526
+ return pred_tnsr_list
527
+
528
+
529
+ class GlyphStrawboss:
530
+ def __init__(self, glyphs="en"):
531
+ """list of letters in a language in unicode
532
+ lang: List with unicodes
533
+ """
534
+ if glyphs == "en":
535
+ # Smallcase alone
536
+ self.glyphs = [chr(alpha) for alpha in range(97, 123)] + ["é", "è", "á"]
537
+ else:
538
+ self.dossier = json.load(open(glyphs, encoding="utf-8"))
539
+ self.numsym_map = self.dossier["numsym_map"]
540
+ self.glyphs = self.dossier["glyphs"]
541
+
542
+ self.indoarab_num = [chr(alpha) for alpha in range(48, 58)]
543
+
544
+ self.char2idx = {}
545
+ self.idx2char = {}
546
+ self._create_index()
547
+
548
+ def _create_index(self):
549
+
550
+ self.char2idx["_"] = 0 # pad
551
+ self.char2idx["$"] = 1 # start
552
+ self.char2idx["#"] = 2 # end
553
+ self.char2idx["*"] = 3 # Mask
554
+ self.char2idx["'"] = 4 # apostrophe U+0027
555
+ self.char2idx["%"] = 5 # unused
556
+ self.char2idx["!"] = 6 # unused
557
+ self.char2idx["?"] = 7
558
+ self.char2idx[":"] = 8
559
+ self.char2idx[" "] = 9
560
+ self.char2idx["-"] = 10
561
+ self.char2idx[","] = 11
562
+ self.char2idx["."] = 12
563
+ self.char2idx["("] = 13
564
+ self.char2idx[")"] = 14
565
+ self.char2idx["/"] = 15
566
+ self.char2idx["^"] = 16
567
+
568
+ for idx, char in enumerate(self.indoarab_num):
569
+ self.char2idx[char] = idx + 17
570
+ # letter to index mapping
571
+ for idx, char in enumerate(self.glyphs):
572
+ self.char2idx[char] = idx + 27 # +20 token initially
573
+
574
+ # index to letter mapping
575
+ for char, idx in self.char2idx.items():
576
+ self.idx2char[idx] = char
577
+
578
+ def size(self):
579
+ return len(self.char2idx)
580
+
581
+ def word2xlitvec(self, word):
582
+ """Converts given string of gyphs(word) to vector(numpy)
583
+ Also adds tokens for start and end
584
+ """
585
+ try:
586
+ vec = [self.char2idx["$"]] # start token
587
+ for i in list(word):
588
+ vec.append(self.char2idx[i])
589
+ vec.append(self.char2idx["#"]) # end token
590
+
591
+ vec = np.asarray(vec, dtype=np.int64)
592
+ return vec
593
+
594
+ except Exception as error:
595
+ print("Error In word:", word, "Error Char not in Token:", error)
596
+ sys.exit()
597
+
598
+ def xlitvec2word(self, vector):
599
+ """Converts vector(numpy) to string of glyphs(word)"""
600
+ char_list = []
601
+ for i in vector:
602
+ char_list.append(self.idx2char[i])
603
+
604
+ word = "".join(char_list).replace("$", "").replace("#", "") # remove tokens
605
+ word = word.replace("_", "").replace("*", "") # remove tokens
606
+ return word
607
+
608
+
609
+ class XlitPiston:
610
+ """
611
+ For handling prediction & post-processing of transliteration for a single language
612
+ Class dependency: Seq2Seq, GlyphStrawboss
613
+ Global Variables: F_DIR
614
+ """
615
+
616
+ def __init__(
617
+ self, weight_path, tglyph_cfg_file, iglyph_cfg_file="en", device="cpu"
618
+ ):
619
+
620
+ self.device = device
621
+ self.in_glyph_obj = GlyphStrawboss(iglyph_cfg_file)
622
+ self.tgt_glyph_obj = GlyphStrawboss(glyphs=tglyph_cfg_file)
623
+
624
+ self._numsym_set = set(
625
+ json.load(open(tglyph_cfg_file, encoding="utf-8"))["numsym_map"].keys()
626
+ )
627
+ self._inchar_set = set("abcdefghijklmnopqrstuvwxyzéèá")
628
+ self._natscr_set = set().union(
629
+ self.tgt_glyph_obj.glyphs, sum(self.tgt_glyph_obj.numsym_map.values(), [])
630
+ )
631
+
632
+ ## Model Config Static TODO: add defining in json support
633
+ input_dim = self.in_glyph_obj.size()
634
+ output_dim = self.tgt_glyph_obj.size()
635
+ enc_emb_dim = 300
636
+ dec_emb_dim = 300
637
+ enc_hidden_dim = 512
638
+ dec_hidden_dim = 512
639
+ rnn_type = "lstm"
640
+ enc2dec_hid = True
641
+ attention = True
642
+ enc_layers = 1
643
+ dec_layers = 2
644
+ m_dropout = 0
645
+ enc_bidirect = True
646
+ enc_outstate_dim = enc_hidden_dim * (2 if enc_bidirect else 1)
647
+
648
+ enc = Encoder(
649
+ input_dim=input_dim,
650
+ embed_dim=enc_emb_dim,
651
+ hidden_dim=enc_hidden_dim,
652
+ rnn_type=rnn_type,
653
+ layers=enc_layers,
654
+ dropout=m_dropout,
655
+ device=self.device,
656
+ bidirectional=enc_bidirect,
657
+ )
658
+ dec = Decoder(
659
+ output_dim=output_dim,
660
+ embed_dim=dec_emb_dim,
661
+ hidden_dim=dec_hidden_dim,
662
+ rnn_type=rnn_type,
663
+ layers=dec_layers,
664
+ dropout=m_dropout,
665
+ use_attention=attention,
666
+ enc_outstate_dim=enc_outstate_dim,
667
+ device=self.device,
668
+ )
669
+ self.model = Seq2Seq(enc, dec, pass_enc2dec_hid=enc2dec_hid, device=self.device)
670
+ self.model = self.model.to(self.device)
671
+ weights = torch.load(weight_path, map_location=torch.device(self.device))
672
+
673
+ self.model.load_state_dict(weights)
674
+ self.model.eval()
675
+
676
+ def character_model(self, word, beam_width=1):
677
+ in_vec = torch.from_numpy(self.in_glyph_obj.word2xlitvec(word)).to(self.device)
678
+ ## change to active or passive beam
679
+ p_out_list = self.model.active_beam_inference(in_vec, beam_width=beam_width)
680
+ result = [
681
+ self.tgt_glyph_obj.xlitvec2word(out.cpu().numpy()) for out in p_out_list
682
+ ]
683
+
684
+ # List type
685
+ return result
686
+
687
+ def numsym_model(self, seg):
688
+ """tgt_glyph_obj.numsym_map[x] returns a list object"""
689
+ if len(seg) == 1:
690
+ return [seg] + self.tgt_glyph_obj.numsym_map[seg]
691
+
692
+ a = [self.tgt_glyph_obj.numsym_map[n][0] for n in seg]
693
+ return [seg] + ["".join(a)]
694
+
695
+ def _word_segementer(self, sequence):
696
+
697
+ sequence = sequence.lower()
698
+ accepted = set().union(self._numsym_set, self._inchar_set, self._natscr_set)
699
+ # sequence = ''.join([i for i in sequence if i in accepted])
700
+
701
+ segment = []
702
+ idx = 0
703
+ seq_ = list(sequence)
704
+ while len(seq_):
705
+ # for Number-Symbol
706
+ temp = ""
707
+ while len(seq_) and seq_[0] in self._numsym_set:
708
+ temp += seq_[0]
709
+ seq_.pop(0)
710
+ if temp != "":
711
+ segment.append(temp)
712
+
713
+ # for Target Chars
714
+ temp = ""
715
+ while len(seq_) and seq_[0] in self._natscr_set:
716
+ temp += seq_[0]
717
+ seq_.pop(0)
718
+ if temp != "":
719
+ segment.append(temp)
720
+
721
+ # for Input-Roman Chars
722
+ temp = ""
723
+ while len(seq_) and seq_[0] in self._inchar_set:
724
+ temp += seq_[0]
725
+ seq_.pop(0)
726
+ if temp != "":
727
+ segment.append(temp)
728
+
729
+ temp = ""
730
+ while len(seq_) and seq_[0] not in accepted:
731
+ temp += seq_[0]
732
+ seq_.pop(0)
733
+ if temp != "":
734
+ segment.append(temp)
735
+
736
+ return segment
737
+
738
+ def inferencer(self, sequence, beam_width=10):
739
+
740
+ seg = self._word_segementer(sequence[:120])
741
+ lit_seg = []
742
+
743
+ p = 0
744
+ while p < len(seg):
745
+ if seg[p][0] in self._natscr_set:
746
+ lit_seg.append([seg[p]])
747
+ p += 1
748
+
749
+ elif seg[p][0] in self._inchar_set:
750
+ lit_seg.append(self.character_model(seg[p], beam_width=beam_width))
751
+ p += 1
752
+
753
+ elif seg[p][0] in self._numsym_set: # num & punc
754
+ lit_seg.append(self.numsym_model(seg[p]))
755
+ p += 1
756
+ else:
757
+ lit_seg.append([seg[p]])
758
+ p += 1
759
+
760
+ ## IF segment less/equal to 2 then return combinotorial,
761
+ ## ELSE only return top1 of each result concatenated
762
+ if len(lit_seg) == 1:
763
+ final_result = lit_seg[0]
764
+
765
+ elif len(lit_seg) == 2:
766
+ final_result = [""]
767
+ for seg in lit_seg:
768
+ new_result = []
769
+ for s in seg:
770
+ for f in final_result:
771
+ new_result.append(f + s)
772
+ final_result = new_result
773
+
774
+ else:
775
+ new_result = []
776
+ for seg in lit_seg:
777
+ new_result.append(seg[0])
778
+ final_result = ["".join(new_result)]
779
+
780
+ return final_result
781
+
782
+
783
+ class XlitEngine:
784
+ """
785
+ For Managing the top level tasks and applications of transliteration
786
+ Global Variables: F_DIR
787
+ """
788
+
789
+ def __init__(self, lang2use="hi", config_path="models/default_lineup.json"):
790
+ lineup = json.load(open(os.path.join(F_DIR, config_path), encoding="utf-8"))
791
+ models_path = os.path.join(F_DIR, "models")
792
+ self.lang_config = {}
793
+ if lang2use in lineup:
794
+ self.lang_config[lang2use] = lineup[lang2use]
795
+ else:
796
+ raise Exception(
797
+ "XlitError: The entered Langauge code not found. Available are {}".format(
798
+ lineup.keys()
799
+ )
800
+ )
801
+ self.langs = {}
802
+ self.lang_model = {}
803
+ for la in self.lang_config:
804
+ try:
805
+ print("Loading {}...".format(la))
806
+ self.lang_model[la] = XlitPiston(
807
+ weight_path=os.path.join(
808
+ models_path, self.lang_config[la]["weight"]
809
+ ),
810
+ tglyph_cfg_file=os.path.join(
811
+ models_path, self.lang_config[la]["script"]
812
+ ),
813
+ iglyph_cfg_file="en",
814
+ )
815
+ self.langs[la] = self.lang_config[la]["name"]
816
+ except Exception as error:
817
+ print("XlitError: Failure in loading {} \n".format(la), error)
818
+ print(XlitError.loading_err.value)
819
+
820
+ def translit_word(self, eng_word, lang_code="hi", topk=7, beam_width=10):
821
+ if eng_word == "":
822
+ return []
823
+ if lang_code in self.langs:
824
+ try:
825
+ res_list = self.lang_model[lang_code].inferencer(
826
+ eng_word, beam_width=beam_width
827
+ )
828
+ return res_list[:topk]
829
+
830
+ except Exception as error:
831
+ print("XlitError:", traceback.format_exc())
832
+ print(XlitError.internal_err.value)
833
+ return XlitError.internal_err
834
+ else:
835
+ print("XlitError: Unknown Langauge requested", lang_code)
836
+ print(XlitError.lang_err.value)
837
+ return XlitError.lang_err
838
+
839
+ def translit_sentence(self, eng_sentence, lang_code="hi", beam_width=10):
840
+ if eng_sentence == "":
841
+ return []
842
+
843
+ if lang_code in self.langs:
844
+ try:
845
+ out_str = ""
846
+ for word in eng_sentence.split():
847
+ res_ = self.lang_model[lang_code].inferencer(
848
+ word, beam_width=beam_width
849
+ )
850
+ out_str = out_str + res_[0] + " "
851
+ return out_str[:-1]
852
+
853
+ except Exception as error:
854
+ print("XlitError:", traceback.format_exc())
855
+ print(XlitError.internal_err.value)
856
+ return XlitError.internal_err
857
+
858
+ else:
859
+ print("XlitError: Unknown Langauge requested", lang_code)
860
+ print(XlitError.lang_err.value)
861
+ return XlitError.lang_err
862
+
863
+
864
+ if __name__ == "__main__":
865
+
866
+ engine = XlitEngine()
867
+ y = engine.translit_sentence("Hello World !")
868
+ print(y)