import streamlit as st import torch import numpy as np from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification import torch.nn.functional as F from copy import copy from torch.nn.functional import softmax tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_mlm_positive = BertForMaskedLM.from_pretrained( 'ewriji/heil-A.412C-positive', return_dict=True ) bert_mlm_negative = BertForMaskedLM.from_pretrained( 'ewriji/heil-A.412C-negative', return_dict=True ) classification_model = BertForSequenceClassification.from_pretrained( 'ewriji/heil-A.412C-classification', return_dict=True ) def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): """ - split the sentence into tokens using the INGSOC-approved BERT tokenizer - find :num_tokens: tokens with the highest ratio (see above) - replace them with :k_best: words according to bert_mlm_positive :return: a list of all possible strings (up to k_best * num_tokens) """ words = sentence.split() batch = [] mask_word = [] for i in range(len(words)): masked = copy(words) mask_word.append(masked[i]) masked[i] = tokenizer.mask_token batch.append(masked) input = tokenizer(batch, padding=True, return_tensors="pt", is_split_into_words=True) mask_ids = (input["input_ids"] == tokenizer.mask_token_id).nonzero().cpu() # predict probabilities positive_logits = bert_mlm_positive(**input) negative_logits = bert_mlm_negative(**input) word_idx = [tokenizer.encode(word, add_special_tokens=False)[0] for word in mask_word] positive_prob = softmax( positive_logits.logits[mask_ids[:, 0], mask_ids[:, 1]], dim=-1 ) positive_prob = positive_prob[np.arange(len(word_idx)), word_idx] negative_prob = softmax( negative_logits.logits[mask_ids[:, 0], mask_ids[:, 1]], dim=-1 ) negative_prob = negative_prob[np.arange(len(word_idx)), word_idx] ratio = (positive_prob + epsilon)/ (negative_prob + epsilon) lowest_ratio = torch.topk(ratio, k=num_tokens, largest=False, dim=-1) # pick top_k logits_indices = mask_ids[lowest_ratio.indices] top_k_probs = positive_logits.logits[logits_indices[:, 0], logits_indices[:, 1]] top_k_probs = softmax(top_k_probs, dim=-1) top_k_probs = torch.topk(top_k_probs, k=k_best, dim=-1) # top get words for every small ratio top_k_words = [] for i in range(top_k_probs.indices.shape[0]): top_words = tokenizer.convert_ids_to_tokens(top_k_probs.indices[i]) top_k_words.append(top_words) # construct replaced sentences replaced_words = [] for word_idx, top_words in zip(lowest_ratio.indices, top_k_words): for word in top_words: replaced_sentence = copy(words) replaced_sentence[word_idx] = word replaced_words.append(' '.join(replaced_sentence)) return replaced_words def evaluate_top(model, sentences): predictions = [] for sentence in sentences: inputs = tokenizer(sentence, padding=True, return_tensors="pt", is_split_into_words=True) prediction = model(**inputs) predictions.append(prediction.logits) predictions = torch.cat(predictions, dim=0) return predictions def get_replacements_with_classifier(model, sentence, num_tokens, k_best, m_best, epsilon=1e-3): replacements = get_replacements(sentence, num_tokens, k_best, epsilon=epsilon) top_m_replacements = [] for i in range(num_tokens): top_k = replacements[i*k_best: (i+1)*k_best] top_k_predictions = evaluate_top(model, top_k)[:, 1].flatten() top_m_prediction_idx = torch.topk(top_k_predictions, k=m_best) for idx in top_m_prediction_idx.indices: top_m_replacements.append(top_k[idx]) return top_m_replacements st.set_page_config(page_title="A + B calculator pro max", layout="centered") st.markdown("## Dude, let's convert some negative vibes to positive") negative = st.text_input("Gimme ya review", value='great wings and decent drinks but the wait staff is horrible !') positive = get_replacements_with_classifier( classification_model, negative, 1, 20, 1 )[0] st.text(positive)