heil-A.412C / app.py
ewriji's picture
Update app.py
d669a08
import streamlit as st
import torch
import numpy as np
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
import torch.nn.functional as F
from copy import copy
from torch.nn.functional import softmax
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained(
'ewriji/heil-A.412C-positive', return_dict=True
)
bert_mlm_negative = BertForMaskedLM.from_pretrained(
'ewriji/heil-A.412C-negative', return_dict=True
)
classification_model = BertForSequenceClassification.from_pretrained(
'ewriji/heil-A.412C-classification', return_dict=True
)
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
"""
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
- find :num_tokens: tokens with the highest ratio (see above)
- replace them with :k_best: words according to bert_mlm_positive
:return: a list of all possible strings (up to k_best * num_tokens)
"""
words = sentence.split()
batch = []
mask_word = []
for i in range(len(words)):
masked = copy(words)
mask_word.append(masked[i])
masked[i] = tokenizer.mask_token
batch.append(masked)
input = tokenizer(batch, padding=True, return_tensors="pt", is_split_into_words=True)
mask_ids = (input["input_ids"] == tokenizer.mask_token_id).nonzero().cpu()
# predict probabilities
positive_logits = bert_mlm_positive(**input)
negative_logits = bert_mlm_negative(**input)
word_idx = [tokenizer.encode(word, add_special_tokens=False)[0] for word in mask_word]
positive_prob = softmax(
positive_logits.logits[mask_ids[:, 0], mask_ids[:, 1]],
dim=-1
)
positive_prob = positive_prob[np.arange(len(word_idx)), word_idx]
negative_prob = softmax(
negative_logits.logits[mask_ids[:, 0], mask_ids[:, 1]],
dim=-1
)
negative_prob = negative_prob[np.arange(len(word_idx)), word_idx]
ratio = (positive_prob + epsilon)/ (negative_prob + epsilon)
lowest_ratio = torch.topk(ratio, k=num_tokens, largest=False, dim=-1)
# pick top_k
logits_indices = mask_ids[lowest_ratio.indices]
top_k_probs = positive_logits.logits[logits_indices[:, 0], logits_indices[:, 1]]
top_k_probs = softmax(top_k_probs, dim=-1)
top_k_probs = torch.topk(top_k_probs, k=k_best, dim=-1)
# top get words for every small ratio
top_k_words = []
for i in range(top_k_probs.indices.shape[0]):
top_words = tokenizer.convert_ids_to_tokens(top_k_probs.indices[i])
top_k_words.append(top_words)
# construct replaced sentences
replaced_words = []
for word_idx, top_words in zip(lowest_ratio.indices, top_k_words):
for word in top_words:
replaced_sentence = copy(words)
replaced_sentence[word_idx] = word
replaced_words.append(' '.join(replaced_sentence))
return replaced_words
def evaluate_top(model, sentences):
predictions = []
for sentence in sentences:
inputs = tokenizer(sentence, padding=True, return_tensors="pt", is_split_into_words=True)
prediction = model(**inputs)
predictions.append(prediction.logits)
predictions = torch.cat(predictions, dim=0)
return predictions
def get_replacements_with_classifier(model, sentence, num_tokens, k_best, m_best, epsilon=1e-3):
replacements = get_replacements(sentence, num_tokens, k_best, epsilon=epsilon)
top_m_replacements = []
for i in range(num_tokens):
top_k = replacements[i*k_best: (i+1)*k_best]
top_k_predictions = evaluate_top(model, top_k)[:, 1].flatten()
top_m_prediction_idx = torch.topk(top_k_predictions, k=m_best)
for idx in top_m_prediction_idx.indices:
top_m_replacements.append(top_k[idx])
return top_m_replacements
st.set_page_config(page_title="A + B calculator pro max", layout="centered")
st.markdown("## Dude, let's convert some negative vibes to positive")
negative = st.text_input("Gimme ya review", value='great wings and decent drinks but the wait staff is horrible !')
positive = get_replacements_with_classifier(
classification_model,
negative,
1,
20,
1
)[0]
st.text(positive)