import streamlit as st import torch import transformers from copy import copy st.markdown('Welcome, citizen!') @st.cache() def get_models(): classifier = transformers.BertForSequenceClassification.from_pretrained("vovaf709/bert_classifier", return_dict=True, num_labels=2).cpu().train(False) bert_mlm_positive = transformers.BertForMaskedLM.from_pretrained('vovaf709/bert_mlm_positive', return_dict=True).cpu().train(False) bert_mlm_negative = transformers.BertForMaskedLM.from_pretrained('vovaf709/bert_mlm_negative', return_dict=True).cpu().train(False) tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') return classifier, bert_mlm_positive, bert_mlm_negative, tokenizer def bert_INGSOC_enhancer(sentence, num_tokens, k_best, n_iter=1): winner = sentence for i in range(n_iter): candidates = get_replacements(winner, num_tokens, k_best) max_ = -1e1337 winner = None for cand in candidates: sentence_ix = tokenizer(cand, return_tensors='pt') #sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} length = len(sentence_ix['input_ids'][0]) score = classifier(**sentence_ix).logits.softmax(dim=1)[0][1].item() if score > max_: winner = cand max_ = score return cand def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): """ - split the sentence into tokens using the INGSOC-approved BERT tokenizer - find :num_tokens: tokens with the highest ratio (see above) - replace them with :k_best: words according to bert_mlm_positive :return: a list of all possible strings (up to k_best * num_tokens) """ bert_mlm_positive.train(False) bert_mlm_negative.train(False) sentence_ix = tokenizer(sentence, return_tensors='pt') #sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} length = len(sentence_ix['input_ids'][0]) probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0] probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0] p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]] p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]] diff = (p_tokens_negative + epsilon) / (p_tokens_positive + epsilon) precious_tokens = [tokenizer.decode(t) for t in list([t] for t in sentence_ix['input_ids'].cpu().numpy()[0])] topk = diff.topk(num_tokens)[1] # num_tokens x k_best replace_ids = probs_positive[topk].topk(k_best, dim=-1)[1] result = [] for i, replace_me in enumerate(topk): for replace_by in replace_ids[i]: replace_token = tokenizer.decode([replace_by]) new_tokens = copy(precious_tokens) new_tokens[replace_me] = replace_token result.append(" ".join(new_tokens[1:-1])) return result classifier, bert_mlm_positive, bert_mlm_negative, tokenizer = get_models() user_input = st.text_input('What are you thinking about?') length = len(user_input.split()) if length > 0: st.markdown(f'INGSOC sertified: {bert_INGSOC_enhancer(user_input, min([length, 4]), 4, 5)}')