from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import torch import streamlit as st from streamlit_chat import message checkpoint = "." tokenizer = AutoTokenizer.from_pretrained(checkpoint) @st.cache def get_model(): model = AutoModelForCausalLM.from_pretrained(checkpoint) return model st.title("Chat with myGPT 🦄") st.write("This is a LLM that was fine-tuned on a dataset of daily conversations.") if 'count' not in st.session_state or st.session_state.count >= 3: st.session_state.count = 0 st.session_state.chat_history_ids = None st.session_state.old_response = '' else: st.session_state.count += 1 if 'message_history' not in st.session_state: st.session_state.message_history = [] if 'response_history' not in st.session_state: st.session_state.response_history = [] if 'input' not in st.session_state: st.session_state.input = '' def submit(): st.session_state.input = st.session_state.user_input st.session_state.user_input = '' # prompt = "How long will it take for the poc to finish?" # inputs = tokenizer(prompt, return_tensors="pt") model = get_model() generation_config = GenerationConfig(max_new_tokens=32, num_beams=4, early_stopping=True, no_repeat_ngram_size=2, do_sample=True, penalty_alpha=0.6, top_k=4, #top_p=0.95, #temperature=0.8, pad_token_id=tokenizer.eos_token_id) for i in range(0, len(st.session_state.message_history)): message(st.session_state.message_history[i], is_user=True, key=str(i)+'_user', avatar_style="identicon", seed='You') # display all the previous message if i in range(0, len(st.session_state.response_history)): message(st.session_state.response_history[i], key=str(i), avatar_style="bottts", seed='mera GPT') placeholder = st.empty() # placeholder for latest message st.text_input('You:', key='user_input', on_change=submit) if st.session_state.input: st.session_state.message_history.append(st.session_state.input) new_user_input_ids = tokenizer.encode(tokenizer.eos_token + st.session_state.input, return_tensors="pt") bot_input_ids = torch.cat([st.session_state.chat_history_ids, new_user_input_ids], dim=-1) if st.session_state.count > 1 else new_user_input_ids st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config) response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) if st.session_state.old_response == response: bot_input_ids = new_user_input_ids st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config) response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) # st.write(f"meraGPT: {response}") st.session_state.old_response = response st.session_state.response_history.append(response) with placeholder.container(): message(st.session_state.message_history[-1], is_user=True, key=str(-1)+'_user', avatar_style="identicon", seed='You') # display the latest message message(st.session_state.response_history[-1], key=str(-1), avatar_style="bottts", seed='mera GPT') # display the latest message