Spaces:
Runtime error
Runtime error
import re | |
import string | |
import flair | |
import jieba | |
import pycld2 as cld2 | |
from .importing import LazyLoader | |
def has_letter(word): | |
"""Returns true if `word` contains at least one character in [A-Za-z].""" | |
return re.search("[A-Za-z]+", word) is not None | |
def is_one_word(word): | |
return len(words_from_text(word)) == 1 | |
def add_indent(s_, numSpaces): | |
s = s_.split("\n") | |
# don't do anything for single-line stuff | |
if len(s) == 1: | |
return s_ | |
first = s.pop(0) | |
s = [(numSpaces * " ") + line for line in s] | |
s = "\n".join(s) | |
s = first + "\n" + s | |
return s | |
def words_from_text(s, words_to_ignore=[]): | |
"""Lowercases a string, removes all non-alphanumeric characters, and splits | |
into words.""" | |
try: | |
isReliable, textBytesFound, details = cld2.detect(s) | |
if details[0][0] == "Chinese" or details[0][0] == "ChineseT": | |
seg_list = jieba.cut(s, cut_all=False) | |
s = " ".join(seg_list) | |
else: | |
s = " ".join(s.split()) | |
except Exception: | |
s = " ".join(s.split()) | |
homos = """˗৭Ȣ𝟕бƼᏎƷᒿlO`ɑЬϲԁе𝚏ɡհіϳ𝒌ⅼmոорԛⲅѕ𝚝սѵԝ×уᴢ""" | |
exceptions = """'-_*@""" | |
filter_pattern = homos + """'\\-_\\*@""" | |
# TODO: consider whether one should add "." to `exceptions` (and "\." to `filter_pattern`) | |
# example "My email address is xxx@yyy.com" | |
filter_pattern = f"[\\w{filter_pattern}]+" | |
words = [] | |
for word in s.split(): | |
# Allow apostrophes, hyphens, underscores, asterisks and at signs as long as they don't begin the word. | |
word = word.lstrip(exceptions) | |
filt = [w.lstrip(exceptions) for w in re.findall(filter_pattern, word)] | |
words.extend(filt) | |
words = list(filter(lambda w: w not in words_to_ignore + [""], words)) | |
return words | |
class TextAttackFlairTokenizer(flair.data.Tokenizer): | |
def tokenize(self, text: str): | |
return words_from_text(text) | |
def default_class_repr(self): | |
if hasattr(self, "extra_repr_keys"): | |
extra_params = [] | |
for key in self.extra_repr_keys(): | |
extra_params.append(" (" + key + ")" + ": {" + key + "}") | |
if len(extra_params): | |
extra_str = "\n" + "\n".join(extra_params) + "\n" | |
extra_str = f"({extra_str})" | |
else: | |
extra_str = "" | |
extra_str = extra_str.format(**self.__dict__) | |
else: | |
extra_str = "" | |
return f"{self.__class__.__name__}{extra_str}" | |
class ReprMixin(object): | |
"""Mixin for enhanced __repr__ and __str__.""" | |
def __repr__(self): | |
return default_class_repr(self) | |
__str__ = __repr__ | |
def extra_repr_keys(self): | |
"""extra fields to be included in the representation of a class.""" | |
return [] | |
LABEL_COLORS = [ | |
"red", | |
"green", | |
"blue", | |
"purple", | |
"yellow", | |
"orange", | |
"pink", | |
"cyan", | |
"gray", | |
"brown", | |
] | |
def process_label_name(label_name): | |
"""Takes a label name from a dataset and makes it nice. | |
Meant to correct different abbreviations and automatically | |
capitalize. | |
""" | |
label_name = label_name.lower() | |
if label_name == "neg": | |
label_name = "negative" | |
elif label_name == "pos": | |
label_name = "positive" | |
return label_name.capitalize() | |
def color_from_label(label_num): | |
"""Arbitrary colors for different labels.""" | |
try: | |
label_num %= len(LABEL_COLORS) | |
return LABEL_COLORS[label_num] | |
except TypeError: | |
return "blue" | |
def color_from_output(label_name, label): | |
"""Returns the correct color for a label name, like 'positive', 'medicine', | |
or 'entailment'.""" | |
label_name = label_name.lower() | |
if label_name in {"entailment", "positive"}: | |
return "green" | |
elif label_name in {"contradiction", "negative"}: | |
return "red" | |
elif label_name in {"neutral"}: | |
return "gray" | |
else: | |
# if no color pre-stored for label name, return color corresponding to | |
# the label number (so, even for unknown datasets, we can give each | |
# class a distinct color) | |
return color_from_label(label) | |
class ANSI_ESCAPE_CODES: | |
"""Escape codes for printing color to the terminal.""" | |
HEADER = "\033[95m" | |
OKBLUE = "\033[94m" | |
OKGREEN = "\033[92m" | |
GRAY = "\033[37m" | |
PURPLE = "\033[35m" | |
YELLOW = "\033[93m" | |
ORANGE = "\033[38:5:208m" | |
PINK = "\033[95m" | |
CYAN = "\033[96m" | |
GRAY = "\033[38:5:240m" | |
BROWN = "\033[38:5:52m" | |
WARNING = "\033[93m" | |
FAIL = "\033[91m" | |
BOLD = "\033[1m" | |
UNDERLINE = "\033[4m" | |
""" This color stops the current color sequence. """ | |
STOP = "\033[0m" | |
def color_text(text, color=None, method=None): | |
if not (isinstance(color, str) or isinstance(color, tuple)): | |
raise TypeError(f"Cannot color text with provided color of type {type(color)}") | |
if isinstance(color, tuple): | |
if len(color) > 1: | |
text = color_text(text, color[1:], method) | |
color = color[0] | |
if method is None: | |
return text | |
if method == "html": | |
return f"<font color = {color}>{text}</font>" | |
elif method == "ansi": | |
if color == "green": | |
color = ANSI_ESCAPE_CODES.OKGREEN | |
elif color == "red": | |
color = ANSI_ESCAPE_CODES.FAIL | |
elif color == "blue": | |
color = ANSI_ESCAPE_CODES.OKBLUE | |
elif color == "purple": | |
color = ANSI_ESCAPE_CODES.PURPLE | |
elif color == "yellow": | |
color = ANSI_ESCAPE_CODES.YELLOW | |
elif color == "orange": | |
color = ANSI_ESCAPE_CODES.ORANGE | |
elif color == "pink": | |
color = ANSI_ESCAPE_CODES.PINK | |
elif color == "cyan": | |
color = ANSI_ESCAPE_CODES.CYAN | |
elif color == "gray": | |
color = ANSI_ESCAPE_CODES.GRAY | |
elif color == "brown": | |
color = ANSI_ESCAPE_CODES.BROWN | |
elif color == "bold": | |
color = ANSI_ESCAPE_CODES.BOLD | |
elif color == "underline": | |
color = ANSI_ESCAPE_CODES.UNDERLINE | |
elif color == "warning": | |
color = ANSI_ESCAPE_CODES.WARNING | |
else: | |
raise ValueError(f"unknown text color {color}") | |
return color + text + ANSI_ESCAPE_CODES.STOP | |
elif method == "file": | |
return "[[" + text + "]]" | |
_flair_pos_tagger = None | |
def flair_tag(sentence, tag_type="upos-fast"): | |
"""Tags a `Sentence` object using `flair` part-of-speech tagger.""" | |
global _flair_pos_tagger | |
if not _flair_pos_tagger: | |
from flair.models import SequenceTagger | |
_flair_pos_tagger = SequenceTagger.load(tag_type) | |
_flair_pos_tagger.predict(sentence, force_token_predictions=True) | |
def zip_flair_result(pred, tag_type="upos-fast"): | |
"""Takes a sentence tagging from `flair` and returns two lists, of words | |
and their corresponding parts-of-speech.""" | |
from flair.data import Sentence | |
if not isinstance(pred, Sentence): | |
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.") | |
tokens = pred.tokens | |
word_list = [] | |
pos_list = [] | |
for token in tokens: | |
word_list.append(token.text) | |
if "pos" in tag_type: | |
pos_list.append(token.annotation_layers["pos"][0]._value) | |
elif tag_type == "ner": | |
pos_list.append(token.get_label("ner")) | |
return word_list, pos_list | |
stanza = LazyLoader("stanza", globals(), "stanza") | |
def zip_stanza_result(pred, tagset="universal"): | |
"""Takes the first sentence from a document from `stanza` and returns two | |
lists, one of words and the other of their corresponding parts-of- | |
speech.""" | |
if not isinstance(pred, stanza.models.common.doc.Document): | |
raise TypeError("Result from Stanza POS tagger must be a `Document` object.") | |
word_list = [] | |
pos_list = [] | |
for sentence in pred.sentences: | |
for word in sentence.words: | |
word_list.append(word.text) | |
if tagset == "universal": | |
pos_list.append(word.upos) | |
else: | |
pos_list.append(word.xpos) | |
return word_list, pos_list | |
def check_if_subword(token, model_type, starting=False): | |
"""Check if ``token`` is a subword token that is not a standalone word. | |
Args: | |
token (str): token to check. | |
model_type (str): type of model (options: "bert", "roberta", "xlnet"). | |
starting (bool): Should be set ``True`` if this token is the starting token of the overall text. | |
This matters because models like RoBERTa does not add "Ġ" to beginning token. | |
Returns: | |
(bool): ``True`` if ``token`` is a subword token. | |
""" | |
avail_models = [ | |
"bert", | |
"gpt", | |
"gpt2", | |
"roberta", | |
"bart", | |
"electra", | |
"longformer", | |
"xlnet", | |
] | |
if model_type not in avail_models: | |
raise ValueError( | |
f"Model type {model_type} is not available. Options are {avail_models}." | |
) | |
if model_type in ["bert", "electra"]: | |
return True if "##" in token else False | |
elif model_type in ["gpt", "gpt2", "roberta", "bart", "longformer"]: | |
if starting: | |
return False | |
else: | |
return False if token[0] == "Ġ" else True | |
elif model_type == "xlnet": | |
return False if token[0] == "_" else True | |
else: | |
return False | |
def strip_BPE_artifacts(token, model_type): | |
"""Strip characters such as "Ġ" that are left over from BPE tokenization. | |
Args: | |
token (str) | |
model_type (str): type of model (options: "bert", "roberta", "xlnet") | |
""" | |
avail_models = [ | |
"bert", | |
"gpt", | |
"gpt2", | |
"roberta", | |
"bart", | |
"electra", | |
"longformer", | |
"xlnet", | |
] | |
if model_type not in avail_models: | |
raise ValueError( | |
f"Model type {model_type} is not available. Options are {avail_models}." | |
) | |
if model_type in ["bert", "electra"]: | |
return token.replace("##", "") | |
elif model_type in ["gpt", "gpt2", "roberta", "bart", "longformer"]: | |
return token.replace("Ġ", "") | |
elif model_type == "xlnet": | |
if len(token) > 1 and token[0] == "_": | |
return token[1:] | |
else: | |
return token | |
else: | |
return token | |
def check_if_punctuations(word): | |
"""Returns ``True`` if ``word`` is just a sequence of punctuations.""" | |
for c in word: | |
if c not in string.punctuation: | |
return False | |
return True | |