ysda_nlp_task12 / app.py
kruntuid's picture
ysda next commit
d22ab1f
import streamlit as st
from pathlib import Path
import torch
from transformers import BertTokenizer
@st.cache
def get_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
return tokenizer
@st.cache
def load_model_bert_mlm_positive():
f_checkpoint = Path("bert_mlm_positive.model")
if not f_checkpoint.exists():
with st.spinner("Downloading bert_mlm_positive... this may take awhile! \n Don't stop it!"):
from GD_download import download_file_from_google_drive
cloud_model_location = "12Gvgv6zaOLJ8oyYXVB5_GYNEfvsjudG_"
download_file_from_google_drive(cloud_model_location, f_checkpoint)
model = torch.load(f_checkpoint, map_location=torch.device('cpu'))
model.eval()
return model
@st.cache
def load_model_model_seq_classify():
f_checkpoint = Path("model_seq_classify.model")
if not f_checkpoint.exists():
with st.spinner("Downloading model_seq_classify... this may take awhile! \n Don't stop it!"):
from GD_download import download_file_from_google_drive
cloud_model_location = "13DwlCIM6aYc4WeOCIRqdGy-U0LGc8f0B"
download_file_from_google_drive(cloud_model_location, f_checkpoint)
model = torch.load(f_checkpoint, map_location=torch.device('cpu'))
model.eval()
return model
def get_replacements_beamsearch(tokenizer, bert_mlm_positive, seq_classify_model, sentence: str, num_candidates=3):
sentence_ix = tokenizer(sentence, return_tensors='pt')
tokens = [tokenizer.decode([t]) for t in sentence_ix['input_ids'].cpu().numpy()[0]]
length = len(sentence_ix['input_ids'][0])
current = [(tokens, 0)]
for ix in range(1,length-1):
new_current = []
for item in current:
sent = " ".join(item[0][1:-1])
prob_seq = item[1]
new_current.append(item)
sent_ix = tokenizer(sent, return_tensors='pt')
logits_positive = bert_mlm_positive(**sent_ix).logits
probs_positive = logits_positive.softmax(dim=-1)[0, ix]
indices = torch.argsort(probs_positive, descending=True)
for cand_ix in range(num_candidates):
token_id = indices[cand_ix]
new_seq = item[0].copy()
new_seq[ix] = tokenizer.decode([token_id])
logits = seq_classify_model(**tokenizer(" ".join(new_seq[1:-1]), return_tensors='pt')).logits
prob = logits.softmax(dim=-1)[0][1]
new_current.append((new_seq, prob))
current = sorted(new_current, key=lambda x: -x[1])[:num_candidates]
return [" ".join(item[0][1:-1]) for item in current]
negative_phrase = st.text_input("Input negative phrase")
num_candidates = st.slider("Number of candidates", min_value=1, max_value=5)
if negative_phrase:
bert_mlm_positive = load_model_bert_mlm_positive()
model_seq_classify = load_model_model_seq_classify()
ret = get_replacements_beamsearch(get_tokenizer(), bert_mlm_positive,
model_seq_classify, negative_phrase, num_candidates=num_candidates)
st.caption("Output positive phrases:")
for i in range(len(ret)):
st.caption(ret[i])