any0019's picture
Update app.py
0f2b728
import streamlit as st
from termcolor import colored
import torch
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@st.cache(allow_output_mutation=True)
def load_models():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
bert_mlm_negative = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_negative', return_dict=True).to(device).train(True)
bert_classifier = BertForSequenceClassification.from_pretrained('any0019/text_style_classifier', num_labels=2).to(device).train(True)
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
def highlight_diff(sent, sent_main):
tokens = tokenizer.tokenize(sent)
tokens_main = tokenizer.tokenize(sent_main)
new_toks = []
for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
if tok != tok_main:
new_toks.append('***' + tok + '***')
else:
new_toks.append(tok)
return ' '.join(new_toks)
def get_classifier_prob(sent):
bert_classifier.eval()
with torch.no_grad():
return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()
def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
"""
- for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
- check :beam_size: hypotheses on each step for each sentence
- save best :beam_size: hypotheses
:return: generator<list of hypotheses on step>
"""
# <YOUR CODE HERE>
bert_mlm_positive.eval()
bert_mlm_negative.eval()
new_beam = []
with torch.no_grad():
for sentence in current_beam:
input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
ids = input_['input_ids'][0].cpu().numpy()
seq_len = probs_positive.shape[0]
p_pos = probs_positive[torch.arange(seq_len), ids]
p_neg = probs_negative[torch.arange(seq_len), ids]
order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
for pos in order_of_replacement:
if pos in used_positions or pos==0 or pos==len(ids)-1:
continue
used_position = pos
replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
for replacement_id in replacement_ids:
if replacement_id == ids[pos]:
continue
new_ids = ids.copy()
new_ids[pos] = replacement_id
new_beam.append(new_ids)
break
if len(new_beam) > 0:
new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
for sent, prob in current_beam.items():
new_beam[sent] = prob
if len(new_beam) > beam_size:
new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
return new_beam, used_position
else:
st.write("No more new hypotheses")
return current_beam, None
def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
current_beam = {sentence: get_classifier_prob(sentence)[1]}
used_poss = []
st.write(f"step #0:")
st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})")
st.write(f"$\qquad${sentence}")
for step in range(max_steps):
current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
st.write(f"\nstep #{step+1}:")
for i, (sent, prob) in enumerate(current_beam.items()):
st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})")
st.write(f"$\qquad${highlight_diff(sent, sentence) if pretty_output else sent}")
if used_pos is None:
return current_beam, used_poss
else:
used_poss.append(used_pos)
return current_beam, used_poss
st.title("Correcting opinions of fellow comrades")
default_value = "write your review here (in lower case - vocab reasons)"
sentence = st.text_area("Text", default_value, height = 275)
beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)
beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))
# beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=False)