import streamlit as st from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification import torch import numpy as np device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') st.set_page_config(page_title="INGSOC review poster", layout="centered") @st.cache def load_models(): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device)) bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device)) bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device)) return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models() def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): """ - split the sentence into tokens using the INGSOC-approved BERT tokenizer - find :num_tokens: tokens with the highest ratio (see above) - replace them with :k_best: words according to bert_mlm_positive :return: a list of all possible strings (up to k_best * num_tokens) """ input = tokenizer(sentence, return_tensors='pt') input = {key: value.to(device) for key, value in input.items()} sent_len = input['input_ids'].shape[1] positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0] negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0] positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]] negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]] scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon) candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1 new_variants = [] template = input['input_ids'][0].cpu().numpy() for candidate in candidates: top_replaces = torch.argsort(positive_probs[candidate])[-k_best:] for replace in top_replaces: new_variant = template.copy() new_variant[candidate] = replace new_variants.append(new_variant) return [tokenizer.decode(variant[1:-1]) for variant in new_variants] def get_positiveness(sentence): bert_classifier.eval() with torch.no_grad(): tokenized = tokenizer(sentence, return_tensors='pt') tokenized = {key: value.to(device) for key, value in tokenized.items()} res = bert_classifier(**tokenized) return res.logits[0][0].item() def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8): variants = [sentence] for _ in range(num_iterations): suggestions = [] for variant in variants: suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best)) # don't forget old variants to forget about num_iterations tuning variants.extend(suggestions) scores = [get_positiveness(suggestion) for suggestion in variants] scores = np.array(scores) beam = np.argsort(scores)[-beam_size:] new_variants = [] for ind in beam: new_variants.append(variants[ind]) variants = new_variants return variants[0] def process_review(review): num_iterations = int(st.session_state.num_iterations) beam_size = int(st.session_state.beam_size) num_tokens = int(st.session_state.num_tokens) k_best = int(st.session_state.k_best) return beam_rewrite(review, num_iterations=num_iterations, num_tokens=num_tokens, k_best=k_best, beam_size=beam_size) st.markdown("# INGSOC-approved service for posting Your honest reviews!") st.text_input("Your honest review: ", key='review') if st.session_state.review: with st.spinner('Wait for it...'): review = process_review(st.session_state.review) review = review.capitalize() st.markdown("### Here is Your honest review:") st.markdown(f'## "{review}"') with st.expander("Only for class A412C citzens"): st.number_input('Number of beam search iterations: ', min_value=1, max_value=20, value=5, key='num_iterations') st.number_input('Beam size: ', min_value=1, max_value=20, value=8, key='beam_size') st.number_input('Number of tokens tested each iteration: ', min_value=1, max_value=20, value=2, key='num_tokens') st.number_input('Number of best replacements tested each iteration: ', min_value=1, max_value=20, value=3, key='k_best')