style_transfer / app.py
Philipenko Vladimir
7 more fcking fix
f998fc5
import streamlit as st
import torch
import transformers
from copy import copy
st.markdown('Welcome, citizen!')
@st.cache()
def get_models():
classifier = transformers.BertForSequenceClassification.from_pretrained("vovaf709/bert_classifier", return_dict=True, num_labels=2).cpu().train(False)
bert_mlm_positive = transformers.BertForMaskedLM.from_pretrained('vovaf709/bert_mlm_positive', return_dict=True).cpu().train(False)
bert_mlm_negative = transformers.BertForMaskedLM.from_pretrained('vovaf709/bert_mlm_negative', return_dict=True).cpu().train(False)
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
return classifier, bert_mlm_positive, bert_mlm_negative, tokenizer
def bert_INGSOC_enhancer(sentence, num_tokens, k_best, n_iter=1):
winner = sentence
for i in range(n_iter):
candidates = get_replacements(winner, num_tokens, k_best)
max_ = -1e1337
winner = None
for cand in candidates:
sentence_ix = tokenizer(cand, return_tensors='pt')
#sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()}
length = len(sentence_ix['input_ids'][0])
score = classifier(**sentence_ix).logits.softmax(dim=1)[0][1].item()
if score > max_:
winner = cand
max_ = score
return cand
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
"""
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
- find :num_tokens: tokens with the highest ratio (see above)
- replace them with :k_best: words according to bert_mlm_positive
:return: a list of all possible strings (up to k_best * num_tokens)
"""
bert_mlm_positive.train(False)
bert_mlm_negative.train(False)
sentence_ix = tokenizer(sentence, return_tensors='pt')
#sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()}
length = len(sentence_ix['input_ids'][0])
probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0]
probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0]
p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]]
p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]]
diff = (p_tokens_negative + epsilon) / (p_tokens_positive + epsilon)
precious_tokens = [tokenizer.decode(t) for t in list([t] for t in sentence_ix['input_ids'].cpu().numpy()[0])]
topk = diff.topk(num_tokens)[1]
# num_tokens x k_best
replace_ids = probs_positive[topk].topk(k_best, dim=-1)[1]
result = []
for i, replace_me in enumerate(topk):
for replace_by in replace_ids[i]:
replace_token = tokenizer.decode([replace_by])
new_tokens = copy(precious_tokens)
new_tokens[replace_me] = replace_token
result.append(" ".join(new_tokens[1:-1]))
return result
classifier, bert_mlm_positive, bert_mlm_negative, tokenizer = get_models()
user_input = st.text_input('What are you thinking about?')
length = len(user_input.split())
if length > 0:
st.markdown(f'INGSOC sertified: {bert_INGSOC_enhancer(user_input, min([length, 4]), 4, 5)}')