import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch TOKEN_LIMIT = 2048 TEMPERATURE = 0.3 REPETITION_PENALTY = 1.05 MAX_NEW_TOKENS = 500 MODEL_NAME = "ericzzz/falcon-rw-1b-chat" # fmt: off st.write("**💬Tiny Chat with [Falcon-RW-1B-Chat](https://huggingface.co/ericzzz/falcon-rw-1b-chat)**" ) st.write("*The model operates on free-tier hardware, which may lead to slower performance during periods of high demand.*") # fmt: on if "chat_history" not in st.session_state: st.session_state.chat_history = [] torch.set_grad_enabled(False) @st.cache_resource() def load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16 ) return tokenizer, model # def chat_func(tokenizer, model, chat_history): # input_ids = tokenizer.apply_chat_template( # chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt" # ).to(model.device) # output_tokens = model.generate( # input_ids, # do_sample=True, # temperature=TEMPERATURE, # repetition_penalty=REPETITION_PENALTY, # max_new_tokens=MAX_NEW_TOKENS, # ) # output_text = tokenizer.decode( # output_tokens[0][len(input_ids[0]) :], skip_special_tokens=True # ) # return output_text def chat_func_stream(tokenizer, model, chat_history, streamer): input_ids = tokenizer.apply_chat_template( chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # check input length if len(input_ids[0]) > TOKEN_LIMIT: st.warning( f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens." ) st.session_state.chat_history = st.session_state.chat_history[:-1] return model.generate( input_ids, do_sample=True, temperature=TEMPERATURE, repetition_penalty=REPETITION_PENALTY, max_new_tokens=MAX_NEW_TOKENS, streamer=streamer, ) return def show_chat_message(contrainer, chat_message): with contrainer: with st.chat_message(chat_message["role"]): st.write(chat_message["content"]) class ResponseStreamer: def __init__(self, tokenizer, container, chat_history): self.tokenizer = tokenizer self.container = container self.chat_history = chat_history self.first_call_to_put = True self.current_response = "" with self.container: self.placeholder = st.empty() # placeholder to save streamed message def put(self, new_token): # do not write input tokens if self.first_call_to_put: self.first_call_to_put = False return # decode current token and accumulate current_response decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True) self.current_response += decoded # display the stramed message show_chat_message( self.placeholder.container(), {"role": "assistant", "content": self.current_response}, ) def end(self): # save assistant message self.chat_history.append( {"role": "assistant", "content": self.current_response} ) # clean up states (actually not needed as the instance will get recreated) self.first_call_to_put = True self.current_response = "" # rerun to unfreeze the chat_input st.rerun() tokenizer, model = load_model() chat_messages_container = st.container() for msg in st.session_state.chat_history: show_chat_message(chat_messages_container, msg) input_placeholder = st.empty() # use placeholder as a hack to disable input user_input = input_placeholder.chat_input(key="user_input_original") if user_input: # disable chat_input while generating input_placeholder.chat_input(key="user_input_disabled", disabled=True) new_user_message = {"role": "user", "content": user_input} st.session_state.chat_history.append(new_user_message) show_chat_message(chat_messages_container, new_user_message) # assistant_message = chat_func(tokenizer, model, st.session_state.chat_history) # assistant_message = {"role": "assistant", "content": assistant_message} # st.session_state.chat_history.append(assistant_message) # show_chat_message(chat_messages_container, assistant_message) streamer = ResponseStreamer( tokenizer=tokenizer, container=chat_messages_container, chat_history=st.session_state.chat_history, ) chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer)