kargaranamir commited on
Commit
2c9efe4
1 Parent(s): d8f31aa
Files changed (4) hide show
  1. README.md +5 -5
  2. app.py +78 -0
  3. masklid.py +268 -0
  4. requirements.txt +3 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: MaskLID
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MaskLID
3
+ emoji: 🐨
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
13
+ This code applies [MaskLID](https://arxiv.org/abs/2406.06263) with [GlotLID](https://arxiv.org/abs/2310.16248), a fasttext-based language identification tool.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Amir Hossein Kargaran
2
+ # Date: August, 2023
3
+
4
+ # Description: This code applies MaskLID (code-switch language identification) with GlotLID, a fastText-based language identification tool.
5
+
6
+ # MIT License
7
+
8
+ import gradio as gr
9
+ from masklid import MaskLID
10
+ from huggingface_hub import hf_hub_download
11
+ from fasttext.FastText import _FastText
12
+
13
+ def render_metadata():
14
+ """Renders the metadata."""
15
+ html_content = """
16
+ <p align="center">
17
+ <a href="https://github.com/cisnlp/MaskLID"><img alt="GitHub stars" src="https://img.shields.io/github/stars/cisnlp/MaskLID"></a>
18
+ This is the demo for <a href="https://arxiv.org/abs/2406.06263">MaskLID</a> paper (ACL 2024). You can see the whole code in our GitHub. Please also note that if you increase the number of languages, you also need larger alpha and beta values.
19
+ MaskLID does not add much overhead to language identification. You first fix the languages your model is limited to and then run the MaskLID code. However, in this demo, we load the model each time (that takes couple of seconds) you hit submit to ensure the results are not cached and to make it possible to change the set of languages each time. We may later change the demo code to resolve this.
20
+ </p>
21
+ """
22
+ return html_content
23
+
24
+
25
+ def get_model_path():
26
+ # Download GlotLID FastText language identification model from Hugging Face Hub
27
+ model_path = hf_hub_download(repo_id="cis-lmu/glotlid", filename="model_v3.bin")
28
+ return model_path
29
+
30
+
31
+ def get_masklid():
32
+ # load masklid model
33
+ masklid_model = MaskLID(get_model_path())
34
+
35
+ # get all the labels
36
+ labels = masklid_model.model.get_labels()
37
+ labels = [l for l in labels if not l.startswith('__label__und') and not l.startswith('__label__zxx')]
38
+
39
+ return masklid_model, labels
40
+
41
+ def predict_codeswitch(text, top_labels=200, beta=20, alpha=3, max_lambda=3, min_length=10, min_prob=0.90, max_retry=3, alpha_step_increase=3, beta_step_increase=5):
42
+
43
+ # constraints
44
+ beta = top_labels if beta > top_labels else beta
45
+ alpha = beta if alpha > beta else alpha
46
+
47
+ # override the masklid label set
48
+ masklid_model, labels = get_masklid()
49
+ masklid_model.language_indices = masklid_model._compute_language_indices(labels[:top_labels])
50
+ masklid_model.labels = [masklid_model.model.get_labels()[i] for i in masklid_model.language_indices]
51
+
52
+ ans = masklid_model.predict_codeswitch(text, beta=beta, alpha=alpha, max_lambda=max_lambda, min_length=min_length, min_prob=min_prob, max_retry=max_retry, alpha_step_increase=alpha_step_increase, beta_step_increase=beta_step_increase)
53
+
54
+ return ans
55
+
56
+ inputs = gr.Textbox(lines=2, label="Enter the text", value="bir kahve dükkanında geçen film tadında güzel bir şarkıya ayrılsın gece falling in love at a coffee shop")
57
+ parameters = {
58
+ "top_labels": gr.Slider(minimum=2, maximum=len(get_masklid()[1]), step=1, value=200, label="Limit LID to X Top Languages"),
59
+ "beta": gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Beta"),
60
+ "alpha": gr.Slider(minimum=1, maximum=30, value=3, step=1, label="Alpha"),
61
+ "max_lambda": gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Iteration"),
62
+ "min_length": gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Min Length"),
63
+ "min_prob": gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Min Probability"),
64
+ "max_retry": gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Retry In total"),
65
+ "alpha_step_increase": gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Alpha Step Increase"),
66
+ "beta_step_increase": gr.Slider(minimum=1, maximum=15, value=5, step=1, label="Beta Step Increase")
67
+ }
68
+
69
+ output = gr.JSON(label="Output")
70
+
71
+ gr.Interface(
72
+ fn=predict_codeswitch,
73
+ inputs=[inputs, *parameters.values()],
74
+ outputs=output,
75
+ title="MaskLID (Code-Switch Language Identification)",
76
+ description = render_metadata(),
77
+ cache_examples=False
78
+ ).launch()
masklid.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fasttext
2
+ import numpy as np
3
+ import re
4
+ import string
5
+ from copy import deepcopy
6
+
7
+ class MaskLID:
8
+ """A class for code-switching language identification using iterative masking."""
9
+
10
+ def __init__(self, model_path, languages=-1):
11
+ """Initialize the MaskLID class.
12
+
13
+ Args:
14
+ model_path (str): The path to the fastText model.
15
+ languages (int or list, optional): The indices or list of language labels to consider. Defaults to -1.
16
+ """
17
+ self.model = fasttext.load_model(model_path)
18
+ self.output_matrix = self.model.get_output_matrix()
19
+ self.labels = self.model.get_labels()
20
+ self.language_indices = self._compute_language_indices(languages)
21
+ self.labels = [self.labels[i] for i in self.language_indices]
22
+
23
+ def _compute_language_indices(self, languages):
24
+ """Compute indices of selected languages.
25
+
26
+ Args:
27
+ languages (int or list): The indices or list of language labels.
28
+
29
+ Returns:
30
+ list: Indices of selected languages.
31
+ """
32
+ if languages != -1 and isinstance(languages, list):
33
+ return [self.labels.index(l) for l in set(languages) if l in self.labels]
34
+ return list(range(len(self.labels)))
35
+
36
+ def _softmax(self, x):
37
+ """Compute softmax values for each score in array x.
38
+
39
+ Args:
40
+ x (numpy.ndarray): Input array.
41
+
42
+ Returns:
43
+ numpy.ndarray: Softmax output.
44
+ """
45
+ exp_x = np.exp(x - np.max(x))
46
+ return exp_x / np.sum(exp_x)
47
+
48
+ def _normalize_text(self, text):
49
+ """Normalize input text.
50
+
51
+ Args:
52
+ text (str): Input text.
53
+
54
+ Returns:
55
+ str: Normalized text.
56
+ """
57
+ replace_by = " "
58
+ replacement_map = {ord(c): replace_by for c in '_:' + '•#{|}' + string.digits}
59
+ text = text.replace('\n', replace_by)
60
+ text = text.translate(replacement_map)
61
+ return re.sub(r'\s+', replace_by, text).strip()
62
+
63
+ def predict(self, text, k=1):
64
+ """Predict the language of the input text.
65
+
66
+ Args:
67
+ text (str): Input text.
68
+ k (int, optional): Number of top predictions to retrieve. Defaults to 1.
69
+
70
+ Returns:
71
+ tuple: Top predicted labels and their probabilities.
72
+ """
73
+ sentence_vector = self.model.get_sentence_vector(text)
74
+ result_vector = np.dot(self.output_matrix, sentence_vector)
75
+ softmax_result = self._softmax(result_vector)[self.language_indices]
76
+ top_k_indices = np.argsort(softmax_result)[-k:][::-1]
77
+ top_k_labels = [self.labels[i] for i in top_k_indices]
78
+ top_k_probs = softmax_result[top_k_indices]
79
+ return tuple(top_k_labels), top_k_probs
80
+
81
+ def compute_v(self, sentence_vector):
82
+ """Compute the language vectors for a given sentence vector.
83
+
84
+ Args:
85
+ sentence_vector (numpy.ndarray): Sentence vector.
86
+
87
+ Returns:
88
+ list: Sorted list of labels and their associated vectors.
89
+ """
90
+ result_vector = np.dot(self.output_matrix[self.language_indices, :], sentence_vector)
91
+ return sorted(zip(self.labels, result_vector), key=lambda x: x[1], reverse=True)
92
+
93
+ def compute_v_per_word(self, text):
94
+ """Compute language vectors for each word in the input text.
95
+
96
+ Args:
97
+ text (str): Input text.
98
+
99
+ Returns:
100
+ dict: Dictionary containing language vectors for each word.
101
+ """
102
+ text = self._normalize_text(text)
103
+ words = self.model.get_line(text)[0]
104
+ words = [w for w in words if w not in ['</s>', '</s>']]
105
+ subword_ids = [self.model.get_subwords(sw)[1] for sw in words]
106
+ sentence_vector = [np.sum([self.model.get_input_vector(id) for id in sid], axis=0) for sid in subword_ids]
107
+
108
+ dict_text = {}
109
+ for i, word in enumerate(words):
110
+ key = f"{i}_{word}"
111
+ dict_text[key] = {'logits': self.compute_v(sentence_vector[i])}
112
+
113
+ return dict_text
114
+
115
+ def mask_label_top_k(self, dict_text, label, top_keep, top_remove):
116
+ """Mask top predictions for a given label.
117
+
118
+ Args:
119
+ dict_text (dict): Dictionary containing language vectors for each word.
120
+ label (str): Label to mask.
121
+ top_keep (int): Number of top predictions to keep.
122
+ top_remove (int): Number of top predictions to remove.
123
+
124
+ Returns:
125
+ tuple: Dictionaries of remaining and deleted words after masking.
126
+ """
127
+ dict_remained = deepcopy(dict_text)
128
+ dict_deleted = {}
129
+
130
+ for key, value in dict_text.items():
131
+ logits = value['logits']
132
+ labels = [t[0] for t in logits]
133
+
134
+ if label in labels[:top_keep]:
135
+ dict_deleted[key] = dict_remained[key]
136
+
137
+ if label in labels[:top_remove]:
138
+ dict_remained.pop(key, None)
139
+
140
+ return dict_remained, dict_deleted
141
+
142
+ @staticmethod
143
+ def get_sizeof(text):
144
+ """Compute the size of text in bytes.
145
+
146
+ Args:
147
+ text (str): Input text.
148
+
149
+ Returns:
150
+ int: Size of text in bytes.
151
+ """
152
+ return len(text.encode('utf-8'))
153
+
154
+ @staticmethod
155
+ def custom_sort(word):
156
+ """Custom sorting function for words.
157
+
158
+ Args:
159
+ word (str): Input word.
160
+
161
+ Returns:
162
+ int or float: Sorted value.
163
+ """
164
+ match = re.match(r'^(\d+)_', word)
165
+ if match:
166
+ return int(match.group(1))
167
+ else:
168
+ return float('inf') # Return infinity for words without numbers at the beginning
169
+
170
+ def sum_logits(self, dict_data, label):
171
+ """Compute the sum of logits for a specific label across all words.
172
+
173
+ Args:
174
+ dict_data (dict): Dictionary containing language vectors for each word.
175
+ label (str): Label to sum logits for.
176
+
177
+ Returns:
178
+ float: Total sum of logits for the given label.
179
+ """
180
+ total = 0
181
+ for value in dict_data.values():
182
+ logits = value['logits']
183
+ labels = [t[0] for t in logits]
184
+ if label in labels:
185
+ total += logits[labels.index(label)][1]
186
+ return total
187
+
188
+ def predict_codeswitch(self, text, beta, alpha, min_prob, min_length, max_lambda=1, max_retry=3, alpha_step_increase=5, beta_step_increase=5):
189
+ """Predict language switching points in the input text.
190
+
191
+ Args:
192
+ text (str): Input text.
193
+ beta (int): Number of top predictions to keep.
194
+ alpha (int): Number of top predictions to remove.
195
+ min_prob (float): Minimum probability threshold for language prediction.
196
+ min_length (int): Minimum length of text after masking.
197
+ max_lambda (int, optional): Maximum number of iterations. Defaults to 1.
198
+ max_retry (int, optional): Maximum number of retries. Defaults to 3.
199
+ alpha_step_increase (int, optional): Step increase for alpha. Defaults to 5.
200
+ beta_step_increase (int, optional): Step increase for beta. Defaults to 5.
201
+ Returns:
202
+ dict: Predicted language switching points and associated information.
203
+ """
204
+ info = {}
205
+ index = 0
206
+ retry = 0
207
+
208
+ # compute v
209
+ dict_data = self.compute_v_per_word(text)
210
+
211
+ while index < max_lambda and retry < max_retry:
212
+
213
+ # predict the text
214
+ pred = self.predict(text, k=1)
215
+ label = pred[0][0]
216
+
217
+ # save the current text in case of step back
218
+ prev_text = text
219
+ # mask
220
+ dict_data, dict_masked = self.mask_label_top_k(dict_data, label, beta, alpha)
221
+
222
+ # get the text from the masked text and remained text
223
+ masked_text = ' '.join(x.split('_', 1)[1] for x in dict_masked.keys())
224
+ text = ' '.join(x.split('_', 1)[1] for x in dict_data.keys())
225
+
226
+ # save info
227
+ if self.get_sizeof(masked_text) > min_length or index == 0:
228
+ temp_pred = self.predict(masked_text)
229
+
230
+ if (temp_pred[1][0] > min_prob and temp_pred[0][0] == label) or index == 0:
231
+ info[index] = {
232
+ 'label': label,
233
+ 'text': masked_text,
234
+ 'text_keys': dict_masked.keys(),
235
+ 'size': self.get_sizeof(masked_text),
236
+ 'sum_logit': self.sum_logits(dict_masked, label)
237
+ }
238
+ index += 1
239
+ else:
240
+ text = prev_text
241
+ beta += beta_step_increase
242
+ alpha += alpha_step_increase
243
+ retry += 1
244
+ else:
245
+ text = prev_text
246
+ beta += beta_step_increase
247
+ alpha += alpha_step_increase
248
+ retry += 1
249
+
250
+ if self.get_sizeof(text) < min_length:
251
+ break
252
+
253
+
254
+ # post-process
255
+ post_info = {}
256
+ for value in info.values():
257
+ key = value['label']
258
+ if key in post_info:
259
+ post_info[key].extend(value['text_keys'])
260
+ else:
261
+ post_info[key] = list(value['text_keys'])
262
+
263
+ # join sorted the text from list of keys
264
+ for key in post_info:
265
+ post_info[key] = ' '.join([x.split('_', 1)[1] for x in sorted(set(post_info[key]), key=self.custom_sort)])
266
+
267
+
268
+ return post_info
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fasttext>=0.9.2
2
+ huggingface-hub>=0.14.1
3
+ numpy>=1.24.3,<2