Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import os | |
import torch | |
import torch.nn as nn | |
from transformers.activations import get_activation | |
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM | |
st.title('DeepWords') | |
st.text('Still under Construction.') | |
st.text('Tip: Try writing a sentence and making the model predict final word.') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_model(): | |
tokenizer = AutoTokenizer.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator") | |
model = AutoModelForCausalLM.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator") | |
return model, tokenizer | |
model, tokenizer = get_model() | |
#g = | |
c = 5 | |
with st.form(key='my_form'): | |
prompt = st.text_input('Enter sentence:', '') | |
c = st.number_input('Enter Number of words: ', 1) | |
submit_button = st.form_submit_button(label='Submit') | |
if submit_button: | |
with torch.no_grad(): | |
text = tokenizer.encode(prompt) | |
myinput, past_key_values = torch.tensor([text]), None | |
myinput = myinput | |
myinput= myinput.to(device) | |
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) | |
logits = logits[0,-1] | |
probabilities = torch.nn.functional.softmax(logits) | |
best_logits, best_indices = logits.topk(350) | |
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] | |
text.append(best_indices[0].item()) | |
best_probabilities = probabilities[best_indices].tolist() | |
words = [] | |
best_words = ' '.join(best_words[0:c]) | |
final_string = prompt + best_words | |
st.write(final_string) |