Spaces:
Runtime error
Runtime error
""" | |
Originally by https://worksheets.codalab.org/worksheets/0x8fc01c7fc2b742fdb29c05669f0ad7d2 | |
""" | |
import json | |
import os, sys | |
import re | |
import random | |
import numpy as np | |
from random import sample | |
from tqdm import tqdm | |
from collections import Counter | |
from critic.edit_dist_utils import get_all_edit_dist_one, sample_random_internal_permutations | |
try: | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
except: | |
dir_path = '.' | |
common_typo = json.load(open(f"{dir_path}/common_typo.json")) | |
random.seed(1234) | |
np.random.seed(1234) | |
class RandomPerturbationAttack(object): | |
def __init__(self, attack_type = 'ed1'): | |
self.cache = {} #{word: {0: set(), 1: set(),.. }, ..} #0=swap, 1=substitute, 2=delete, 3=insert | |
self.n_types = 5 | |
self.attack_type = attack_type | |
# | |
def sample_perturbations(self, word, n_samples, types): | |
if types is None: | |
type_list = list(range(4)) * (n_samples//4) + list(np.random.choice(self.n_types, n_samples % self.n_types, replace=False)) | |
else: | |
type_list = [sample(types,1)[0] for _ in range(n_samples)] | |
type_count = Counter(type_list) | |
perturbations = set() | |
for type in type_count: | |
if type not in self.cache[word]: | |
continue | |
if len(self.cache[word][type]) >= type_count[type]: | |
perturbations.update(set(sample(self.cache[word][type], type_count[type]))) | |
else: | |
perturbations.update(self.cache[word][type]) | |
return perturbations | |
# | |
def get_perturbations(self, word, n_samples, types=None): | |
if word not in self.cache: | |
self.cache[word] = {} | |
if word[0].islower(): | |
for type in range(4): | |
self.cache[word][type] = get_all_edit_dist_one(word, 10**type) | |
if word in common_typo: | |
self.cache[word][4] = set(common_typo[word]) | |
elif word[0].isupper(): | |
if word in common_typo: | |
self.cache[word][4] = set(common_typo[word]) | |
if self.attack_type == 'ed1': | |
perturbations = self.sample_perturbations(word, n_samples, types) | |
else: | |
raise NotImplementedError("Attack type: {} not implemented yet".format(self.attack_type)) | |
return perturbations | |
# | |
def name(self): | |
return 'RandomPerturbationAttack' | |
word_attack = RandomPerturbationAttack() | |
def _tokenize(sent): | |
toks = [] | |
word_idxs = [] | |
for idx, match in enumerate(re.finditer(r'([a-zA-Z]+)|([0-9]+)|.', sent)): | |
tok = match.group(0) | |
toks.append(tok) | |
if len(tok) > 2 and tok.isalpha() and (tok[0].islower()): | |
word_idxs.append(idx) | |
return toks, word_idxs | |
def _detokenize(toks): | |
return ''.join(toks) | |
def get_local_neighbors_char_level(sent, max_n_samples=500): | |
words, word_idxs = _tokenize(sent) | |
n_samples = min(len(word_idxs) *20, max_n_samples) | |
sent_perturbations = set() | |
if len(word_idxs) == 0: | |
return sent_perturbations | |
for _ in range(500): | |
word_idx = sample(word_idxs, 1)[0] | |
words_cp = words[:] | |
word_perturbations = list(word_attack.get_perturbations(words_cp[word_idx], n_samples=1)) | |
if len(word_perturbations) > 0: | |
words_cp[word_idx] = word_perturbations[0] | |
sent_perturbed = _detokenize(words_cp) | |
if sent_perturbed != sent: | |
sent_perturbations.add(sent_perturbed) | |
if len(sent_perturbations) == n_samples: | |
break | |
#Adding common typos such as 's' | |
for word_idx in word_idxs: | |
words_cp = words[:] | |
word = words_cp[word_idx] | |
if len(word) > 2 and word[0].islower(): | |
words_cp[word_idx] = word +'s' | |
sent_perturbed = _detokenize(words_cp) | |
if sent_perturbed != sent: | |
sent_perturbations.add(sent_perturbed) | |
words_cp[word_idx] = word[:-1] | |
sent_perturbed = _detokenize(words_cp) | |
if sent_perturbed != sent: | |
sent_perturbations.add(sent_perturbed) | |
if len(sent_perturbations) > max_n_samples: | |
sent_perturbations = list(sent_perturbations) | |
np.random.shuffle(sent_perturbations) | |
sent_perturbations = set(sent_perturbations[:max_n_samples]) | |
return sent_perturbations | |
from critic.PIE.word_level_perturb import WordLevelPerturber_all, WordLevelPerturber_refine | |
from utils.text_utils import detokenize_sent | |
def get_local_neighbors_word_level(sent_toked, max_n_samples=500, mode='refine'): | |
""" sent_toked is tokenized by spacy """ | |
n_samples = min(len(sent_toked) *20, max_n_samples) | |
orig_sent = ' '.join(sent_toked) | |
orig_sent_detok = detokenize_sent(orig_sent) | |
if mode == 'refine': | |
ptb = WordLevelPerturber_refine(orig_sent) | |
else: | |
ptb = WordLevelPerturber_all(orig_sent) | |
sent_perturbations = set() | |
for _ in range(500): | |
sent_perturbed = ptb.perturb() | |
if sent_perturbed != orig_sent: | |
sent_perturbed_detok = detokenize_sent(sent_perturbed) | |
sent_perturbations.add(sent_perturbed_detok) | |
if len(sent_perturbations) == n_samples: | |
break | |
assert len(sent_perturbations) <= max_n_samples | |
return sent_perturbations, orig_sent_detok | |