GPT2_4_Bible / app.py
RickMartel's picture
Update app.py
3542db6 verified
raw
history blame contribute delete
No virus
2.77 kB
import streamlit as st
from transformers import AutoModelForCausalLM, GPT2Tokenizer, StoppingCriteria, StoppingCriteriaList
from transformers import TextIteratorStreamer
from threading import Thread
import torch
import random
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
PROJECT_MODEL = "RickMartel/GPT2_FT_By_NT_RAND_v11"
model = AutoModelForCausalLM.from_pretrained(PROJECT_MODEL)
model = model.to( device )
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained(PROJECT_MODEL)
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to( device ) for stop in stops]
self.encounters = encounters
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_tkn = input_ids[0][-1]
stop_word_found = False
for stop in self.stops:
if sum( input_ids[0] == stop ) >= self.encounters:
stop_word_found = True
return stop_word_found and self.stops[0] == last_tkn
# The StoppingCriteriaSub assumes period is the first token id.
stop_words = ['.']
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,
encounters=3)])
st.set_page_config(page_title="GPT2 4 Bible")
st.sidebar.title("GPT2 4 Bible")
st.sidebar.markdown(
"""
Model notes:
- This is a fine-tuned Hugging Face distilgpt2 model.
- The dataset used was the Christian New Testament.
- This Space uses a CPU only. So, the app is slow.
- This is a document completion model. Not a Q&A. Input prompts like, "Jesus said".
"""
)
form = st.form(key='my-form')
txt = form.text_input('Enter a prompt')
submit = form.form_submit_button('Submit')
if submit:
with st.spinner('Processing...'):
st.markdown("<h4 style='text-align: left;'>Response:</h4>", unsafe_allow_html=True)
ta = st.empty()
input = tokenizer([tokenizer.bos_token + txt], return_tensors="pt")
streamer = TextIteratorStreamer( tokenizer )
generation_kwargs = dict(input, streamer=streamer,
stopping_criteria=stopping_criteria,
do_sample=True,
max_new_tokens=200,)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text.replace('"', "").replace(tokenizer.bos_token,"")
ta.write( generated_text )