Spaces:
Runtime error
Runtime error
| import torch | |
| import streamlit as st | |
| import transformers | |
| from transformers import BertTokenizer, BertForMaskedLM | |
| from transformers import BertForSequenceClassification, DataCollatorWithPadding | |
| st.set_page_config(page_title="style transfer", layout="centered") | |
| st.markdown("Welcome to text style transfer. Wait a few seconds for the model to load...") | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
| bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False) | |
| bert_cls = BertForSequenceClassification.from_pretrained( | |
| 'bert-base-uncased', return_dict=True, problem_type="multi_label_classification", num_labels=2 | |
| ).to(device).train(False) | |
| def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
| """ | |
| - split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
| - find :num_tokens: tokens with the highest ratio (see above) | |
| - replace them with :k_best: words according to bert_mlm_positive | |
| :return: a list of all possible strings (up to k_best * num_tokens) | |
| """ | |
| res = [] | |
| sentence_ix = tokenizer(sentence, return_tensors='pt') | |
| sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()} | |
| length = len(sentence_ix['input_ids'][0]) | |
| probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0] | |
| probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0] | |
| p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]] | |
| p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]] | |
| p_relative = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon) | |
| best_pos = torch.argsort(p_relative[1:-1], dim=0)[:num_tokens] + 1 | |
| best_pos_tokens = torch.argsort(probs_positive, dim=1)[..., -k_best:] | |
| for pos in best_pos: | |
| for replace_token in best_pos_tokens[pos]: | |
| new_tensor = sentence_ix['input_ids'][0].cpu().numpy() | |
| new_tensor[pos] = replace_token | |
| new_sentence = tokenizer.decode(new_tensor[1:-1]) | |
| res.append(new_sentence) | |
| # print(new_sentence) | |
| return res | |
| def beamSearch(sentence, n_rounds=5): | |
| labels = torch.tensor([[1, 1]], dtype=torch.float).to(device) | |
| for i in range(n_rounds): | |
| cur_res = get_replacements(sentence, num_tokens=num_tokens, k_best=k_best) | |
| max_prob = -1 | |
| best_sentence = None | |
| for candidate_sentence in cur_res: | |
| inputs = tokenizer(candidate_sentence, return_tensors="pt").to(device) | |
| outputs = bert_cls(**inputs, labels=labels) | |
| prob_good = outputs.logits.softmax(dim=-1)[0][1] | |
| if prob_good > max_prob: | |
| max_prob = prob_good | |
| best_sentence = candidate_sentence | |
| if debug: | |
| st.markdown(f"cur_sentence: {best_sentence}") | |
| sentence = best_sentence | |
| return sentence | |
| bert_mlm_positive.load_state_dict(torch.load('mlm_positive.pth', map_location=torch.device('cpu'))) | |
| bert_mlm_negative.load_state_dict(torch.load('mlm_negative.pth', map_location=torch.device('cpu'))) | |
| # bert_cls.load_state_dict(torch.load('bert_cls.pth', map_location=torch.device('cpu'))) | |
| user_input = st.text_input("Please enter something review") | |
| n_rounds = st.slider("Pick a number of rounds in beamSearch", 1, 10, value=5) | |
| k_best = st.slider("Pick k_best parameter", 1, 5, value=3) | |
| num_tokens = st.slider("Pick num_tokens parameter", 1, 5, value=3) | |
| debug = st.radio("print intermediate steps", [True, False]) | |
| if len(user_input.split()) > 0: | |
| res = beamSearch(user_input, n_rounds=n_rounds) | |
| st.markdown("Processed review:") | |
| st.markdown(f"{res}") |