File size: 3,338 Bytes
4d415e1
d22ab1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d415e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from pathlib import Path
import torch
from transformers import BertTokenizer


@st.cache
def get_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    return tokenizer

@st.cache
def load_model_bert_mlm_positive():
    f_checkpoint = Path("bert_mlm_positive.model")

    if not f_checkpoint.exists():
        with st.spinner("Downloading bert_mlm_positive... this may take awhile! \n Don't stop it!"):
            from GD_download import download_file_from_google_drive
            cloud_model_location = "12Gvgv6zaOLJ8oyYXVB5_GYNEfvsjudG_"
            download_file_from_google_drive(cloud_model_location, f_checkpoint)
    
    model = torch.load(f_checkpoint, map_location=torch.device('cpu'))
    model.eval()
    return model

@st.cache
def load_model_model_seq_classify():
    f_checkpoint = Path("model_seq_classify.model")

    if not f_checkpoint.exists():
        with st.spinner("Downloading model_seq_classify... this may take awhile! \n Don't stop it!"):
            from GD_download import download_file_from_google_drive
            cloud_model_location = "13DwlCIM6aYc4WeOCIRqdGy-U0LGc8f0B"
            download_file_from_google_drive(cloud_model_location, f_checkpoint)
    
    model = torch.load(f_checkpoint, map_location=torch.device('cpu'))
    model.eval()
    return model


def get_replacements_beamsearch(tokenizer, bert_mlm_positive, seq_classify_model, sentence: str, num_candidates=3):
    sentence_ix = tokenizer(sentence, return_tensors='pt')
    
    tokens = [tokenizer.decode([t]) for t in sentence_ix['input_ids'].cpu().numpy()[0]]
    
    length = len(sentence_ix['input_ids'][0])
    
    current = [(tokens, 0)]
    for ix in range(1,length-1):
        
        new_current = []
        for item in current:
            sent = " ".join(item[0][1:-1])
            prob_seq = item[1]
            new_current.append(item)
            
            sent_ix = tokenizer(sent, return_tensors='pt')
            logits_positive = bert_mlm_positive(**sent_ix).logits
            probs_positive = logits_positive.softmax(dim=-1)[0, ix]
            indices = torch.argsort(probs_positive, descending=True)
            
            for cand_ix in range(num_candidates):
                token_id = indices[cand_ix]
                new_seq = item[0].copy()
                new_seq[ix] = tokenizer.decode([token_id])
                
                logits = seq_classify_model(**tokenizer(" ".join(new_seq[1:-1]), return_tensors='pt')).logits
                prob = logits.softmax(dim=-1)[0][1]
                
                new_current.append((new_seq, prob))
                
        current = sorted(new_current, key=lambda x: -x[1])[:num_candidates]
                
    return [" ".join(item[0][1:-1]) for item in current]



negative_phrase = st.text_input("Input negative phrase")
num_candidates = st.slider("Number of candidates", min_value=1, max_value=5)


if negative_phrase:
    bert_mlm_positive = load_model_bert_mlm_positive()
    model_seq_classify = load_model_model_seq_classify()

    ret = get_replacements_beamsearch(get_tokenizer(), bert_mlm_positive, 
            model_seq_classify, negative_phrase, num_candidates=num_candidates)

    st.caption("Output positive phrases:")
    for i in range(len(ret)):
        st.caption(ret[i])