from celebbot import CelebBot import streamlit as st import re import spacy import json from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel from utils import * @st.cache_resource def get_seq2seq_model(model_id): return AutoModelForSeq2SeqLM.from_pretrained(model_id) @st.cache_resource def get_auto_model(model_id): return AutoModel.from_pretrained(model_id) @st.cache_resource def get_tokenizer(model_id): return AutoTokenizer.from_pretrained(model_id) @st.cache_data def get_celeb_data(fpath): with open(fpath) as json_file: return json.load(json_file) @st.cache_resource def preprocess_text(name, gender, text, model_id): lname = name.split(" ")[-1] lname_regex = re.compile(rf'\b({lname})\b') name_regex = re.compile(rf'\b({name})\b') lnames = lname+"’s" if not lname.endswith("s") else lname+"’" lnames_regex = re.compile(rf'\b({lnames})\b') names = name+"’s" if not name.endswith("s") else name+"’" names_regex = re.compile(rf'\b({names})\b') if gender == "M": text = re.sub(he_regex, "I", text) text = re.sub(his_regex, "my", text) elif gender == "F": text = re.sub(she_regex, "I", text) text = re.sub(her_regex, "my", text) text = re.sub(names_regex, "my", text) text = re.sub(lnames_regex, "my", text) text = re.sub(name_regex, "I", text) text = re.sub(lname_regex, "I", text) spacy_model = spacy.load(model_id) texts = [i.text.strip() for i in spacy_model(text).sents] return spacy_model, texts def main(): hide_footer() if "messages" not in st.session_state: st.session_state["messages"] = [] if "QA_model_path" not in st.session_state: st.session_state["QA_model_path"] = "google/flan-t5-base" if "sentTr_model_path" not in st.session_state: st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2" if "start_chat" not in st.session_state: st.session_state["start_chat"] = False model_list = ["base", "large", "xl", "xxl"] for message in st.session_state["messages"]: with st.chat_message(message["role"]): st.markdown(message["content"]) celeb_data = get_celeb_data(f'data.json') # Create a Form Component on the Sidebar for accepting input data and parameters celeb_name = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys())) celeb_gender = celeb_data[celeb_name]["gender"] knowledge = celeb_data[celeb_name]["knowledge"] model_choice = st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list) st.session_state["QA_model_path"] = f"google/flan-t5-{model_choice}" # submitted = st.form_submit_button(label="Start Chatting") # if submitted: # st.session_state["start_chat"] = True # if st.session_state["start_chat"]: celeb_bot = CelebBot(celeb_name, get_tokenizer(st.session_state["QA_model_path"]), get_seq2seq_model(st.session_state["QA_model_path"]), get_tokenizer(st.session_state["sentTr_model_path"]), get_auto_model(st.session_state["sentTr_model_path"]), *preprocess_text(celeb_name, celeb_gender, knowledge, "en_core_web_sm") ) prompt = st.chat_input("Say something") print(prompt) if prompt: celeb_bot.text = prompt # Display user message in chat message container st.chat_message("user").markdown(prompt) # Add user message to chat history st.session_state["messages"].append({"role": "user", "content": prompt}) # Add assistant response to chat history response = celeb_bot.question_answer() # disable autoplay to play in HTML b64 = celeb_bot.text_to_speech(autoplay=False) md = f"""

{response}

""" st.chat_message("assistant").markdown( md, unsafe_allow_html=True, ) # Display assistant response in chat message container st.session_state["messages"].append({"role": "assistant", "content": response}) if __name__ == "__main__": main()