ingsoc_censor / app.py
system's picture
system HF staff
Application code
44e043b
raw history blame
No virus
4.71 kB
import streamlit as st
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
import torch
import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
st.set_page_config(page_title="INGSOC review poster", layout="centered")
@st.cache
def load_models():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device))
bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device))
bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device))
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
"""
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
- find :num_tokens: tokens with the highest ratio (see above)
- replace them with :k_best: words according to bert_mlm_positive
:return: a list of all possible strings (up to k_best * num_tokens)
"""
input = tokenizer(sentence, return_tensors='pt')
input = {key: value.to(device) for key, value in input.items()}
sent_len = input['input_ids'].shape[1]
positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0]
negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0]
positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]]
negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]]
scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon)
candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1
new_variants = []
template = input['input_ids'][0].cpu().numpy()
for candidate in candidates:
top_replaces = torch.argsort(positive_probs[candidate])[-k_best:]
for replace in top_replaces:
new_variant = template.copy()
new_variant[candidate] = replace
new_variants.append(new_variant)
return [tokenizer.decode(variant[1:-1]) for variant in new_variants]
def get_positiveness(sentence):
bert_classifier.eval()
with torch.no_grad():
tokenized = tokenizer(sentence, return_tensors='pt')
tokenized = {key: value.to(device) for key, value in tokenized.items()}
res = bert_classifier(**tokenized)
return res.logits[0][0].item()
def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8):
variants = [sentence]
for _ in range(num_iterations):
suggestions = []
for variant in variants:
suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best))
# don't forget old variants to forget about num_iterations tuning
variants.extend(suggestions)
scores = [get_positiveness(suggestion) for suggestion in variants]
scores = np.array(scores)
beam = np.argsort(scores)[-beam_size:]
new_variants = []
for ind in beam:
new_variants.append(variants[ind])
variants = new_variants
return variants[0]
def process_review(review):
num_iterations = int(st.session_state.num_iterations)
beam_size = int(st.session_state.beam_size)
num_tokens = int(st.session_state.num_tokens)
k_best = int(st.session_state.k_best)
return beam_rewrite(review,
num_iterations=num_iterations,
num_tokens=num_tokens,
k_best=k_best,
beam_size=beam_size)
st.markdown("# INGSOC-approved service for posting Your honest reviews!")
st.text_input("Your honest review: ", key='review')
if st.session_state.review:
with st.spinner('Wait for it...'):
review = process_review(st.session_state.review)
review = review.capitalize()
st.markdown("### Here is Your honest review:")
st.markdown(f'## "{review}"')
with st.expander("Only for class A412C citzens"):
st.number_input('Number of beam search iterations: ',
min_value=1, max_value=20, value=5, key='num_iterations')
st.number_input('Beam size: ',
min_value=1, max_value=20, value=8, key='beam_size')
st.number_input('Number of tokens tested each iteration: ',
min_value=1, max_value=20, value=2, key='num_tokens')
st.number_input('Number of best replacements tested each iteration: ',
min_value=1, max_value=20, value=3, key='k_best')