Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification | |
import torch | |
import numpy as np | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
st.set_page_config(page_title="INGSOC review poster", layout="centered") | |
def load_models(): | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device)) | |
bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device)) | |
bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device)) | |
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier | |
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models() | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
""" | |
- split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
- find :num_tokens: tokens with the highest ratio (see above) | |
- replace them with :k_best: words according to bert_mlm_positive | |
:return: a list of all possible strings (up to k_best * num_tokens) | |
""" | |
input = tokenizer(sentence, return_tensors='pt') | |
input = {key: value.to(device) for key, value in input.items()} | |
sent_len = input['input_ids'].shape[1] | |
positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0] | |
negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0] | |
positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]] | |
negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]] | |
scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon) | |
candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1 | |
new_variants = [] | |
template = input['input_ids'][0].cpu().numpy() | |
for candidate in candidates: | |
top_replaces = torch.argsort(positive_probs[candidate])[-k_best:] | |
for replace in top_replaces: | |
new_variant = template.copy() | |
new_variant[candidate] = replace | |
new_variants.append(new_variant) | |
return [tokenizer.decode(variant[1:-1]) for variant in new_variants] | |
def get_positiveness(sentence): | |
bert_classifier.eval() | |
with torch.no_grad(): | |
tokenized = tokenizer(sentence, return_tensors='pt') | |
tokenized = {key: value.to(device) for key, value in tokenized.items()} | |
res = bert_classifier(**tokenized) | |
return res.logits[0][0].item() | |
def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8): | |
variants = [sentence] | |
for _ in range(num_iterations): | |
suggestions = [] | |
for variant in variants: | |
suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best)) | |
# don't forget old variants to forget about num_iterations tuning | |
variants.extend(suggestions) | |
scores = [get_positiveness(suggestion) for suggestion in variants] | |
scores = np.array(scores) | |
beam = np.argsort(scores)[-beam_size:] | |
new_variants = [] | |
for ind in beam: | |
new_variants.append(variants[ind]) | |
variants = new_variants | |
return variants[0] | |
def process_review(review): | |
num_iterations = int(st.session_state.num_iterations) | |
beam_size = int(st.session_state.beam_size) | |
num_tokens = int(st.session_state.num_tokens) | |
k_best = int(st.session_state.k_best) | |
return beam_rewrite(review, | |
num_iterations=num_iterations, | |
num_tokens=num_tokens, | |
k_best=k_best, | |
beam_size=beam_size) | |
st.markdown("# INGSOC-approved service for posting Your honest reviews!") | |
st.text_input("Your honest review: ", key='review') | |
if st.session_state.review: | |
with st.spinner('Wait for it...'): | |
review = process_review(st.session_state.review) | |
review = review.capitalize() | |
st.markdown("### Here is Your honest review:") | |
st.markdown(f'## "{review}"') | |
with st.expander("Only for class A412C citzens"): | |
st.number_input('Number of beam search iterations: ', | |
min_value=1, max_value=20, value=5, key='num_iterations') | |
st.number_input('Beam size: ', | |
min_value=1, max_value=20, value=8, key='beam_size') | |
st.number_input('Number of tokens tested each iteration: ', | |
min_value=1, max_value=20, value=2, key='num_tokens') | |
st.number_input('Number of best replacements tested each iteration: ', | |
min_value=1, max_value=20, value=3, key='k_best') | |