live-lm-critic / critic /edit_dist_utils.py
Olivia Figueira
Upload code with streamlit addition
b6e5241
"""
Edit distance utils...
Originally by https://worksheets.codalab.org/worksheets/0x8fc01c7fc2b742fdb29c05669f0ad7d2
"""
from collections import defaultdict
import numpy as np
import random
import string
from itertools import permutations
def process_filetype(filetype):
insert = (filetype // 1000) % 2 == 1
delete = (filetype // 100) % 2 == 1
substitute = (filetype // 10) % 2 == 1
swap = filetype % 2 == 1
return insert, delete, substitute, swap
def get_all_edit_dist_one(word, filetype = 1111, sub_restrict = None):
"""
Allowable edit_dist_one perturbations:
1. Insert any lowercase characer at any position other than the start
2. Delete any character other than the first one
3. Substitute any lowercase character for any other lowercase letter other than the start
4. Swap adjacent characters
We also include the original word. Filetype determines which of the allowable perturbations to use.
"""
insert, delete, substitute, swap = process_filetype(filetype)
#last_mod_pos is last thing you could insert before
last_mod_pos = len(word) #- 1
ed1 = set()
if len(word) <= 2 or word[:1].isupper() or word[:1].isnumeric():
return ed1
for pos in range(1, last_mod_pos + 1): #can add letters at the end
if delete and pos < last_mod_pos:
deletion = word[:pos] + word[pos + 1:]
ed1.add(deletion)
if swap and pos < last_mod_pos - 1:
#swapping thing at pos with thing at pos + 1
swaped = word[:pos] + word[pos + 1] + word[pos] + word[pos + 2:]
ed1.add(swaped)
for letter in string.ascii_lowercase: #+"'-": #no need to add '-, as we want to corrupt good to bad
if insert:
#Insert right after pos - 1
insertion = word[:pos] + letter + word[pos:]
ed1.add(insertion)
can_substitute = sub_restrict is None or letter in sub_restrict[word[pos]]
if substitute and pos < last_mod_pos and can_substitute:
substitution = word[:pos] + letter + word[pos + 1:]
ed1.add(substitution)
#Include original word
# ed1.add(word)
return ed1
def get_all_internal_permutations(word):
if len(word) > 10:
return set([word])
first_char = word[0]
last_char = word[-1]
internal_chars = word[1:-1]
internal_permutations = set()
for int_perm in permutations(internal_chars):
int_perm_str = ''.join(int_perm)
perm = '{}{}{}'.format(first_char, int_perm_str, last_char)
internal_permutations.add(perm)
return internal_permutations
def sample_random_internal_permutations(word, n_perts = 5):
#We try swapping everything with the second character...
if len(word) < 4:
return set([word])
#iterate through positions between second and last
perturbations = set()
start = word[0]
end = word[-1]
middle = word[1:-1]
for _ in range(n_perts):
middle_list = list(middle)
random.shuffle(middle_list)
mixed_up_middle = ''.join(middle_list)
perturbations.add('{}{}{}'.format(start, mixed_up_middle, end))
return perturbations
def get_sorted_word(word):
if len(word) < 3:
sorted_word = word
else:
sorted_word = '{}{}{}'.format(word[0], ''.join(sorted(word[1:-1])), word[-1])
return sorted_word
def get_sorted_word_set(word):
if len(word) < 3:
sorted_word = word
else:
sorted_word = '{}{}{}'.format(word[0], ''.join(sorted(word[1:-1])), word[-1])
return set([sorted_word])
#Used to create agglomerative clusters.
def preprocess_ed1_neighbors(vocab, sub_restrict = None, filetype = 1111):
vocab = set([word.lower() for word in vocab])
typo2words = defaultdict(set)
for word in vocab:
ed1_typos = get_all_edit_dist_one(word, filetype = filetype, sub_restrict = sub_restrict)
for typo in ed1_typos:
typo2words[typo].add(word)
word2neighbors = defaultdict(set)
for typo in typo2words:
for word in typo2words[typo]:
word2neighbors[word] = word2neighbors[word].union(typo2words[typo])
return word2neighbors
#Used to create agglomerative clusters.
def ed1_neighbors_mat(vocab, sub_restrict = None, filetype = 1111):
vocab = [word.lower() for word in vocab]
word2idx = dict([(word, i) for i, word in enumerate(vocab)])
word2neighbors = preprocess_ed1_neighbors(vocab, sub_restrict = sub_restrict, filetype = filetype)
edges = set()
for word in word2neighbors:
for neighbor in word2neighbors[word]:
edge = [word, neighbor]
edge.sort()
edge = tuple(edge)
edges.add(edge)
edge_mat = np.zeros((len(vocab), len(vocab)), dtype = int)
for edge in edges:
vtx1, vtx2 = edge
idx1, idx2 = word2idx[vtx1], word2idx[vtx2]
edge_mat[idx1][idx2] = 1
edge_mat[idx2][idx1] = 1
return edge_mat
if __name__ == '__main__':
while True:
word = input("Enter a word: ")
print("Total number of possible perturbations: {}".format(len(get_all_edit_dist_one(word))))