File size: 3,890 Bytes
2ce3f0c
ca802b1
2ce3f0c
adbb296
2ce3f0c
 
adbb296
2ce3f0c
 
adbb296
2ce3f0c
adbb296
2ce3f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbb296
2ce3f0c
 
 
 
 
 
 
adbb296
2ce3f0c
 
 
 
 
 
 
 
 
 
adbb296
 
2ce3f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05de302
 
2ce3f0c
 
 
adbb296
2ce3f0c
adbb296
2ce3f0c
 
 
 
7cb2bc6
adbb296
2ce3f0c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import streamlit as st
import transformers

from transformers import BertTokenizer, BertForMaskedLM
from transformers import BertForSequenceClassification, DataCollatorWithPadding

st.set_page_config(page_title="style transfer", layout="centered")
st.markdown("Welcome to text style transfer. Wait a few seconds for the model to load...")

device = 'cuda' if torch.cuda.is_available() else 'cpu'


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
bert_cls = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', return_dict=True, problem_type="multi_label_classification", num_labels=2
).to(device).train(False)

def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
    """
    - split the sentence into tokens using the INGSOC-approved BERT tokenizer
    - find :num_tokens: tokens with the highest ratio (see above)
    - replace them with :k_best: words according to bert_mlm_positive
    :return: a list of all possible strings (up to k_best * num_tokens)
    """
    res = []
    sentence_ix = tokenizer(sentence, return_tensors='pt')
    sentence_ix = {key: value.to(device) for key, value in sentence_ix.items()}
    length = len(sentence_ix['input_ids'][0])
    probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0]
    probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0]
    p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]]
    p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]]

    p_relative = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)
    best_pos = torch.argsort(p_relative[1:-1], dim=0)[:num_tokens] + 1
    best_pos_tokens = torch.argsort(probs_positive, dim=1)[..., -k_best:]
    for pos in best_pos:
        for replace_token in best_pos_tokens[pos]:
            new_tensor = sentence_ix['input_ids'][0].cpu().numpy()
            new_tensor[pos] = replace_token
            new_sentence = tokenizer.decode(new_tensor[1:-1])
            res.append(new_sentence)
            # print(new_sentence)
    return res

def beamSearch(sentence, n_rounds=5):
    labels = torch.tensor([[1, 1]], dtype=torch.float).to(device)
    for i in range(n_rounds):
        cur_res = get_replacements(sentence, num_tokens=num_tokens, k_best=k_best)
        max_prob = -1
        best_sentence = None
        for candidate_sentence in cur_res:
            inputs = tokenizer(candidate_sentence, return_tensors="pt").to(device)
            outputs = bert_cls(**inputs, labels=labels)
            prob_good = outputs.logits.softmax(dim=-1)[0][1]
            if prob_good > max_prob:
                max_prob = prob_good
                best_sentence = candidate_sentence
        if debug:
            st.markdown(f"cur_sentence: {best_sentence}")
        sentence = best_sentence
    return sentence


bert_mlm_positive.load_state_dict(torch.load('mlm_positive.pth', map_location=torch.device('cpu')))
bert_mlm_negative.load_state_dict(torch.load('mlm_negative.pth', map_location=torch.device('cpu')))
# bert_cls.load_state_dict(torch.load('bert_cls.pth', map_location=torch.device('cpu')))

user_input = st.text_input("Please enter something review")

n_rounds = st.slider("Pick a number of rounds in beamSearch", 1, 10, value=5)
k_best = st.slider("Pick k_best parameter", 1, 5, value=3)
num_tokens = st.slider("Pick num_tokens parameter", 1, 5, value=3)
debug = st.radio("print intermediate steps", [True, False])


if len(user_input.split()) > 0:
    res = beamSearch(user_input, n_rounds=n_rounds)
    st.markdown("Processed review:")
    st.markdown(f"{res}")