live-lm-critic / critic /critic.py
Olivia Figueira
Upload code with streamlit addition
b6e5241
raw history blame
No virus
6.98 kB
import sys
import torch
import random
import hashlib
import numpy as np
from tqdm import tqdm
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
import nltk
nltk.download('punkt')
sys.path.insert(0, '.')
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
from utils.spacy_tokenizer import spacy_tokenize_gec
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
#model.cuda()
model.cpu()
print (f'Loaded {model_name}')
def get_gpt2_loss(input_ids, attention_mask, labels):
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
lm_logits = outputs[1] #[bsize, seqlen, vocab]
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_mask = attention_mask[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
bsize, seqlen = input_ids.size()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(bsize, seqlen-1)
loss = (loss * shift_mask).sum(dim=1) #[bsize, ]
return loss
MAX_LENGTH = 66
def run_gpt2(sents, cuda=False, model_name=None):
assert isinstance(sents, list)
_sents = [tokenizer.bos_token + s for s in sents]
inputs = tokenizer(_sents, return_tensors="pt", padding=True)
if inputs['input_ids'].size(1) > MAX_LENGTH:
return None
if cuda:
inputs = {k: v.cuda() for k, v in inputs.items()}
loss = get_gpt2_loss(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
logps = - loss.detach().cpu()
return logps
def gpt2_critic_char_level_only(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100):
return_string = []
if seed == 'auto':
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
if verbose > 1:
print ('seed', seed)
np.random.seed(seed); random.seed(seed)
is_good = True
for _ in range(1):
sent_perturbations = get_local_neighbors_char_level(sent, max_n_samples=n_samples)
if verbose > 1:
print ("#sent_perturbations (char-level)", len(sent_perturbations))
return_string.append(f"#sent_perturbations (char-level){len(sent_perturbations)}\n")
sents = [sent] + list(sent_perturbations)
if fp16:
with torch.cuda.amp.autocast():
logps = run_gpt2(sents, cuda)
else:
logps = run_gpt2(sents, cuda)
if logps is None:
if verbose:
print ('Invalid input. Maybe the sentence is too long.')
return_string.append('Invalid input. Maybe the sentence is too long.\n')
return None
best_idx = int(logps.argmax())
if best_idx != 0:
is_good = False
break
if verbose:
if is_good:
print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
else:
print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx])))
return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx])))
counter_example = None
if not is_good:
counter_example = [sents[best_idx], float(logps[best_idx])]
return is_good, float(logps[0]), counter_example
def gpt2_critic(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100, word_level_mode='refine'):
return_string = []
if seed == 'auto':
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
if verbose > 1:
print ('seed', seed)
return_string.append(f'seed{seed}\n')
np.random.seed(seed); random.seed(seed)
sent_toked = spacy_tokenize_gec(sent)
is_good = True
for _ in range(1):
sent_perturbations_w, orig_sent = get_local_neighbors_word_level(sent_toked, max_n_samples=n_samples//2, mode=word_level_mode)
sent_perturbations_c = get_local_neighbors_char_level(orig_sent, max_n_samples=n_samples//2)
if verbose > 1:
print ("#sent_perturbations (char-level)", len(sent_perturbations_c))
return_string.append("#sent_perturbations (char-level)\n", len(sent_perturbations_c))
print ("#sent_perturbations (word-level)", len(sent_perturbations_w))
return_string.append("#sent_perturbations (word-level)\n", len(sent_perturbations_w))
sents = [orig_sent] + list(sent_perturbations_c.union(sent_perturbations_w))
if fp16:
with torch.cuda.amp.autocast():
logps = run_gpt2(sents, cuda)
else:
logps = run_gpt2(sents, cuda)
if logps is None:
if verbose:
print ('Invalid input. Maybe the sentence is too long.')
return_string.append('Invalid input. Maybe the sentence is too long.\n')
return None
best_idx = int(logps.argmax())
if best_idx != 0:
is_good = False
break
if verbose:
if is_good:
print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
else:
print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx])))
return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx])))
counter_example = None
if not is_good:
counter_example = [sents[best_idx], float(logps[best_idx])]
return is_good, float(logps[0]), counter_example, return_string
def main():
import streamlit as st
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
sent = st.text_input('Enter a sentence:', value="")
if sent != '':
st.markdown(f"**Sentence**: {sent}")
_,_,_,return_string = gpt2_critic(sent)
st.markdown("**Results:**")
st.write('\n'.join(return_string))
if __name__ == '__main__':
main()