Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
import transformers | |
from transformers import BertTokenizer, BertForMaskedLM | |
from transformers import BertForSequenceClassification, DataCollatorWithPadding | |
st.set_page_config(page_title="style transfer", layout="centered") | |
st.markdown("Welcome to text style transfer. Wait a few seconds for the model to load...") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
bert_cls = BertForSequenceClassification.from_pretrained( | |
'bert-base-uncased', return_dict=True, problem_type="multi_label_classification", num_labels=2 | |
).to(device).train(False) | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
""" | |
- split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
- find :num_tokens: tokens with the highest ratio (see above) | |
- replace them with :k_best: words according to bert_mlm_positive | |
:return: a list of all possible strings (up to k_best * num_tokens) | |
""" | |
res = [] | |
sentence_ix = tokenizer(sentence, return_tensors='pt') | |
sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} | |
length = len(sentence_ix['input_ids'][0]) | |
probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0] | |
probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0] | |
p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]] | |
p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]] | |
p_relative = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon) | |
best_pos = torch.argsort(p_relative[1:-1], dim=0)[:num_tokens] + 1 | |
best_pos_tokens = torch.argsort(probs_positive, dim=1)[..., -k_best:] | |
for pos in best_pos: | |
for replace_token in best_pos_tokens[pos]: | |
new_tensor = sentence_ix['input_ids'][0].cpu().numpy() | |
new_tensor[pos] = replace_token | |
new_sentence = tokenizer.decode(new_tensor[1:-1]) | |
res.append(new_sentence) | |
# print(new_sentence) | |
return res | |
def beamSearch(sentence, n_rounds=5): | |
labels = torch.tensor([[1, 1]], dtype=torch.float).to(device) | |
for i in range(n_rounds): | |
cur_res = get_replacements(sentence, num_tokens=num_tokens, k_best=k_best) | |
max_prob = -1 | |
best_sentence = None | |
for candidate_sentence in cur_res: | |
inputs = tokenizer(candidate_sentence, return_tensors="pt").to(device) | |
outputs = bert_cls(**inputs, labels=labels) | |
prob_good = outputs.logits.softmax(dim=-1)[0][1] | |
if prob_good > max_prob: | |
max_prob = prob_good | |
best_sentence = candidate_sentence | |
if debug: | |
st.markdown(f"cur_sentence: {best_sentence}") | |
sentence = best_sentence | |
return sentence | |
bert_mlm_positive.load_state_dict(torch.load('mlm_positive.pth', map_location=torch.device('cpu'))) | |
bert_mlm_negative.load_state_dict(torch.load('mlm_negative.pth', map_location=torch.device('cpu'))) | |
# bert_cls.load_state_dict(torch.load('bert_cls.pth', map_location=torch.device('cpu'))) | |
user_input = st.text_input("Please enter something review") | |
n_rounds = st.slider("Pick a number of rounds in beamSearch", 1, 10, value=5) | |
k_best = st.slider("Pick k_best parameter", 1, 5, value=3) | |
num_tokens = st.slider("Pick num_tokens parameter", 1, 5, value=3) | |
debug = st.radio("print intermediate steps", [True, False]) | |
if len(user_input.split()) > 0: | |
res = beamSearch(user_input, n_rounds=n_rounds) | |
st.markdown("Processed review:") | |
st.markdown(f"{res}") |