tiny-chat / app.py
ericzzz's picture
Update app.py
52e0c4f
raw
history blame contribute delete
No virus
4.79 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
TOKEN_LIMIT = 2048
TEMPERATURE = 0.3
REPETITION_PENALTY = 1.05
MAX_NEW_TOKENS = 500
MODEL_NAME = "ericzzz/falcon-rw-1b-chat"
# fmt: off
st.write("**💬Tiny Chat with [Falcon-RW-1B-Chat](https://huggingface.co/ericzzz/falcon-rw-1b-chat)**" )
st.write("*The model operates on free-tier hardware, which may lead to slower performance during periods of high demand.*")
# fmt: on
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
torch.set_grad_enabled(False)
@st.cache_resource()
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16
)
return tokenizer, model
# def chat_func(tokenizer, model, chat_history):
# input_ids = tokenizer.apply_chat_template(
# chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
# ).to(model.device)
# output_tokens = model.generate(
# input_ids,
# do_sample=True,
# temperature=TEMPERATURE,
# repetition_penalty=REPETITION_PENALTY,
# max_new_tokens=MAX_NEW_TOKENS,
# )
# output_text = tokenizer.decode(
# output_tokens[0][len(input_ids[0]) :], skip_special_tokens=True
# )
# return output_text
def chat_func_stream(tokenizer, model, chat_history, streamer):
input_ids = tokenizer.apply_chat_template(
chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
# check input length
if len(input_ids[0]) > TOKEN_LIMIT:
st.warning(
f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens."
)
st.session_state.chat_history = st.session_state.chat_history[:-1]
return
model.generate(
input_ids,
do_sample=True,
temperature=TEMPERATURE,
repetition_penalty=REPETITION_PENALTY,
max_new_tokens=MAX_NEW_TOKENS,
streamer=streamer,
)
return
def show_chat_message(contrainer, chat_message):
with contrainer:
with st.chat_message(chat_message["role"]):
st.write(chat_message["content"])
class ResponseStreamer:
def __init__(self, tokenizer, container, chat_history):
self.tokenizer = tokenizer
self.container = container
self.chat_history = chat_history
self.first_call_to_put = True
self.current_response = ""
with self.container:
self.placeholder = st.empty() # placeholder to save streamed message
def put(self, new_token):
# do not write input tokens
if self.first_call_to_put:
self.first_call_to_put = False
return
# decode current token and accumulate current_response
decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True)
self.current_response += decoded
# display the stramed message
show_chat_message(
self.placeholder.container(),
{"role": "assistant", "content": self.current_response},
)
def end(self):
# save assistant message
self.chat_history.append(
{"role": "assistant", "content": self.current_response}
)
# clean up states (actually not needed as the instance will get recreated)
self.first_call_to_put = True
self.current_response = ""
# rerun to unfreeze the chat_input
st.rerun()
tokenizer, model = load_model()
chat_messages_container = st.container()
for msg in st.session_state.chat_history:
show_chat_message(chat_messages_container, msg)
input_placeholder = st.empty() # use placeholder as a hack to disable input
user_input = input_placeholder.chat_input(key="user_input_original")
if user_input:
# disable chat_input while generating
input_placeholder.chat_input(key="user_input_disabled", disabled=True)
new_user_message = {"role": "user", "content": user_input}
st.session_state.chat_history.append(new_user_message)
show_chat_message(chat_messages_container, new_user_message)
# assistant_message = chat_func(tokenizer, model, st.session_state.chat_history)
# assistant_message = {"role": "assistant", "content": assistant_message}
# st.session_state.chat_history.append(assistant_message)
# show_chat_message(chat_messages_container, assistant_message)
streamer = ResponseStreamer(
tokenizer=tokenizer,
container=chat_messages_container,
chat_history=st.session_state.chat_history,
)
chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer)