live-lm-critic / critic /perturbations.py
Olivia Figueira
Upload code with streamlit addition
b6e5241
raw
history blame
5.34 kB
"""
Originally by https://worksheets.codalab.org/worksheets/0x8fc01c7fc2b742fdb29c05669f0ad7d2
"""
import json
import os, sys
import re
import random
import numpy as np
from random import sample
from tqdm import tqdm
from collections import Counter
from critic.edit_dist_utils import get_all_edit_dist_one, sample_random_internal_permutations
try:
dir_path = os.path.dirname(os.path.realpath(__file__))
except:
dir_path = '.'
common_typo = json.load(open(f"{dir_path}/common_typo.json"))
random.seed(1234)
np.random.seed(1234)
class RandomPerturbationAttack(object):
def __init__(self, attack_type = 'ed1'):
self.cache = {} #{word: {0: set(), 1: set(),.. }, ..} #0=swap, 1=substitute, 2=delete, 3=insert
self.n_types = 5
self.attack_type = attack_type
#
def sample_perturbations(self, word, n_samples, types):
if types is None:
type_list = list(range(4)) * (n_samples//4) + list(np.random.choice(self.n_types, n_samples % self.n_types, replace=False))
else:
type_list = [sample(types,1)[0] for _ in range(n_samples)]
type_count = Counter(type_list)
perturbations = set()
for type in type_count:
if type not in self.cache[word]:
continue
if len(self.cache[word][type]) >= type_count[type]:
perturbations.update(set(sample(self.cache[word][type], type_count[type])))
else:
perturbations.update(self.cache[word][type])
return perturbations
#
def get_perturbations(self, word, n_samples, types=None):
if word not in self.cache:
self.cache[word] = {}
if word[0].islower():
for type in range(4):
self.cache[word][type] = get_all_edit_dist_one(word, 10**type)
if word in common_typo:
self.cache[word][4] = set(common_typo[word])
elif word[0].isupper():
if word in common_typo:
self.cache[word][4] = set(common_typo[word])
if self.attack_type == 'ed1':
perturbations = self.sample_perturbations(word, n_samples, types)
else:
raise NotImplementedError("Attack type: {} not implemented yet".format(self.attack_type))
return perturbations
#
def name(self):
return 'RandomPerturbationAttack'
word_attack = RandomPerturbationAttack()
def _tokenize(sent):
toks = []
word_idxs = []
for idx, match in enumerate(re.finditer(r'([a-zA-Z]+)|([0-9]+)|.', sent)):
tok = match.group(0)
toks.append(tok)
if len(tok) > 2 and tok.isalpha() and (tok[0].islower()):
word_idxs.append(idx)
return toks, word_idxs
def _detokenize(toks):
return ''.join(toks)
def get_local_neighbors_char_level(sent, max_n_samples=500):
words, word_idxs = _tokenize(sent)
n_samples = min(len(word_idxs) *20, max_n_samples)
sent_perturbations = set()
if len(word_idxs) == 0:
return sent_perturbations
for _ in range(500):
word_idx = sample(word_idxs, 1)[0]
words_cp = words[:]
word_perturbations = list(word_attack.get_perturbations(words_cp[word_idx], n_samples=1))
if len(word_perturbations) > 0:
words_cp[word_idx] = word_perturbations[0]
sent_perturbed = _detokenize(words_cp)
if sent_perturbed != sent:
sent_perturbations.add(sent_perturbed)
if len(sent_perturbations) == n_samples:
break
#Adding common typos such as 's'
for word_idx in word_idxs:
words_cp = words[:]
word = words_cp[word_idx]
if len(word) > 2 and word[0].islower():
words_cp[word_idx] = word +'s'
sent_perturbed = _detokenize(words_cp)
if sent_perturbed != sent:
sent_perturbations.add(sent_perturbed)
words_cp[word_idx] = word[:-1]
sent_perturbed = _detokenize(words_cp)
if sent_perturbed != sent:
sent_perturbations.add(sent_perturbed)
if len(sent_perturbations) > max_n_samples:
sent_perturbations = list(sent_perturbations)
np.random.shuffle(sent_perturbations)
sent_perturbations = set(sent_perturbations[:max_n_samples])
return sent_perturbations
from critic.PIE.word_level_perturb import WordLevelPerturber_all, WordLevelPerturber_refine
from utils.text_utils import detokenize_sent
def get_local_neighbors_word_level(sent_toked, max_n_samples=500, mode='refine'):
""" sent_toked is tokenized by spacy """
n_samples = min(len(sent_toked) *20, max_n_samples)
orig_sent = ' '.join(sent_toked)
orig_sent_detok = detokenize_sent(orig_sent)
if mode == 'refine':
ptb = WordLevelPerturber_refine(orig_sent)
else:
ptb = WordLevelPerturber_all(orig_sent)
sent_perturbations = set()
for _ in range(500):
sent_perturbed = ptb.perturb()
if sent_perturbed != orig_sent:
sent_perturbed_detok = detokenize_sent(sent_perturbed)
sent_perturbations.add(sent_perturbed_detok)
if len(sent_perturbations) == n_samples:
break
assert len(sent_perturbations) <= max_n_samples
return sent_perturbations, orig_sent_detok