File size: 2,770 Bytes
811c4cd
b2bc435
 
 
87612e7
0a09f31
3baff9f
87612e7
94fab07
c4e8f07
87612e7
 
c4e8f07
811c4cd
c4e8f07
 
 
 
 
89db4a1
c4e8f07
83162a1
 
c4e8f07
83162a1
 
 
ba20738
83162a1
c4e8f07
 
 
3542db6
c4e8f07
351306b
 
45bd730
 
 
 
efb6153
44cb37f
45bd730
 
 
612e92e
 
171363a
060557b
612e92e
6122c92
cdfd163
da0f5e9
c6afa28
 
 
 
 
dfc5f4d
c6afa28
 
 
 
15611ec
caff9f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
from transformers import AutoModelForCausalLM, GPT2Tokenizer, StoppingCriteria, StoppingCriteriaList
from transformers import TextIteratorStreamer
from threading import Thread
import torch
import random   

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
PROJECT_MODEL = "RickMartel/GPT2_FT_By_NT_RAND_v11"
model = AutoModelForCausalLM.from_pretrained(PROJECT_MODEL)
model = model.to( device )
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained(PROJECT_MODEL)

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop.to( device ) for stop in stops]
        self.encounters = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        last_tkn = input_ids[0][-1]
        stop_word_found = False
        for stop in self.stops:
            if sum( input_ids[0] == stop ) >= self.encounters: 
                stop_word_found = True
        return stop_word_found and self.stops[0] == last_tkn

# The StoppingCriteriaSub assumes period is the first token id.
stop_words = ['.']
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,
                                                              encounters=3)])
st.set_page_config(page_title="GPT2 4 Bible")
st.sidebar.title("GPT2 4 Bible")
st.sidebar.markdown(
"""
Model notes:
- This is a fine-tuned Hugging Face distilgpt2 model.  
- The dataset used was the Christian New Testament.
- This Space uses a CPU only. So, the app is slow.
- This is a document completion model. Not a Q&A. Input prompts like, "Jesus said".
"""
)

form = st.form(key='my-form')
txt = form.text_input('Enter a prompt')
submit = form.form_submit_button('Submit') 
    
if submit:
    with st.spinner('Processing...'):
        st.markdown("<h4 style='text-align: left;'>Response:</h4>", unsafe_allow_html=True)
        ta = st.empty() 
        input = tokenizer([tokenizer.bos_token + txt], return_tensors="pt")
        streamer = TextIteratorStreamer( tokenizer )
        generation_kwargs = dict(input, streamer=streamer, 
                                 stopping_criteria=stopping_criteria,
                                 do_sample=True,
                                 max_new_tokens=200,)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text.replace('"', "").replace(tokenizer.bos_token,"")
            ta.write( generated_text )