Spaces:
Running
Running
import re | |
from contextlib import contextmanager | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from fuzzysearch import find_near_matches | |
from pyarabic import araby | |
from torch import nn | |
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex | |
multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL) | |
# ASAD-NEW_AraBERT_PREP-Balanced | |
class NewArabicPreprocessorBalanced(ArabertPreprocessor): | |
def __init__( | |
self, | |
model_name: str, | |
keep_emojis: bool = False, | |
remove_html_markup: bool = True, | |
replace_urls_emails_mentions: bool = True, | |
strip_tashkeel: bool = True, | |
strip_tatweel: bool = True, | |
insert_white_spaces: bool = True, | |
remove_non_digit_repetition: bool = True, | |
replace_slash_with_dash: bool = None, | |
map_hindi_numbers_to_arabic: bool = None, | |
apply_farasa_segmentation: bool = None, | |
): | |
if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name: | |
keep_emojis = True | |
remove_non_digit_repetition = True | |
super().__init__( | |
model_name=model_name, | |
keep_emojis=keep_emojis, | |
remove_html_markup=remove_html_markup, | |
replace_urls_emails_mentions=replace_urls_emails_mentions, | |
strip_tashkeel=strip_tashkeel, | |
strip_tatweel=strip_tatweel, | |
insert_white_spaces=insert_white_spaces, | |
remove_non_digit_repetition=remove_non_digit_repetition, | |
replace_slash_with_dash=replace_slash_with_dash, | |
map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic, | |
apply_farasa_segmentation=apply_farasa_segmentation, | |
) | |
self.true_model_name = model_name | |
def preprocess(self, text): | |
if "UBC-NLP" in self.true_model_name: | |
return self.ubc_prep(text) | |
def ubc_prep(self, text): | |
text = re.sub("\s", " ", text) | |
text = text.replace("\\n", " ") | |
text = text.replace("\\r", " ") | |
text = araby.strip_tashkeel(text) | |
text = araby.strip_tatweel(text) | |
# replace all possible URLs | |
for reg in url_regexes: | |
text = re.sub(reg, " URL ", text) | |
text = re.sub("(URL\s*)+", " URL ", text) | |
# replace mentions with USER | |
text = re.sub(user_mention_regex, " USER ", text) | |
text = re.sub("(USER\s*)+", " USER ", text) | |
# replace hashtags with HASHTAG | |
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
text = text.replace("#", " HASH ") | |
text = text.replace("_", " ") | |
text = " ".join(text.split()) | |
# text = re.sub("\B\\[Uu]\w+", "", text) | |
text = text.replace("\\U0001f97a", "🥺") | |
text = text.replace("\\U0001f928", "🤨") | |
text = text.replace("\\U0001f9d8", "😀") | |
text = text.replace("\\U0001f975", "😥") | |
text = text.replace("\\U0001f92f", "😲") | |
text = text.replace("\\U0001f92d", "🤭") | |
text = text.replace("\\U0001f9d1", "😐") | |
text = text.replace("\\U000e0067", "") | |
text = text.replace("\\U000e006e", "") | |
text = text.replace("\\U0001f90d", "♥") | |
text = text.replace("\\U0001f973", "🎉") | |
text = text.replace("\\U0001fa79", "") | |
text = text.replace("\\U0001f92b", "🤐") | |
text = text.replace("\\U0001f9da", "🦋") | |
text = text.replace("\\U0001f90e", "♥") | |
text = text.replace("\\U0001f9d0", "🧐") | |
text = text.replace("\\U0001f9cf", "") | |
text = text.replace("\\U0001f92c", "😠") | |
text = text.replace("\\U0001f9f8", "😸") | |
text = text.replace("\\U0001f9b6", "💩") | |
text = text.replace("\\U0001f932", "🤲") | |
text = text.replace("\\U0001f9e1", "🧡") | |
text = text.replace("\\U0001f974", "☹") | |
text = text.replace("\\U0001f91f", "") | |
text = text.replace("\\U0001f9fb", "💩") | |
text = text.replace("\\U0001f92a", "🤪") | |
text = text.replace("\\U0001f9fc", "") | |
text = text.replace("\\U000e0065", "") | |
text = text.replace("\\U0001f92e", "💩") | |
text = text.replace("\\U000e007f", "") | |
text = text.replace("\\U0001f970", "🥰") | |
text = text.replace("\\U0001f929", "🤩") | |
text = text.replace("\\U0001f6f9", "") | |
text = text.replace("🤍", "♥") | |
text = text.replace("🦠", "😷") | |
text = text.replace("🤢", "مقرف") | |
text = text.replace("🤮", "مقرف") | |
text = text.replace("🕠", "⌚") | |
text = text.replace("🤬", "😠") | |
text = text.replace("🤧", "😷") | |
text = text.replace("🥳", "🎉") | |
text = text.replace("🥵", "🔥") | |
text = text.replace("🥴", "☹") | |
text = text.replace("🤫", "🤐") | |
text = text.replace("🤥", "كذاب") | |
text = text.replace("\\u200d", " ") | |
text = text.replace("u200d", " ") | |
text = text.replace("\\u200c", " ") | |
text = text.replace("u200c", " ") | |
text = text.replace('"', "'") | |
text = text.replace("\\xa0", "") | |
text = text.replace("\\u2066", " ") | |
text = re.sub("\B\\\[Uu]\w+", "", text) | |
text = super(NewArabicPreprocessorBalanced, self).preprocess(text) | |
text = " ".join(text.split()) | |
return text | |
"""CNNMarbertArabicPreprocessor""" | |
# ASAD-CNN_MARBERT | |
class CNNMarbertArabicPreprocessor(ArabertPreprocessor): | |
def __init__( | |
self, | |
model_name, | |
keep_emojis=False, | |
remove_html_markup=True, | |
replace_urls_emails_mentions=True, | |
remove_elongations=True, | |
): | |
if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name: | |
keep_emojis = True | |
remove_elongations = False | |
super().__init__( | |
model_name, | |
keep_emojis, | |
remove_html_markup, | |
replace_urls_emails_mentions, | |
remove_elongations, | |
) | |
self.true_model_name = model_name | |
def preprocess(self, text): | |
if "UBC-NLP" in self.true_model_name: | |
return self.ubc_prep(text) | |
def ubc_prep(self, text): | |
text = re.sub("\s", " ", text) | |
text = text.replace("\\n", " ") | |
text = araby.strip_tashkeel(text) | |
text = araby.strip_tatweel(text) | |
# replace all possible URLs | |
for reg in url_regexes: | |
text = re.sub(reg, " URL ", text) | |
text = re.sub("(URL\s*)+", " URL ", text) | |
# replace mentions with USER | |
text = re.sub(user_mention_regex, " USER ", text) | |
text = re.sub("(USER\s*)+", " USER ", text) | |
# replace hashtags with HASHTAG | |
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
text = text.replace("#", " HASH ") | |
text = text.replace("_", " ") | |
text = " ".join(text.split()) | |
text = super(CNNMarbertArabicPreprocessor, self).preprocess(text) | |
text = text.replace("\u200d", " ") | |
text = text.replace("u200d", " ") | |
text = text.replace("\u200c", " ") | |
text = text.replace("u200c", " ") | |
text = text.replace('"', "'") | |
# text = re.sub('[\d\.]+', ' NUM ', text) | |
# text = re.sub('(NUM\s*)+', ' NUM ', text) | |
text = multiple_char_pattern.sub(r"\1\1", text) | |
text = " ".join(text.split()) | |
return text | |
"""Trial5ArabicPreprocessor""" | |
class Trial5ArabicPreprocessor(ArabertPreprocessor): | |
def __init__( | |
self, | |
model_name, | |
keep_emojis=False, | |
remove_html_markup=True, | |
replace_urls_emails_mentions=True, | |
): | |
if "UBC-NLP" in model_name: | |
keep_emojis = True | |
super().__init__( | |
model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions | |
) | |
self.true_model_name = model_name | |
def preprocess(self, text): | |
if "UBC-NLP" in self.true_model_name: | |
return self.ubc_prep(text) | |
def ubc_prep(self, text): | |
text = re.sub("\s", " ", text) | |
text = text.replace("\\n", " ") | |
text = araby.strip_tashkeel(text) | |
text = araby.strip_tatweel(text) | |
# replace all possible URLs | |
for reg in url_regexes: | |
text = re.sub(reg, " URL ", text) | |
# replace mentions with USER | |
text = re.sub(user_mention_regex, " USER ", text) | |
# replace hashtags with HASHTAG | |
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
text = text.replace("#", " HASH TAG ") | |
text = text.replace("_", " ") | |
text = " ".join(text.split()) | |
text = super(Trial5ArabicPreprocessor, self).preprocess(text) | |
# text = text.replace("السلام عليكم"," ") | |
# text = text.replace(find_near_matches("السلام عليكم",text,max_deletions=3,max_l_dist=3)[0].matched," ") | |
return text | |
"""SarcasmArabicPreprocessor""" | |
class SarcasmArabicPreprocessor(ArabertPreprocessor): | |
def __init__( | |
self, | |
model_name, | |
keep_emojis=False, | |
remove_html_markup=True, | |
replace_urls_emails_mentions=True, | |
): | |
if "UBC-NLP" in model_name: | |
keep_emojis = True | |
super().__init__( | |
model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions | |
) | |
self.true_model_name = model_name | |
def preprocess(self, text): | |
if "UBC-NLP" in self.true_model_name: | |
return self.ubc_prep(text) | |
else: | |
return super(SarcasmArabicPreprocessor, self).preprocess(text) | |
def ubc_prep(self, text): | |
text = re.sub("\s", " ", text) | |
text = text.replace("\\n", " ") | |
text = araby.strip_tashkeel(text) | |
text = araby.strip_tatweel(text) | |
# replace all possible URLs | |
for reg in url_regexes: | |
text = re.sub(reg, " URL ", text) | |
# replace mentions with USER | |
text = re.sub(user_mention_regex, " USER ", text) | |
# replace hashtags with HASHTAG | |
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
text = text.replace("#", " HASH TAG ") | |
text = text.replace("_", " ") | |
text = text.replace('"', " ") | |
text = " ".join(text.split()) | |
text = super(SarcasmArabicPreprocessor, self).preprocess(text) | |
return text | |
"""NoAOAArabicPreprocessor""" | |
class NoAOAArabicPreprocessor(ArabertPreprocessor): | |
def __init__( | |
self, | |
model_name, | |
keep_emojis=False, | |
remove_html_markup=True, | |
replace_urls_emails_mentions=True, | |
): | |
if "UBC-NLP" in model_name: | |
keep_emojis = True | |
super().__init__( | |
model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions | |
) | |
self.true_model_name = model_name | |
def preprocess(self, text): | |
if "UBC-NLP" in self.true_model_name: | |
return self.ubc_prep(text) | |
else: | |
return super(NoAOAArabicPreprocessor, self).preprocess(text) | |
def ubc_prep(self, text): | |
text = re.sub("\s", " ", text) | |
text = text.replace("\\n", " ") | |
text = araby.strip_tashkeel(text) | |
text = araby.strip_tatweel(text) | |
# replace all possible URLs | |
for reg in url_regexes: | |
text = re.sub(reg, " URL ", text) | |
# replace mentions with USER | |
text = re.sub(user_mention_regex, " USER ", text) | |
# replace hashtags with HASHTAG | |
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
text = text.replace("#", " HASH TAG ") | |
text = text.replace("_", " ") | |
text = " ".join(text.split()) | |
text = super(NoAOAArabicPreprocessor, self).preprocess(text) | |
text = text.replace("السلام عليكم", " ") | |
text = text.replace("ورحمة الله وبركاته", " ") | |
matched = find_near_matches("السلام عليكم", text, max_deletions=3, max_l_dist=3) | |
if len(matched) > 0: | |
text = text.replace(matched[0].matched, " ") | |
matched = find_near_matches( | |
"ورحمة الله وبركاته", text, max_deletions=3, max_l_dist=3 | |
) | |
if len(matched) > 0: | |
text = text.replace(matched[0].matched, " ") | |
return text | |
class CnnBertForSequenceClassification(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.bert = BertModel(config) | |
filter_sizes = [1, 2, 3, 4, 5] | |
num_filters = 32 | |
self.convs1 = nn.ModuleList( | |
[nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes] | |
) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
x = outputs[2][-4:] | |
x = torch.stack(x, dim=1) | |
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] | |
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] | |
x = torch.cat(x, 1) | |
x = self.dropout(x) | |
logits = self.classifier(x) | |
loss = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and ( | |
labels.dtype == torch.long or labels.dtype == torch.int | |
): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = nn.MSELoss() | |
if self.num_labels == 1: | |
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
else: | |
loss = loss_fct(logits, labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = nn.BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=None, | |
attentions=outputs.attentions, | |
) | |
class CNNTextClassificationPipeline: | |
def __init__(self, model_path, device, return_all_scores=False): | |
self.model_path = model_path | |
self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path) | |
# Special handling | |
self.device = torch.device("cpu" if device < 0 else f"cuda:{device}") | |
if self.device.type == "cuda": | |
self.model = self.model.to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.return_all_scores = return_all_scores | |
def device_placement(self): | |
""" | |
Context Manager allowing tensor allocation on the user-specified device in framework agnostic way. | |
Returns: | |
Context manager | |
Examples:: | |
# Explicitly ask for tensor allocation on CUDA device :0 | |
pipe = pipeline(..., device=0) | |
with pipe.device_placement(): | |
# Every framework specific tensor allocation will be done on the request device | |
output = pipe(...) | |
""" | |
if self.device.type == "cuda": | |
torch.cuda.set_device(self.device) | |
yield | |
def ensure_tensor_on_device(self, **inputs): | |
""" | |
Ensure PyTorch tensors are on the specified device. | |
Args: | |
inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`. | |
Return: | |
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device. | |
""" | |
return { | |
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor | |
for name, tensor in inputs.items() | |
} | |
def __call__(self, text): | |
""" | |
Classify the text(s) given as inputs. | |
Args: | |
args (:obj:`str` or :obj:`List[str]`): | |
One or several texts (or one list of prompts) to classify. | |
Return: | |
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys: | |
- **label** (:obj:`str`) -- The label predicted. | |
- **score** (:obj:`float`) -- The corresponding probability. | |
If ``self.return_all_scores=True``, one such dictionary is returned per label. | |
""" | |
# outputs = super().__call__(*args, **kwargs) | |
inputs = self.tokenizer.batch_encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=64, | |
padding=True, | |
truncation="longest_first", | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
inputs = self.ensure_tensor_on_device(**inputs) | |
predictions = self.model(**inputs)[0].cpu() | |
predictions = predictions.numpy() | |
if self.model.config.num_labels == 1: | |
scores = 1.0 / (1.0 + np.exp(-predictions)) | |
else: | |
scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True) | |
if self.return_all_scores: | |
return [ | |
[ | |
{"label": self.model.config.id2label[i], "score": score.item()} | |
for i, score in enumerate(item) | |
] | |
for item in scores | |
] | |
else: | |
return [ | |
{"label": self.inv_label_map[item.argmax()], "score": item.max().item()} | |
for item in scores | |
] | |