Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import transformers | |
from copy import copy | |
st.markdown('Welcome, citizen!') | |
def get_models(): | |
classifier = transformers.BertForSequenceClassification.from_pretrained("vovaf709/bert_classifier", return_dict=True, num_labels=2).cpu().train(False) | |
bert_mlm_positive = transformers.BertForMaskedLM.from_pretrained('vovaf709/bert_mlm_positive', return_dict=True).cpu().train(False) | |
bert_mlm_negative = transformers.BertForMaskedLM.from_pretrained('vovaf709/bert_mlm_negative', return_dict=True).cpu().train(False) | |
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') | |
return classifier, bert_mlm_positive, bert_mlm_negative, tokenizer | |
def bert_INGSOC_enhancer(sentence, num_tokens, k_best, n_iter=1): | |
winner = sentence | |
for i in range(n_iter): | |
candidates = get_replacements(winner, num_tokens, k_best) | |
max_ = -1e1337 | |
winner = None | |
for cand in candidates: | |
sentence_ix = tokenizer(cand, return_tensors='pt') | |
#sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} | |
length = len(sentence_ix['input_ids'][0]) | |
score = classifier(**sentence_ix).logits.softmax(dim=1)[0][1].item() | |
if score > max_: | |
winner = cand | |
max_ = score | |
return cand | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
""" | |
- split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
- find :num_tokens: tokens with the highest ratio (see above) | |
- replace them with :k_best: words according to bert_mlm_positive | |
:return: a list of all possible strings (up to k_best * num_tokens) | |
""" | |
bert_mlm_positive.train(False) | |
bert_mlm_negative.train(False) | |
sentence_ix = tokenizer(sentence, return_tensors='pt') | |
#sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} | |
length = len(sentence_ix['input_ids'][0]) | |
probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0] | |
probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0] | |
p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]] | |
p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]] | |
diff = (p_tokens_negative + epsilon) / (p_tokens_positive + epsilon) | |
precious_tokens = [tokenizer.decode(t) for t in list([t] for t in sentence_ix['input_ids'].cpu().numpy()[0])] | |
topk = diff.topk(num_tokens)[1] | |
# num_tokens x k_best | |
replace_ids = probs_positive[topk].topk(k_best, dim=-1)[1] | |
result = [] | |
for i, replace_me in enumerate(topk): | |
for replace_by in replace_ids[i]: | |
replace_token = tokenizer.decode([replace_by]) | |
new_tokens = copy(precious_tokens) | |
new_tokens[replace_me] = replace_token | |
result.append(" ".join(new_tokens[1:-1])) | |
return result | |
classifier, bert_mlm_positive, bert_mlm_negative, tokenizer = get_models() | |
user_input = st.text_input('What are you thinking about?') | |
length = len(user_input.split()) | |
if length > 0: | |
st.markdown(f'INGSOC sertified: {bert_INGSOC_enhancer(user_input, min([length, 4]), 4, 5)}') | |