ingsoc_censor / app.py
system's picture
system HF staff
Application code
44e043b
import streamlit as st
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
import torch
import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
st.set_page_config(page_title="INGSOC review poster", layout="centered")
@st.cache
def load_models():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device))
bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device))
bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device))
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
"""
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
- find :num_tokens: tokens with the highest ratio (see above)
- replace them with :k_best: words according to bert_mlm_positive
:return: a list of all possible strings (up to k_best * num_tokens)
"""
input = tokenizer(sentence, return_tensors='pt')
input = {key: value.to(device) for key, value in input.items()}
sent_len = input['input_ids'].shape[1]
positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0]
negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0]
positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]]
negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]]
scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon)
candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1
new_variants = []
template = input['input_ids'][0].cpu().numpy()
for candidate in candidates:
top_replaces = torch.argsort(positive_probs[candidate])[-k_best:]
for replace in top_replaces:
new_variant = template.copy()
new_variant[candidate] = replace
new_variants.append(new_variant)
return [tokenizer.decode(variant[1:-1]) for variant in new_variants]
def get_positiveness(sentence):
bert_classifier.eval()
with torch.no_grad():
tokenized = tokenizer(sentence, return_tensors='pt')
tokenized = {key: value.to(device) for key, value in tokenized.items()}
res = bert_classifier(**tokenized)
return res.logits[0][0].item()
def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8):
variants = [sentence]
for _ in range(num_iterations):
suggestions = []
for variant in variants:
suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best))
# don't forget old variants to forget about num_iterations tuning
variants.extend(suggestions)
scores = [get_positiveness(suggestion) for suggestion in variants]
scores = np.array(scores)
beam = np.argsort(scores)[-beam_size:]
new_variants = []
for ind in beam:
new_variants.append(variants[ind])
variants = new_variants
return variants[0]
def process_review(review):
num_iterations = int(st.session_state.num_iterations)
beam_size = int(st.session_state.beam_size)
num_tokens = int(st.session_state.num_tokens)
k_best = int(st.session_state.k_best)
return beam_rewrite(review,
num_iterations=num_iterations,
num_tokens=num_tokens,
k_best=k_best,
beam_size=beam_size)
st.markdown("# INGSOC-approved service for posting Your honest reviews!")
st.text_input("Your honest review: ", key='review')
if st.session_state.review:
with st.spinner('Wait for it...'):
review = process_review(st.session_state.review)
review = review.capitalize()
st.markdown("### Here is Your honest review:")
st.markdown(f'## "{review}"')
with st.expander("Only for class A412C citzens"):
st.number_input('Number of beam search iterations: ',
min_value=1, max_value=20, value=5, key='num_iterations')
st.number_input('Beam size: ',
min_value=1, max_value=20, value=8, key='beam_size')
st.number_input('Number of tokens tested each iteration: ',
min_value=1, max_value=20, value=2, key='num_tokens')
st.number_input('Number of best replacements tested each iteration: ',
min_value=1, max_value=20, value=3, key='k_best')