import streamlit as st from termcolor import colored import torch from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification device = 'cuda' if torch.cuda.is_available() else 'cpu' @st.cache(allow_output_mutation=True) def load_models(): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True) bert_mlm_negative = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_negative', return_dict=True).to(device).train(True) bert_classifier = BertForSequenceClassification.from_pretrained('any0019/text_style_classifier', num_labels=2).to(device).train(True) return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models() def highlight_diff(sent, sent_main): tokens = tokenizer.tokenize(sent) tokens_main = tokenizer.tokenize(sent_main) new_toks = [] for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)): if tok != tok_main: new_toks.append('***' + tok + '***') else: new_toks.append(tok) return ' '.join(new_toks) def get_classifier_prob(sent): bert_classifier.eval() with torch.no_grad(): return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy() def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]): """ - for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer - check :beam_size: hypotheses on each step for each sentence - save best :beam_size: hypotheses :return: generator """ # bert_mlm_positive.eval() bert_mlm_negative.eval() new_beam = [] with torch.no_grad(): for sentence in current_beam: input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()} probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0] probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0] ids = input_['input_ids'][0].cpu().numpy() seq_len = probs_positive.shape[0] p_pos = probs_positive[torch.arange(seq_len), ids] p_neg = probs_negative[torch.arange(seq_len), ids] order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort() for pos in order_of_replacement: if pos in used_positions or pos==0 or pos==len(ids)-1: continue used_position = pos replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size] for replacement_id in replacement_ids: if replacement_id == ids[pos]: continue new_ids = ids.copy() new_ids[pos] = replacement_id new_beam.append(new_ids) break if len(new_beam) > 0: new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam] new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam} for sent, prob in current_beam.items(): new_beam[sent] = prob if len(new_beam) > beam_size: new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]} return new_beam, used_position else: st.write("No more new hypotheses") return current_beam, None def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False): current_beam = {sentence: get_classifier_prob(sentence)[1]} used_poss = [] st.write(f"step #0:") st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})") st.write(f"$\qquad${sentence}") for step in range(max_steps): current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss) st.write(f"\nstep #{step+1}:") for i, (sent, prob) in enumerate(current_beam.items()): st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})") st.write(f"$\qquad${highlight_diff(sent, sentence) if pretty_output else sent}") if used_pos is None: return current_beam, used_poss else: used_poss.append(used_pos) return current_beam, used_poss st.title("Correcting opinions of fellow comrades") default_value = "write your review here (in lower case - vocab reasons)" sentence = st.text_area("Text", default_value, height = 275) beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1) max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1) prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1) beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy)) # beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=False)