chat-with-myGPT / app.py
Asankhaya Sharma
x
4108df0
raw
history blame contribute delete
No virus
3.49 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import streamlit as st
from streamlit_chat import message
checkpoint = "."
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
@st.cache
def get_model():
model = AutoModelForCausalLM.from_pretrained(checkpoint)
return model
st.title("Chat with myGPT πŸ¦„")
st.write("This is a LLM that was fine-tuned on a dataset of daily conversations.")
if 'count' not in st.session_state or st.session_state.count >= 3:
st.session_state.count = 0
st.session_state.chat_history_ids = None
st.session_state.old_response = ''
else:
st.session_state.count += 1
if 'message_history' not in st.session_state:
st.session_state.message_history = []
if 'response_history' not in st.session_state:
st.session_state.response_history = []
if 'input' not in st.session_state:
st.session_state.input = ''
def submit():
st.session_state.input = st.session_state.user_input
st.session_state.user_input = ''
# prompt = "How long will it take for the poc to finish?"
# inputs = tokenizer(prompt, return_tensors="pt")
model = get_model()
generation_config = GenerationConfig(max_new_tokens=32,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=True,
penalty_alpha=0.6,
top_k=4,
#top_p=0.95,
#temperature=0.8,
pad_token_id=tokenizer.eos_token_id)
for i in range(0, len(st.session_state.message_history)):
message(st.session_state.message_history[i], is_user=True, key=str(i)+'_user', avatar_style="identicon", seed='You') # display all the previous message
if i in range(0, len(st.session_state.response_history)):
message(st.session_state.response_history[i], key=str(i), avatar_style="bottts", seed='mera GPT')
placeholder = st.empty() # placeholder for latest message
st.text_input('You:', key='user_input', on_change=submit)
if st.session_state.input:
st.session_state.message_history.append(st.session_state.input)
new_user_input_ids = tokenizer.encode(tokenizer.eos_token + st.session_state.input, return_tensors="pt")
bot_input_ids = torch.cat([st.session_state.chat_history_ids, new_user_input_ids], dim=-1) if st.session_state.count > 1 else new_user_input_ids
st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config)
response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
if st.session_state.old_response == response:
bot_input_ids = new_user_input_ids
st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config)
response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# st.write(f"meraGPT: {response}")
st.session_state.old_response = response
st.session_state.response_history.append(response)
with placeholder.container():
message(st.session_state.message_history[-1], is_user=True, key=str(-1)+'_user', avatar_style="identicon", seed='You') # display the latest message
message(st.session_state.response_history[-1], key=str(-1), avatar_style="bottts", seed='mera GPT') # display the latest message