File size: 8,285 Bytes
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f2efc
7a664fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import streamlit as st
from transformers import AutoTokenizer
import json
import tempfile
import os
import uuid
import copy

st.set_page_config(layout="wide")

def sanitize_jinja2(jinja_lines):
    
    one_liner_jinja = ""
    for line in jinja_lines:
        one_liner_jinja += line.lstrip(" ").rstrip("\n")

    return one_liner_jinja

@st.cache_resource
def get_existing_templates():
    return [None] + os.listdir("./templates")



# Initialization
if 'tokenizer_json' not in st.session_state:
    st.session_state['tokenizer_json'] = None

if 'tokenizer' not in st.session_state:
    st.session_state['tokenizer'] = None

if 'repo_normalized_name' not in st.session_state:
    st.session_state['repo_normalized_name'] = None

if 'repo_id' not in st.session_state:
    st.session_state['repo_id'] = None

if 'input_jinja_template' not in st.session_state:
    st.session_state['input_jinja_template'] = ""

if 'uuid' not in st.session_state:
    st.session_state['uuid'] = uuid.uuid4()
    os.makedirs(f"./tmp/{st.session_state['uuid']}")

if 'successful_template' not in st.session_state:
    st.session_state['successful_template'] = ''

if not os.path.exists("./tmp"):
    os.makedirs("./tmp")

title_description = """
Chat Template Generation: Make Chat Easier with Huggingface Tokenizer
"""

st.title(title_description)
st.markdown('This streamlit app is to serve as an easier way to check and push the chat template to your/exisiting huggingface repo')

list_of_templates = get_existing_templates()
with st.expander("Current predefined templates"):
    for model in list_of_templates[1:]:
        st.markdown(f"- {model}")
    st.info('More templates will be predefined for easier setup of chat template.', icon="ℹ️")

st.divider()
# custom_repo_tab, prebuilt_template_tab = st.tabs(["Specify Custom Repository Path", "Select Prebuilt Template"])

hf_model_repo_name = st.text_input("Hugging Face Model Repository To Update", value="tiiuae/falcon-7b", max_chars=None, key=None, type="default", 
                        help=None, autocomplete=None, label_visibility="visible")

gen_button = st.button("Get Tokenizer Config")

if gen_button:
    with st.spinner(text="In progress...", cache=False):
        st.session_state['repo_id'] = hf_model_repo_name
        st.session_state['tokenizer'] = AutoTokenizer.from_pretrained(hf_model_repo_name)

        st.session_state['repo_normalized_name'] = hf_model_repo_name.replace("/", "_")
        st.session_state['tokenizer'].save_pretrained(f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}")
        st.session_state['tokenizer_json'] = f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}"
    
if st.session_state['tokenizer_json'] is not None:
    with open(f"{st.session_state['tokenizer_json']}/tokenizer_config.json", "rb") as f:
        tokenizer_json = json.load(f)

    json_spec, col2 = st.columns(spec=[0.3, 0.7])


    with json_spec:
        st.markdown(f"### Tokenizer Config from {st.session_state['repo_normalized_name']}")
        st.json(json.dumps(tokenizer_json, indent=4))

    with col2:
        chat = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
        {"role": "user", "content": "I'd like to show off how chat templating works!"},
        ]
        st.markdown("### Example Conversation")
        st.json(json.dumps(chat, indent=4))

        prompt_template_col, prompt_template_output_col = st.columns(spec=[0.3, 0.7])

        with prompt_template_col:
            list_of_templates = get_existing_templates()
            selected_template = st.selectbox("Choose Existing Template or Leave Blank.", 
                options=list_of_templates, 
                index=0, placeholder="Choose an option", disabled=False, label_visibility="visible")
            # add_generation_prompt_checkbox = st.checkbox("add_generation_prompt")
            generate_prompt_example_button = st.button("Generate Prompt", key="generate_prompt_example_button")

            # if selected_template is None:
            #     st.session_state['input_jinja_template'] = st.text_area(
            #         "Jinja Chat Template", value=st.session_state['input_jinja_template'], 
            #         height=500, placeholder=None, disabled=False, label_visibility="visible")

            if selected_template is not None:
                with open(f"./templates/{selected_template}", "r") as f:
                    jinja_lines = f.readlines()
                    st.session_state['input_jinja_template'] = "".join(jinja_lines)

            st.session_state['input_jinja_template'] = st.text_area(
                "Jinja Chat Template", value=st.session_state['input_jinja_template'], 
                height=500, placeholder=None, disabled=False, label_visibility="visible")

        
        with prompt_template_output_col:

            if generate_prompt_example_button:
                with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp:
                    fp.write(st.session_state['input_jinja_template'])
                with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f:
                    jinja_lines = f.readlines()
                    st.session_state['tokenizer'].chat_template = sanitize_jinja2(jinja_lines)
                generated_prompt_wo_add_generation_prompt = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= False)
                generated_prompt_w_add_generation_prompt = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= True)

                st.text_area(
                    "Generate Prompt with `add_generation_prompt=False`", value=generated_prompt_wo_add_generation_prompt, 
                    height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_wo_add_generation_prompt")

                st.text_area(
                    "Generate Prompt with `add_generation_prompt=True`", value=generated_prompt_w_add_generation_prompt, 
                    height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_w_add_generation_prompt")

                st.session_state['successful_template'] = copy.deepcopy(st.session_state['input_jinja_template'])

            if len(st.session_state['successful_template']) > 0:
                access_token_no_cache = st.text_input("HuggingFace Access Token API with Write Access", type="password", key="access_token_no_cache")
                commit_message_text_input = st.text_input("Commit Message", key="commit_message_text_input")
                to_private_checkbox = st.checkbox("To Private Repo", key="to_private_checkbox")
                push_to_hub_button = st.button("Push to Hub", key="push_to_hub_button")
                create_pr_checkbox = st.checkbox("Create PR (For Contribution 🤗)", key="create_pr_checkbox")
                if push_to_hub_button:
                    with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp:
                        fp.write(st.session_state['successful_template'])
                    with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f:
                        successful_jinja_lines = f.readlines()
                        st.session_state['tokenizer'].chat_template = sanitize_jinja2(successful_jinja_lines)
                        try:
                            with st.spinner(text="Pushing to hub ...", cache=False):
                                st.session_state['tokenizer'].push_to_hub(
                                    repo_id=st.session_state['repo_id'], 
                                    commit_message=commit_message_text_input, 
                                    private=to_private_checkbox, 
                                    token=access_token_no_cache,
                                    create_pr=create_pr_checkbox)
                        except Exception as e:
                            st.write(f"Repo id: {st.session_state['repo_id']}")
                            st.write(str(e))