import streamlit as st from pathlib import Path import torch from transformers import BertTokenizer @st.cache def get_tokenizer(): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') return tokenizer @st.cache def load_model_bert_mlm_positive(): f_checkpoint = Path("bert_mlm_positive.model") if not f_checkpoint.exists(): with st.spinner("Downloading bert_mlm_positive... this may take awhile! \n Don't stop it!"): from GD_download import download_file_from_google_drive cloud_model_location = "12Gvgv6zaOLJ8oyYXVB5_GYNEfvsjudG_" download_file_from_google_drive(cloud_model_location, f_checkpoint) model = torch.load(f_checkpoint, map_location=torch.device('cpu')) model.eval() return model @st.cache def load_model_model_seq_classify(): f_checkpoint = Path("model_seq_classify.model") if not f_checkpoint.exists(): with st.spinner("Downloading model_seq_classify... this may take awhile! \n Don't stop it!"): from GD_download import download_file_from_google_drive cloud_model_location = "13DwlCIM6aYc4WeOCIRqdGy-U0LGc8f0B" download_file_from_google_drive(cloud_model_location, f_checkpoint) model = torch.load(f_checkpoint, map_location=torch.device('cpu')) model.eval() return model def get_replacements_beamsearch(tokenizer, bert_mlm_positive, seq_classify_model, sentence: str, num_candidates=3): sentence_ix = tokenizer(sentence, return_tensors='pt') tokens = [tokenizer.decode([t]) for t in sentence_ix['input_ids'].cpu().numpy()[0]] length = len(sentence_ix['input_ids'][0]) current = [(tokens, 0)] for ix in range(1,length-1): new_current = [] for item in current: sent = " ".join(item[0][1:-1]) prob_seq = item[1] new_current.append(item) sent_ix = tokenizer(sent, return_tensors='pt') logits_positive = bert_mlm_positive(**sent_ix).logits probs_positive = logits_positive.softmax(dim=-1)[0, ix] indices = torch.argsort(probs_positive, descending=True) for cand_ix in range(num_candidates): token_id = indices[cand_ix] new_seq = item[0].copy() new_seq[ix] = tokenizer.decode([token_id]) logits = seq_classify_model(**tokenizer(" ".join(new_seq[1:-1]), return_tensors='pt')).logits prob = logits.softmax(dim=-1)[0][1] new_current.append((new_seq, prob)) current = sorted(new_current, key=lambda x: -x[1])[:num_candidates] return [" ".join(item[0][1:-1]) for item in current] negative_phrase = st.text_input("Input negative phrase") num_candidates = st.slider("Number of candidates", min_value=1, max_value=5) if negative_phrase: bert_mlm_positive = load_model_bert_mlm_positive() model_seq_classify = load_model_model_seq_classify() ret = get_replacements_beamsearch(get_tokenizer(), bert_mlm_positive, model_seq_classify, negative_phrase, num_candidates=num_candidates) st.caption("Output positive phrases:") for i in range(len(ret)): st.caption(ret[i])