File size: 4,792 Bytes
0694801
 
 
 
 
52e0c4f
0694801
 
 
 
5a344af
e578fce
52e0c4f
0694801
5a344af
0694801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a344af
 
0694801
 
 
 
 
 
 
 
5a344af
 
 
0694801
5a344af
 
 
0694801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

TOKEN_LIMIT = 2048
TEMPERATURE = 0.3
REPETITION_PENALTY = 1.05
MAX_NEW_TOKENS = 500
MODEL_NAME = "ericzzz/falcon-rw-1b-chat"

# fmt: off
st.write("**💬Tiny Chat with [Falcon-RW-1B-Chat](https://huggingface.co/ericzzz/falcon-rw-1b-chat)**" ) 
st.write("*The model operates on free-tier hardware, which may lead to slower performance during periods of high demand.*")

# fmt: on
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

torch.set_grad_enabled(False)


@st.cache_resource()
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16
    )
    return tokenizer, model


# def chat_func(tokenizer, model, chat_history):
#     input_ids = tokenizer.apply_chat_template(
#         chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
#     ).to(model.device)
#     output_tokens = model.generate(
#         input_ids,
#         do_sample=True,
#         temperature=TEMPERATURE,
#         repetition_penalty=REPETITION_PENALTY,
#         max_new_tokens=MAX_NEW_TOKENS,
#     )
#     output_text = tokenizer.decode(
#         output_tokens[0][len(input_ids[0]) :], skip_special_tokens=True
#     )
#     return output_text


def chat_func_stream(tokenizer, model, chat_history, streamer):
    input_ids = tokenizer.apply_chat_template(
        chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)
    # check input length
    if len(input_ids[0]) > TOKEN_LIMIT:
        st.warning(
            f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens."
        )
        st.session_state.chat_history = st.session_state.chat_history[:-1]
        return
    model.generate(
        input_ids,
        do_sample=True,
        temperature=TEMPERATURE,
        repetition_penalty=REPETITION_PENALTY,
        max_new_tokens=MAX_NEW_TOKENS,
        streamer=streamer,
    )
    return


def show_chat_message(contrainer, chat_message):
    with contrainer:
        with st.chat_message(chat_message["role"]):
            st.write(chat_message["content"])


class ResponseStreamer:
    def __init__(self, tokenizer, container, chat_history):
        self.tokenizer = tokenizer
        self.container = container
        self.chat_history = chat_history

        self.first_call_to_put = True
        self.current_response = ""
        with self.container:
            self.placeholder = st.empty()  # placeholder to save streamed message

    def put(self, new_token):
        # do not write input tokens
        if self.first_call_to_put:
            self.first_call_to_put = False
            return
        # decode current token and accumulate current_response
        decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True)
        self.current_response += decoded
        # display the stramed message
        show_chat_message(
            self.placeholder.container(),
            {"role": "assistant", "content": self.current_response},
        )

    def end(self):
        # save assistant message
        self.chat_history.append(
            {"role": "assistant", "content": self.current_response}
        )
        # clean up states (actually not needed as the instance will get recreated)
        self.first_call_to_put = True
        self.current_response = ""
        # rerun to unfreeze the chat_input
        st.rerun()


tokenizer, model = load_model()
chat_messages_container = st.container()

for msg in st.session_state.chat_history:
    show_chat_message(chat_messages_container, msg)

input_placeholder = st.empty()  # use placeholder as a hack to disable input
user_input = input_placeholder.chat_input(key="user_input_original")

if user_input:
    # disable chat_input while generating
    input_placeholder.chat_input(key="user_input_disabled", disabled=True)

    new_user_message = {"role": "user", "content": user_input}
    st.session_state.chat_history.append(new_user_message)
    show_chat_message(chat_messages_container, new_user_message)

    # assistant_message = chat_func(tokenizer, model, st.session_state.chat_history)
    # assistant_message = {"role": "assistant", "content": assistant_message}
    # st.session_state.chat_history.append(assistant_message)
    # show_chat_message(chat_messages_container, assistant_message)

    streamer = ResponseStreamer(
        tokenizer=tokenizer,
        container=chat_messages_container,
        chat_history=st.session_state.chat_history,
    )
    chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer)