style-transfer / app.py
vosatorp's picture
Update app.py
2ce3f0c
import torch
import streamlit as st
import transformers
from transformers import BertTokenizer, BertForMaskedLM
from transformers import BertForSequenceClassification, DataCollatorWithPadding
st.set_page_config(page_title="style transfer", layout="centered")
st.markdown("Welcome to text style transfer. Wait a few seconds for the model to load...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_cls = BertForSequenceClassification.from_pretrained(
'bert-base-uncased', return_dict=True, problem_type="multi_label_classification", num_labels=2
).to(device).train(False)
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
"""
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
- find :num_tokens: tokens with the highest ratio (see above)
- replace them with :k_best: words according to bert_mlm_positive
:return: a list of all possible strings (up to k_best * num_tokens)
"""
res = []
sentence_ix = tokenizer(sentence, return_tensors='pt')
sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()}
length = len(sentence_ix['input_ids'][0])
probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0]
probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0]
p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]]
p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]]
p_relative = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)
best_pos = torch.argsort(p_relative[1:-1], dim=0)[:num_tokens] + 1
best_pos_tokens = torch.argsort(probs_positive, dim=1)[..., -k_best:]
for pos in best_pos:
for replace_token in best_pos_tokens[pos]:
new_tensor = sentence_ix['input_ids'][0].cpu().numpy()
new_tensor[pos] = replace_token
new_sentence = tokenizer.decode(new_tensor[1:-1])
res.append(new_sentence)
# print(new_sentence)
return res
def beamSearch(sentence, n_rounds=5):
labels = torch.tensor([[1, 1]], dtype=torch.float).to(device)
for i in range(n_rounds):
cur_res = get_replacements(sentence, num_tokens=num_tokens, k_best=k_best)
max_prob = -1
best_sentence = None
for candidate_sentence in cur_res:
inputs = tokenizer(candidate_sentence, return_tensors="pt").to(device)
outputs = bert_cls(**inputs, labels=labels)
prob_good = outputs.logits.softmax(dim=-1)[0][1]
if prob_good > max_prob:
max_prob = prob_good
best_sentence = candidate_sentence
if debug:
st.markdown(f"cur_sentence: {best_sentence}")
sentence = best_sentence
return sentence
bert_mlm_positive.load_state_dict(torch.load('mlm_positive.pth', map_location=torch.device('cpu')))
bert_mlm_negative.load_state_dict(torch.load('mlm_negative.pth', map_location=torch.device('cpu')))
# bert_cls.load_state_dict(torch.load('bert_cls.pth', map_location=torch.device('cpu')))
user_input = st.text_input("Please enter something review")
n_rounds = st.slider("Pick a number of rounds in beamSearch", 1, 10, value=5)
k_best = st.slider("Pick k_best parameter", 1, 5, value=3)
num_tokens = st.slider("Pick num_tokens parameter", 1, 5, value=3)
debug = st.radio("print intermediate steps", [True, False])
if len(user_input.split()) > 0:
res = beamSearch(user_input, n_rounds=n_rounds)
st.markdown("Processed review:")
st.markdown(f"{res}")