Spaces:
Runtime error
Runtime error
import streamlit as st | |
from pathlib import Path | |
import torch | |
from transformers import BertTokenizer | |
def get_tokenizer(): | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
return tokenizer | |
def load_model_bert_mlm_positive(): | |
f_checkpoint = Path("bert_mlm_positive.model") | |
if not f_checkpoint.exists(): | |
with st.spinner("Downloading bert_mlm_positive... this may take awhile! \n Don't stop it!"): | |
from GD_download import download_file_from_google_drive | |
cloud_model_location = "12Gvgv6zaOLJ8oyYXVB5_GYNEfvsjudG_" | |
download_file_from_google_drive(cloud_model_location, f_checkpoint) | |
model = torch.load(f_checkpoint, map_location=torch.device('cpu')) | |
model.eval() | |
return model | |
def load_model_model_seq_classify(): | |
f_checkpoint = Path("model_seq_classify.model") | |
if not f_checkpoint.exists(): | |
with st.spinner("Downloading model_seq_classify... this may take awhile! \n Don't stop it!"): | |
from GD_download import download_file_from_google_drive | |
cloud_model_location = "13DwlCIM6aYc4WeOCIRqdGy-U0LGc8f0B" | |
download_file_from_google_drive(cloud_model_location, f_checkpoint) | |
model = torch.load(f_checkpoint, map_location=torch.device('cpu')) | |
model.eval() | |
return model | |
def get_replacements_beamsearch(tokenizer, bert_mlm_positive, seq_classify_model, sentence: str, num_candidates=3): | |
sentence_ix = tokenizer(sentence, return_tensors='pt') | |
tokens = [tokenizer.decode([t]) for t in sentence_ix['input_ids'].cpu().numpy()[0]] | |
length = len(sentence_ix['input_ids'][0]) | |
current = [(tokens, 0)] | |
for ix in range(1,length-1): | |
new_current = [] | |
for item in current: | |
sent = " ".join(item[0][1:-1]) | |
prob_seq = item[1] | |
new_current.append(item) | |
sent_ix = tokenizer(sent, return_tensors='pt') | |
logits_positive = bert_mlm_positive(**sent_ix).logits | |
probs_positive = logits_positive.softmax(dim=-1)[0, ix] | |
indices = torch.argsort(probs_positive, descending=True) | |
for cand_ix in range(num_candidates): | |
token_id = indices[cand_ix] | |
new_seq = item[0].copy() | |
new_seq[ix] = tokenizer.decode([token_id]) | |
logits = seq_classify_model(**tokenizer(" ".join(new_seq[1:-1]), return_tensors='pt')).logits | |
prob = logits.softmax(dim=-1)[0][1] | |
new_current.append((new_seq, prob)) | |
current = sorted(new_current, key=lambda x: -x[1])[:num_candidates] | |
return [" ".join(item[0][1:-1]) for item in current] | |
negative_phrase = st.text_input("Input negative phrase") | |
num_candidates = st.slider("Number of candidates", min_value=1, max_value=5) | |
if negative_phrase: | |
bert_mlm_positive = load_model_bert_mlm_positive() | |
model_seq_classify = load_model_model_seq_classify() | |
ret = get_replacements_beamsearch(get_tokenizer(), bert_mlm_positive, | |
model_seq_classify, negative_phrase, num_candidates=num_candidates) | |
st.caption("Output positive phrases:") | |
for i in range(len(ret)): | |
st.caption(ret[i]) | |