import torch import streamlit as st import transformers from transformers import BertTokenizer, BertForMaskedLM from transformers import BertForSequenceClassification, DataCollatorWithPadding st.set_page_config(page_title="style transfer", layout="centered") st.markdown("Welcome to text style transfer. Wait a few seconds for the model to load...") device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) bert_cls = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', return_dict=True, problem_type="multi_label_classification", num_labels=2 ).to(device).train(False) def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): """ - split the sentence into tokens using the INGSOC-approved BERT tokenizer - find :num_tokens: tokens with the highest ratio (see above) - replace them with :k_best: words according to bert_mlm_positive :return: a list of all possible strings (up to k_best * num_tokens) """ res = [] sentence_ix = tokenizer(sentence, return_tensors='pt') sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} length = len(sentence_ix['input_ids'][0]) probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0] probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0] p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]] p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]] p_relative = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon) best_pos = torch.argsort(p_relative[1:-1], dim=0)[:num_tokens] + 1 best_pos_tokens = torch.argsort(probs_positive, dim=1)[..., -k_best:] for pos in best_pos: for replace_token in best_pos_tokens[pos]: new_tensor = sentence_ix['input_ids'][0].cpu().numpy() new_tensor[pos] = replace_token new_sentence = tokenizer.decode(new_tensor[1:-1]) res.append(new_sentence) # print(new_sentence) return res def beamSearch(sentence, n_rounds=5): labels = torch.tensor([[1, 1]], dtype=torch.float).to(device) for i in range(n_rounds): cur_res = get_replacements(sentence, num_tokens=num_tokens, k_best=k_best) max_prob = -1 best_sentence = None for candidate_sentence in cur_res: inputs = tokenizer(candidate_sentence, return_tensors="pt").to(device) outputs = bert_cls(**inputs, labels=labels) prob_good = outputs.logits.softmax(dim=-1)[0][1] if prob_good > max_prob: max_prob = prob_good best_sentence = candidate_sentence if debug: st.markdown(f"cur_sentence: {best_sentence}") sentence = best_sentence return sentence bert_mlm_positive.load_state_dict(torch.load('mlm_positive.pth', map_location=torch.device('cpu'))) bert_mlm_negative.load_state_dict(torch.load('mlm_negative.pth', map_location=torch.device('cpu'))) # bert_cls.load_state_dict(torch.load('bert_cls.pth', map_location=torch.device('cpu'))) user_input = st.text_input("Please enter something review") n_rounds = st.slider("Pick a number of rounds in beamSearch", 1, 10, value=5) k_best = st.slider("Pick k_best parameter", 1, 5, value=3) num_tokens = st.slider("Pick num_tokens parameter", 1, 5, value=3) debug = st.radio("print intermediate steps", [True, False]) if len(user_input.split()) > 0: res = beamSearch(user_input, n_rounds=n_rounds) st.markdown("Processed review:") st.markdown(f"{res}")