import streamlit as st from transformers import AutoModelForCausalLM, GPT2Tokenizer, StoppingCriteria, StoppingCriteriaList from transformers import TextIteratorStreamer from threading import Thread import torch import random device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") PROJECT_MODEL = "RickMartel/GPT2_FT_By_NT_RAND_v11" model = AutoModelForCausalLM.from_pretrained(PROJECT_MODEL) model = model.to( device ) model.eval() tokenizer = GPT2Tokenizer.from_pretrained(PROJECT_MODEL) class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops = [], encounters=1): super().__init__() self.stops = [stop.to( device ) for stop in stops] self.encounters = encounters def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): last_tkn = input_ids[0][-1] stop_word_found = False for stop in self.stops: if sum( input_ids[0] == stop ) >= self.encounters: stop_word_found = True return stop_word_found and self.stops[0] == last_tkn # The StoppingCriteriaSub assumes period is the first token id. stop_words = ['.'] stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=3)]) st.set_page_config(page_title="GPT2 4 Bible") st.sidebar.title("GPT2 4 Bible") st.sidebar.markdown( """ Model notes: - This is a fine-tuned Hugging Face distilgpt2 model. - The dataset used was the Christian New Testament. - This Space uses a CPU only. So, the app is slow. - This is a document completion model. Not a Q&A. Input prompts like, "Jesus said". """ ) form = st.form(key='my-form') txt = form.text_input('Enter a prompt') submit = form.form_submit_button('Submit') if submit: with st.spinner('Processing...'): st.markdown("

Response:

", unsafe_allow_html=True) ta = st.empty() input = tokenizer([tokenizer.bos_token + txt], return_tensors="pt") streamer = TextIteratorStreamer( tokenizer ) generation_kwargs = dict(input, streamer=streamer, stopping_criteria=stopping_criteria, do_sample=True, max_new_tokens=200,) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text.replace('"', "").replace(tokenizer.bos_token,"") ta.write( generated_text )