Spaces:
Sleeping
Sleeping
File size: 2,770 Bytes
811c4cd b2bc435 87612e7 0a09f31 3baff9f 87612e7 94fab07 c4e8f07 87612e7 c4e8f07 811c4cd c4e8f07 89db4a1 c4e8f07 83162a1 c4e8f07 83162a1 ba20738 83162a1 c4e8f07 3542db6 c4e8f07 351306b 45bd730 efb6153 44cb37f 45bd730 612e92e 171363a 060557b 612e92e 6122c92 cdfd163 da0f5e9 c6afa28 dfc5f4d c6afa28 15611ec caff9f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
from transformers import AutoModelForCausalLM, GPT2Tokenizer, StoppingCriteria, StoppingCriteriaList
from transformers import TextIteratorStreamer
from threading import Thread
import torch
import random
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
PROJECT_MODEL = "RickMartel/GPT2_FT_By_NT_RAND_v11"
model = AutoModelForCausalLM.from_pretrained(PROJECT_MODEL)
model = model.to( device )
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained(PROJECT_MODEL)
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to( device ) for stop in stops]
self.encounters = encounters
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_tkn = input_ids[0][-1]
stop_word_found = False
for stop in self.stops:
if sum( input_ids[0] == stop ) >= self.encounters:
stop_word_found = True
return stop_word_found and self.stops[0] == last_tkn
# The StoppingCriteriaSub assumes period is the first token id.
stop_words = ['.']
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,
encounters=3)])
st.set_page_config(page_title="GPT2 4 Bible")
st.sidebar.title("GPT2 4 Bible")
st.sidebar.markdown(
"""
Model notes:
- This is a fine-tuned Hugging Face distilgpt2 model.
- The dataset used was the Christian New Testament.
- This Space uses a CPU only. So, the app is slow.
- This is a document completion model. Not a Q&A. Input prompts like, "Jesus said".
"""
)
form = st.form(key='my-form')
txt = form.text_input('Enter a prompt')
submit = form.form_submit_button('Submit')
if submit:
with st.spinner('Processing...'):
st.markdown("<h4 style='text-align: left;'>Response:</h4>", unsafe_allow_html=True)
ta = st.empty()
input = tokenizer([tokenizer.bos_token + txt], return_tensors="pt")
streamer = TextIteratorStreamer( tokenizer )
generation_kwargs = dict(input, streamer=streamer,
stopping_criteria=stopping_criteria,
do_sample=True,
max_new_tokens=200,)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text.replace('"', "").replace(tokenizer.bos_token,"")
ta.write( generated_text )
|