import torch import streamlit as st import numpy as np import pandas as pd import os import torch import torch.nn as nn from transformers.activations import get_activation from transformers import AutoTokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("roberta-large") model = AutoModelForMaskedLM.from_pretrained("BigSalmon/FormalRobertaLincoln") #model = AutoModelForMaskedLM.from_pretrained("BigSalmon/MrLincolnBerta") model2 = AutoModelForMaskedLM.from_pretrained("roberta-base") with st.expander('BigSalmon/FormalRobertaa'): with st.form(key='my_form'): prompt = st.text_area(label='Enter Text. Put where you want the model to fill in the blank. You can use more than one at a time.') submit_button = st.form_submit_button(label='Submit') if submit_button: a_list = [] token_ids = tokenizer.encode(prompt, return_tensors='pt') token_ids_tk = tokenizer.tokenize(prompt, return_tensors='pt') masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() masked_pos = [mask.item() for mask in masked_position ] with torch.no_grad(): output = model(token_ids) last_hidden_state = output[0].squeeze() for mask_index in masked_pos: mask_hidden_state = last_hidden_state[mask_index] idx = torch.topk(mask_hidden_state, k=100, dim=0)[1] words = [tokenizer.decode(i.item()).strip() for i in idx] st.text_area(label = 'Infill:', value=words) with st.expander('roberta-base result'): token_ids = tokenizer.encode(prompt, return_tensors='pt') token_ids_tk = tokenizer.tokenize(prompt, return_tensors='pt') masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() masked_pos = [mask.item() for mask in masked_position ] with torch.no_grad(): output = model2(token_ids) last_hidden_state = output[0].squeeze() for mask_index in masked_pos: mask_hidden_state = last_hidden_state[mask_index] idx = torch.topk(mask_hidden_state, k=100, dim=0)[1] words = [tokenizer.decode(i.item()).strip() for i in idx] st.text_area(label = 'Infill:', value=words)