File size: 9,864 Bytes
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4eae
 
 
 
 
 
7a664fd
 
 
 
 
 
 
 
25f2efc
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4eae
7a664fd
 
 
 
 
c7e10bf
7a664fd
c7e10bf
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
c7e10bf
 
 
 
 
 
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4eae
 
 
 
7a664fd
dcd4eae
7a664fd
 
 
dcd4eae
7a664fd
dcd4eae
7a664fd
 
 
 
dcd4eae
 
 
 
 
 
 
 
 
 
 
 
 
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import streamlit as st
from transformers import AutoTokenizer
import json
import tempfile
import os
import uuid
import copy

st.set_page_config(layout="wide")

def sanitize_jinja2(jinja_lines):
    
    one_liner_jinja = ""
    for line in jinja_lines:
        one_liner_jinja += line.lstrip(" ").rstrip("\n")

    return one_liner_jinja

@st.cache_resource
def get_existing_templates():
    return [None] + os.listdir("./templates")



# Initialization
if 'tokenizer_json' not in st.session_state:
    st.session_state['tokenizer_json'] = None

if 'tokenizer' not in st.session_state:
    st.session_state['tokenizer'] = None

if 'repo_normalized_name' not in st.session_state:
    st.session_state['repo_normalized_name'] = None

if 'repo_id' not in st.session_state:
    st.session_state['repo_id'] = None

if 'input_jinja_template' not in st.session_state:
    st.session_state['input_jinja_template'] = ""

if 'uuid' not in st.session_state:
    st.session_state['uuid'] = uuid.uuid4()
    os.makedirs(f"./tmp/{st.session_state['uuid']}")

if 'successful_template' not in st.session_state:
    st.session_state['successful_template'] = ''

if 'generated_prompt_w_add_generation_prompt' not in st.session_state:
    st.session_state['generated_prompt_w_add_generation_prompt'] = ''
    
if 'generated_prompt_wo_add_generation_prompt' not in st.session_state:
    st.session_state['generated_prompt_wo_add_generation_prompt'] = ''

if not os.path.exists("./tmp"):
    os.makedirs("./tmp")

title_description = """
Chat Template Generation: Make Chat Easier with Huggingface Tokenizer
"""

st.title(title_description)
st.markdown('This streamlit app is to serve as an easier way to check and push the chat template to your/exisiting huggingface repo')

list_of_templates = get_existing_templates()
with st.expander("Current predefined templates"):
    for model in list_of_templates[1:]:
        st.markdown(f"- {model}")
    st.info('More templates will be predefined for easier setup of chat template.', icon="ℹ️")

st.divider()
# custom_repo_tab, prebuilt_template_tab = st.tabs(["Specify Custom Repository Path", "Select Prebuilt Template"])

hf_model_repo_name = st.text_input("Hugging Face Model Repository To Update", value="tiiuae/falcon-7b", max_chars=None, key=None, type="default", 
                        help=None, autocomplete=None, label_visibility="visible")

gen_button = st.button("Get Tokenizer Config")

if gen_button:
    with st.spinner(text="In progress...", cache=False):
        st.session_state['repo_id'] = hf_model_repo_name
        st.session_state['tokenizer'] = AutoTokenizer.from_pretrained(hf_model_repo_name)

        st.session_state['repo_normalized_name'] = hf_model_repo_name.replace("/", "_")
        st.session_state['tokenizer'].save_pretrained(f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}")
        st.session_state['tokenizer_json'] = f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}"
    
if st.session_state['tokenizer_json'] is not None:
    with open(f"{st.session_state['tokenizer_json']}/tokenizer_config.json", "rb") as f:
        tokenizer_json = json.load(f)

    json_spec, col2 = st.columns(spec=[0.3, 0.7])


    with json_spec:
        st.markdown(f"### Tokenizer Config from {st.session_state['repo_normalized_name']}")
        st.json(json.dumps(tokenizer_json, indent=4))

    with col2:
        chat = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
        {"role": "user", "content": "I'd like to show off how chat templating works!"},
        ]
        st.markdown("### Example Conversation")
        st.json(json.dumps(chat, indent=4), expanded=False)

        prompt_template_col, prompt_template_output_col = st.columns(spec=[0.3, 0.7])

        with prompt_template_col:
            list_of_templates = get_existing_templates()
            selected_template = st.selectbox("Choose Existing Template or Leave Blank. (If template is None, it will check current tokenizer's `chat_template` and `default_chat_template` fields)", 
                options=list_of_templates, 
                index=0, placeholder="Choose a template (If template is None, it will check current tokenizer `chat_template` and `default_chat_template` fields)", disabled=False, label_visibility="visible")
            # add_generation_prompt_checkbox = st.checkbox("add_generation_prompt")
            generate_prompt_example_button = st.button("Generate Prompt", key="generate_prompt_example_button")

            # if selected_template is None:
            #     st.session_state['input_jinja_template'] = st.text_area(
            #         "Jinja Chat Template", value=st.session_state['input_jinja_template'], 
            #         height=500, placeholder=None, disabled=False, label_visibility="visible")

            if selected_template is not None:
                with open(f"./templates/{selected_template}", "r") as f:
                    jinja_lines = f.readlines()
                    st.session_state['input_jinja_template'] = "".join(jinja_lines)

            if selected_template is None:
                st.session_state['input_jinja_template'] = st.session_state['tokenizer'].chat_template 
                if st.session_state['input_jinja_template'] is None:
                    st.session_state['input_jinja_template'] = st.session_state['tokenizer'].default_chat_template


            st.session_state['input_jinja_template'] = st.text_area(
                "Jinja Chat Template", value=st.session_state['input_jinja_template'], 
                height=500, placeholder=None, disabled=False, label_visibility="visible")

        
        with prompt_template_output_col:

            if generate_prompt_example_button:
                with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp:
                    fp.write(st.session_state['input_jinja_template'])
                with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f:
                    jinja_lines = f.readlines()
                    st.session_state['tokenizer'].chat_template = sanitize_jinja2(jinja_lines)
                generated_prompt_wo_add_generation_prompt = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= False)
                generated_prompt_w_add_generation_prompt = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= True)

                st.session_state['successful_template'] = copy.deepcopy(st.session_state['input_jinja_template'])


            if len(st.session_state['successful_template']) > 0:
                st.text_area(
                    "Generate Prompt with `add_generation_prompt=False`", value=st.session_state['generated_prompt_wo_add_generation_prompt'], 
                    height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_wo_add_generation_prompt")

                st.text_area(
                    "Generate Prompt with `add_generation_prompt=True`", value=st.session_state['generated_prompt_w_add_generation_prompt'], 
                    height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_w_add_generation_prompt")
                    
                access_token_no_cache = st.text_input("HuggingFace Access Token API with Write Access", type="password", key="access_token_no_cache")
                commit_message_text_input = st.text_input("Commit Message", key="commit_message_text_input")
                to_private_checkbox = st.checkbox("To Private Repo", key="to_private_checkbox")
                create_pr_checkbox = st.checkbox("Create PR (For Contribution 🤗)", key="create_pr_checkbox")
                push_to_hub_button = st.button("Push to Hub", key="push_to_hub_button", use_container_width=True)
                st.session_state['tokenizer'].save_pretrained(f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}")
                with open(f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}/tokenizer_config.json", "r") as f:

                    tokenizer_config_content = json.loads(f.read())

                st.download_button(
                    label="Download tokenizer_config.json",
                    data=json.dumps(tokenizer_config_content, indent=4),
                    file_name='tokenizer_config.json',
                    mime='application/json',
                    use_container_width=True
                )
                if push_to_hub_button:
                    with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp:
                        fp.write(st.session_state['successful_template'])
                    with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f:
                        successful_jinja_lines = f.readlines()
                        st.session_state['tokenizer'].chat_template = sanitize_jinja2(successful_jinja_lines)
                        try:
                            with st.spinner(text="Pushing to hub ...", cache=False):
                                st.session_state['tokenizer'].push_to_hub(
                                    repo_id=st.session_state['repo_id'], 
                                    commit_message=commit_message_text_input, 
                                    private=to_private_checkbox, 
                                    token=access_token_no_cache,
                                    create_pr=create_pr_checkbox)
                        except Exception as e:
                            st.write(f"Repo id: {st.session_state['repo_id']}")
                            st.write(str(e))