File size: 9,772 Bytes
2c9efe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
import fasttext
import numpy as np
import re
import string
from copy import deepcopy
class MaskLID:
"""A class for code-switching language identification using iterative masking."""
def __init__(self, model_path, languages=-1):
"""Initialize the MaskLID class.
Args:
model_path (str): The path to the fastText model.
languages (int or list, optional): The indices or list of language labels to consider. Defaults to -1.
"""
self.model = fasttext.load_model(model_path)
self.output_matrix = self.model.get_output_matrix()
self.labels = self.model.get_labels()
self.language_indices = self._compute_language_indices(languages)
self.labels = [self.labels[i] for i in self.language_indices]
def _compute_language_indices(self, languages):
"""Compute indices of selected languages.
Args:
languages (int or list): The indices or list of language labels.
Returns:
list: Indices of selected languages.
"""
if languages != -1 and isinstance(languages, list):
return [self.labels.index(l) for l in set(languages) if l in self.labels]
return list(range(len(self.labels)))
def _softmax(self, x):
"""Compute softmax values for each score in array x.
Args:
x (numpy.ndarray): Input array.
Returns:
numpy.ndarray: Softmax output.
"""
exp_x = np.exp(x - np.max(x))
return exp_x / np.sum(exp_x)
def _normalize_text(self, text):
"""Normalize input text.
Args:
text (str): Input text.
Returns:
str: Normalized text.
"""
replace_by = " "
replacement_map = {ord(c): replace_by for c in '_:' + '•#{|}' + string.digits}
text = text.replace('\n', replace_by)
text = text.translate(replacement_map)
return re.sub(r'\s+', replace_by, text).strip()
def predict(self, text, k=1):
"""Predict the language of the input text.
Args:
text (str): Input text.
k (int, optional): Number of top predictions to retrieve. Defaults to 1.
Returns:
tuple: Top predicted labels and their probabilities.
"""
sentence_vector = self.model.get_sentence_vector(text)
result_vector = np.dot(self.output_matrix, sentence_vector)
softmax_result = self._softmax(result_vector)[self.language_indices]
top_k_indices = np.argsort(softmax_result)[-k:][::-1]
top_k_labels = [self.labels[i] for i in top_k_indices]
top_k_probs = softmax_result[top_k_indices]
return tuple(top_k_labels), top_k_probs
def compute_v(self, sentence_vector):
"""Compute the language vectors for a given sentence vector.
Args:
sentence_vector (numpy.ndarray): Sentence vector.
Returns:
list: Sorted list of labels and their associated vectors.
"""
result_vector = np.dot(self.output_matrix[self.language_indices, :], sentence_vector)
return sorted(zip(self.labels, result_vector), key=lambda x: x[1], reverse=True)
def compute_v_per_word(self, text):
"""Compute language vectors for each word in the input text.
Args:
text (str): Input text.
Returns:
dict: Dictionary containing language vectors for each word.
"""
text = self._normalize_text(text)
words = self.model.get_line(text)[0]
words = [w for w in words if w not in ['</s>', '</s>']]
subword_ids = [self.model.get_subwords(sw)[1] for sw in words]
sentence_vector = [np.sum([self.model.get_input_vector(id) for id in sid], axis=0) for sid in subword_ids]
dict_text = {}
for i, word in enumerate(words):
key = f"{i}_{word}"
dict_text[key] = {'logits': self.compute_v(sentence_vector[i])}
return dict_text
def mask_label_top_k(self, dict_text, label, top_keep, top_remove):
"""Mask top predictions for a given label.
Args:
dict_text (dict): Dictionary containing language vectors for each word.
label (str): Label to mask.
top_keep (int): Number of top predictions to keep.
top_remove (int): Number of top predictions to remove.
Returns:
tuple: Dictionaries of remaining and deleted words after masking.
"""
dict_remained = deepcopy(dict_text)
dict_deleted = {}
for key, value in dict_text.items():
logits = value['logits']
labels = [t[0] for t in logits]
if label in labels[:top_keep]:
dict_deleted[key] = dict_remained[key]
if label in labels[:top_remove]:
dict_remained.pop(key, None)
return dict_remained, dict_deleted
@staticmethod
def get_sizeof(text):
"""Compute the size of text in bytes.
Args:
text (str): Input text.
Returns:
int: Size of text in bytes.
"""
return len(text.encode('utf-8'))
@staticmethod
def custom_sort(word):
"""Custom sorting function for words.
Args:
word (str): Input word.
Returns:
int or float: Sorted value.
"""
match = re.match(r'^(\d+)_', word)
if match:
return int(match.group(1))
else:
return float('inf') # Return infinity for words without numbers at the beginning
def sum_logits(self, dict_data, label):
"""Compute the sum of logits for a specific label across all words.
Args:
dict_data (dict): Dictionary containing language vectors for each word.
label (str): Label to sum logits for.
Returns:
float: Total sum of logits for the given label.
"""
total = 0
for value in dict_data.values():
logits = value['logits']
labels = [t[0] for t in logits]
if label in labels:
total += logits[labels.index(label)][1]
return total
def predict_codeswitch(self, text, beta, alpha, min_prob, min_length, max_lambda=1, max_retry=3, alpha_step_increase=5, beta_step_increase=5):
"""Predict language switching points in the input text.
Args:
text (str): Input text.
beta (int): Number of top predictions to keep.
alpha (int): Number of top predictions to remove.
min_prob (float): Minimum probability threshold for language prediction.
min_length (int): Minimum length of text after masking.
max_lambda (int, optional): Maximum number of iterations. Defaults to 1.
max_retry (int, optional): Maximum number of retries. Defaults to 3.
alpha_step_increase (int, optional): Step increase for alpha. Defaults to 5.
beta_step_increase (int, optional): Step increase for beta. Defaults to 5.
Returns:
dict: Predicted language switching points and associated information.
"""
info = {}
index = 0
retry = 0
# compute v
dict_data = self.compute_v_per_word(text)
while index < max_lambda and retry < max_retry:
# predict the text
pred = self.predict(text, k=1)
label = pred[0][0]
# save the current text in case of step back
prev_text = text
# mask
dict_data, dict_masked = self.mask_label_top_k(dict_data, label, beta, alpha)
# get the text from the masked text and remained text
masked_text = ' '.join(x.split('_', 1)[1] for x in dict_masked.keys())
text = ' '.join(x.split('_', 1)[1] for x in dict_data.keys())
# save info
if self.get_sizeof(masked_text) > min_length or index == 0:
temp_pred = self.predict(masked_text)
if (temp_pred[1][0] > min_prob and temp_pred[0][0] == label) or index == 0:
info[index] = {
'label': label,
'text': masked_text,
'text_keys': dict_masked.keys(),
'size': self.get_sizeof(masked_text),
'sum_logit': self.sum_logits(dict_masked, label)
}
index += 1
else:
text = prev_text
beta += beta_step_increase
alpha += alpha_step_increase
retry += 1
else:
text = prev_text
beta += beta_step_increase
alpha += alpha_step_increase
retry += 1
if self.get_sizeof(text) < min_length:
break
# post-process
post_info = {}
for value in info.values():
key = value['label']
if key in post_info:
post_info[key].extend(value['text_keys'])
else:
post_info[key] = list(value['text_keys'])
# join sorted the text from list of keys
for key in post_info:
post_info[key] = ' '.join([x.split('_', 1)[1] for x in sorted(set(post_info[key]), key=self.custom_sort)])
return post_info |