Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import os | |
import torch | |
import torch.nn as nn | |
from transformers.activations import get_activation | |
from transformers import AutoTokenizer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
tokenizer = AutoTokenizer.from_pretrained("roberta-large") | |
model = AutoModelForMaskedLM.from_pretrained("BigSalmon/FormalRobertaLincoln") | |
#model = AutoModelForMaskedLM.from_pretrained("BigSalmon/MrLincolnBerta") | |
model2 = AutoModelForMaskedLM.from_pretrained("roberta-base") | |
with st.expander('BigSalmon/FormalRobertaa'): | |
with st.form(key='my_form'): | |
prompt = st.text_area(label='Enter Text. Put <mask> where you want the model to fill in the blank. You can use more than one at a time.') | |
submit_button = st.form_submit_button(label='Submit') | |
if submit_button: | |
a_list = [] | |
token_ids = tokenizer.encode(prompt, return_tensors='pt') | |
token_ids_tk = tokenizer.tokenize(prompt, return_tensors='pt') | |
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() | |
masked_pos = [mask.item() for mask in masked_position ] | |
with torch.no_grad(): | |
output = model(token_ids) | |
last_hidden_state = output[0].squeeze() | |
for mask_index in masked_pos: | |
mask_hidden_state = last_hidden_state[mask_index] | |
idx = torch.topk(mask_hidden_state, k=100, dim=0)[1] | |
words = [tokenizer.decode(i.item()).strip() for i in idx] | |
st.text_area(label = 'Infill:', value=words) | |
with st.expander('roberta-base result'): | |
token_ids = tokenizer.encode(prompt, return_tensors='pt') | |
token_ids_tk = tokenizer.tokenize(prompt, return_tensors='pt') | |
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() | |
masked_pos = [mask.item() for mask in masked_position ] | |
with torch.no_grad(): | |
output = model2(token_ids) | |
last_hidden_state = output[0].squeeze() | |
for mask_index in masked_pos: | |
mask_hidden_state = last_hidden_state[mask_index] | |
idx = torch.topk(mask_hidden_state, k=100, dim=0)[1] | |
words = [tokenizer.decode(i.item()).strip() for i in idx] | |
st.text_area(label = 'Infill:', value=words) |