File size: 3,489 Bytes
41fa981
 
 
 
 
c665729
41fa981
4108df0
 
41fa981
 
4108df0
41fa981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f87c71
4108df0
2c96f08
41fa981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f121b56
41fa981
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import streamlit as st
from streamlit_chat import message

checkpoint = "."

tokenizer = AutoTokenizer.from_pretrained(checkpoint)    
@st.cache
def get_model():
    model = AutoModelForCausalLM.from_pretrained(checkpoint)
    return model
     
st.title("Chat with myGPT 🦄")     
st.write("This is a LLM that was fine-tuned on a dataset of daily conversations.")

if 'count' not in st.session_state or st.session_state.count >= 3:
    st.session_state.count = 0 
    st.session_state.chat_history_ids = None
    st.session_state.old_response = ''
else:
    st.session_state.count += 1
    
if 'message_history' not in st.session_state:
    st.session_state.message_history = []

if 'response_history' not in st.session_state:
    st.session_state.response_history = []
    
if 'input' not in st.session_state:
    st.session_state.input = ''

def submit():
    st.session_state.input = st.session_state.user_input
    st.session_state.user_input = ''

# prompt = "How long will it take for the poc to finish?"
# inputs = tokenizer(prompt, return_tensors="pt")

model = get_model()
generation_config = GenerationConfig(max_new_tokens=32, 
                        num_beams=4, 
                        early_stopping=True,
                        no_repeat_ngram_size=2, 
                        do_sample=True, 
                        penalty_alpha=0.6,
                        top_k=4, 
                        #top_p=0.95,
                        #temperature=0.8,
                        pad_token_id=tokenizer.eos_token_id)


for i in range(0, len(st.session_state.message_history)):
    message(st.session_state.message_history[i], is_user=True, key=str(i)+'_user', avatar_style="identicon", seed='You') # display all the previous message
    if i in range(0, len(st.session_state.response_history)):
        message(st.session_state.response_history[i], key=str(i), avatar_style="bottts", seed='mera GPT')
        
placeholder = st.empty() # placeholder for latest message
st.text_input('You:', key='user_input', on_change=submit)

if st.session_state.input:
    st.session_state.message_history.append(st.session_state.input)
    new_user_input_ids = tokenizer.encode(tokenizer.eos_token + st.session_state.input, return_tensors="pt")
    bot_input_ids = torch.cat([st.session_state.chat_history_ids, new_user_input_ids], dim=-1) if st.session_state.count > 1 else new_user_input_ids

    st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config)
    response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)

    if st.session_state.old_response == response:
        bot_input_ids = new_user_input_ids
        st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config)
        response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)

    # st.write(f"meraGPT: {response}")
    st.session_state.old_response = response
    st.session_state.response_history.append(response)
    
    with placeholder.container():
        message(st.session_state.message_history[-1], is_user=True, key=str(-1)+'_user', avatar_style="identicon", seed='You') # display the latest message
        message(st.session_state.response_history[-1], key=str(-1), avatar_style="bottts", seed='mera GPT') # display the latest message