BigSalmon's picture
Update app.py
f5e8576
import torch
import streamlit as st
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from transformers.activations import get_activation
from transformers import AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
model = AutoModelForMaskedLM.from_pretrained("BigSalmon/FormalRobertaLincoln")
#model = AutoModelForMaskedLM.from_pretrained("BigSalmon/MrLincolnBerta")
model2 = AutoModelForMaskedLM.from_pretrained("roberta-base")
with st.expander('BigSalmon/FormalRobertaa'):
with st.form(key='my_form'):
prompt = st.text_area(label='Enter Text. Put <mask> where you want the model to fill in the blank. You can use more than one at a time.')
submit_button = st.form_submit_button(label='Submit')
if submit_button:
a_list = []
token_ids = tokenizer.encode(prompt, return_tensors='pt')
token_ids_tk = tokenizer.tokenize(prompt, return_tensors='pt')
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
masked_pos = [mask.item() for mask in masked_position ]
with torch.no_grad():
output = model(token_ids)
last_hidden_state = output[0].squeeze()
for mask_index in masked_pos:
mask_hidden_state = last_hidden_state[mask_index]
idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
words = [tokenizer.decode(i.item()).strip() for i in idx]
st.text_area(label = 'Infill:', value=words)
with st.expander('roberta-base result'):
token_ids = tokenizer.encode(prompt, return_tensors='pt')
token_ids_tk = tokenizer.tokenize(prompt, return_tensors='pt')
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
masked_pos = [mask.item() for mask in masked_position ]
with torch.no_grad():
output = model2(token_ids)
last_hidden_state = output[0].squeeze()
for mask_index in masked_pos:
mask_hidden_state = last_hidden_state[mask_index]
idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
words = [tokenizer.decode(i.item()).strip() for i in idx]
st.text_area(label = 'Infill:', value=words)