CISCai's picture
reindent json to make code editor happier
0d5ccb3 verified
raw
history blame
41.8 kB
import gradio as gr
import json
from difflib import Differ, unified_diff
from itertools import groupby
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from huggingface_hub import HfApi, CommitOperationAdd
from transformers import PreTrainedTokenizerBase
from enum import StrEnum
from copy import deepcopy
hfapi = HfApi()
class ModelFiles(StrEnum):
TOKENIZER_CHAT_TEMPLATE = "tokenizer_chat_template.jinja"
TOKENIZER_CONFIG = "tokenizer_config.json"
TOKENIZER_INVERSE_TEMPLATE = "inverse_template.jinja"
example_labels = [
"Single user message",
"Single user message with system prompt",
"Longer conversation",
"Tool call",
"Tool call with response",
"Tool call with multiple responses",
"Tool call with complex tool definition",
"RAG call",
]
example_values = [
[
"{}",
"""[
{
"role": "user",
"content": "What is the capital of Norway?"
}
]""",
],
[
"{}",
"""[
{
"role": "system",
"content": "You are a somewhat helpful AI."
},
{
"role": "user",
"content": "What is the capital of Norway?"
}
]""",
],
[
"{}",
"""[
{
"role": "user",
"content": "What is the capital of Norway?"
},
{
"role": "assistant",
"content": "Oslo is the capital of Norway."
},
{
"role": "user",
"content": "What is the world famous sculpture park there called?"
},
{
"role": "assistant",
"content": "The world famous sculpture park in Oslo is called Vigelandsparken."
},
{
"role": "user",
"content": "What is the most famous sculpture in the park?"
}
]""",
],
[
"""{
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [ "celsius", "fahrenheit" ]
}
},
"required": [ "location" ]
}
}
}
]
}""",
"""[
{
"role": "user",
"content": "What's the weather like in Oslo?"
}
]""",
],
[
"""{
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [ "celsius", "fahrenheit" ]
}
},
"required": [ "location" ]
}
}
}
]
}""",
"""[
{
"role": "user",
"content": "What's the weather like in Oslo?"
},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "toolcall1",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Oslo, Norway",
"unit": "celsius"
}
}
}
]
},
{
"role": "tool",
"content": "20",
"tool_call_id": "toolcall1"
}
]""",
],
[
"""{
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [ "celsius", "fahrenheit" ]
}
},
"required": [ "location" ]
}
}
}
]
}""",
"""[
{
"role": "user",
"content": "What's the weather like in Oslo and Stockholm?"
},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "toolcall1",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Oslo, Norway",
"unit": "celsius"
}
}
},
{
"id": "toolcall2",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Stockholm, Sweden",
"unit": "celsius"
}
}
}
]
},
{
"role": "tool",
"content": "20",
"tool_call_id": "toolcall1"
},
{
"role": "tool",
"content": "22",
"tool_call_id": "toolcall2"
}
]""",
],
[
"""{
"tools": [
{
"type": "function",
"function": {
"name": "create_user",
"description": "creates a user",
"parameters": {
"type": "object",
"properties": {
"user": {
"title": "User",
"type": "object",
"properties": {
"user_id": {
"title": "User Id",
"description": "The unique identifier for a user",
"default": 0,
"type": "integer"
},
"name": {
"title": "Name",
"description": "The name of the user",
"type": "string"
},
"birthday": {
"type": "string",
"description": "The birthday of the user, e.g. 2022-01-01",
"pattern": "^([1-9] |1[0-9]| 2[0-9]|3[0-1])(.|-)([1-9] |1[0-2])(.|-|)20[0-9][0-9]$"
},
"email": {
"title": "Email",
"description": "The email address of the user",
"type": "string"
},
"friends": {
"title": "Friends",
"description": "List of friends of the user",
"type": "array",
"items": {"type": "string"}
}
},
"required": ["name", "email"]
}
},
"required": ["user"],
"definitions": {}
}
}
}
]
}""",
"""[
{
"role": "user",
"content": "Create a user for Test User (test@user.org), born January 1st 2000, with some random friends."
}
]""",
],
[
"""{
"documents": [
{
"title": "Much ado about nothing",
"content": "Dolor sit amet..."
},
{
"title": "Less ado about something",
"content": "Lorem ipsum..."
}
]
}""",
"""[
{
"role": "user",
"content": "Write a brief summary of the following documents."
}
]""",
],
]
pr_description_default = "### Changes\n* \n\n**Updated using [Chat Template Editor](https://huggingface.co/spaces/CISCai/chat-template-editor)**"
class TokenizerConfig():
def __init__(self, tokenizer_config: dict):
self._data = deepcopy(tokenizer_config)
self.chat_template = self._data.get("chat_template")
@property
def chat_template(self) -> str | list | None:
templates = [
{
"name": k,
"template": v,
}
for k, v in self.chat_templates.items() if v
]
if not templates:
return None
elif len(templates) == 1 and templates[0]["name"] == "default":
return templates[0]["template"]
else:
return templates
@chat_template.setter
def chat_template(self, value: str | list | None):
if not value:
self.chat_templates.clear()
elif isinstance(value, str):
self.chat_templates = {
"default": value,
}
else:
self.chat_templates = {
t["name"]: t["template"]
for t in value
}
# @property
# def inverse_template(self) -> str | None:
# return self._data.get("inverse_template")
# @inverse_template.setter
# def inverse_template(self, value: str | None):
# if value:
# self._data["inverse_template"] = value
# elif "inverse_template" in self._data:
# del self._data["inverse_template"]
def json(self, indent: int | str | None = 4) -> str:
self._data["chat_template"] = self.chat_template
return json.dumps(self._data, ensure_ascii = False, indent = indent)
def get_json_indent(
json: str,
) -> int | str | None:
nonl = json.replace('\r', '').replace('\n', '')
start = nonl.find("{")
first = nonl.find('"')
return "\t" if start >= 0 and nonl[start + 1] == "\t" else None if first == json.find('"') else min(max(first - start - 1, 2), 4)
def character_diff(
diff_title: str | None,
str_original: str,
str_updated: str,
):
d = Differ()
title = [] if diff_title is None else [ (f"\n@@ {diff_title} @@\n", "@") ]
diffs = [
("".join(map(lambda x: x[2:].replace("\t", "\u21e5").replace("\r", "\u240d\r").replace("\n", "\u240a\n") if x[0] != " " else x[2:], tokens)), group if group != " " else None) # .replace(" ", "\u2423")
for group, tokens in groupby(d.compare(str_updated, str_original), lambda x: x[0])
]
return title + ([("No changes", "?")] if len(diffs) == 1 and diffs[0][1] is None else diffs)
with gr.Blocks(
) as blocks:
with gr.Row(
equal_height = True,
):
hf_search = HuggingfaceHubSearch(
label = "Search Huggingface Hub",
placeholder = "Search for models on Huggingface",
search_type = "model",
sumbit_on_select = True,
scale = 2,
)
hf_branch = gr.Dropdown(
None,
label = "Branch",
scale = 1,
)
gr.LoginButton(
"Sign in for write access or gated/private repos",
scale = 1,
)
gr.Markdown(
"""# Chat Template Editor
Any model repository with chat template(s) is supported (including GGUFs), however do note that all the model info is extracted using the Hugging Face API.
For GGUFs in particular this means that the chat template may deviate from the actual content in any given GGUF file as only the default template from an arbitrary GGUF file is returned.
If you sign in and grant this editor write access you will get the option to create a pull request of your changes (provided you have access to the repository).
You can freely edit and test GGUF chat template(s) (and are encouraged to do so), but you cannot commit any changes, it is recommended to use the [GGUF Editor](https://huggingface.co/spaces/CISCai/gguf-editor) to save the final result to a GGUF.
""",
)
with gr.Accordion("Commit Changes", open = False, visible = False) as pr_group:
with gr.Tabs() as pr_tabs:
with gr.Tab("Edit", id = "edit") as pr_edit_tab:
pr_title = gr.Textbox(
placeholder = "Title",
show_label = False,
max_lines = 1,
interactive = True,
)
pr_description = gr.Code(
label = "Description",
language = "markdown",
lines = 10,
max_lines = 10,
interactive = True,
)
with gr.Tab("Preview (with diffs)", id = "preview") as pr_preview_tab:
pr_preview_title = gr.Textbox(
show_label = False,
max_lines = 1,
interactive = False,
)
pr_preview_description = gr.Markdown(
label = "Description",
height = "13rem",
container = True,
)
pr_preview_diff = gr.HighlightedText(
label = "Diff",
combine_adjacent = True,
color_map = { "+": "red", "-": "green", "@": "blue", "?": "blue" },
interactive = False,
show_legend = False,
show_inline_category = False,
)
pr_submit = gr.Button(
"Create Pull Request",
variant = "huggingface",
interactive = False,
)
pr_submit.click(
lambda: gr.Button(
interactive = False,
),
outputs = [
pr_submit,
],
show_api = False,
)
with gr.Tabs() as template_tabs:
with gr.Tab("Edit", id = "edit") as edit_tab:
with gr.Accordion("Template Input", open = False):
chat_settings = gr.Code(
label = "Template Settings (kwargs)",
language = "json",
interactive = True,
render = False,
)
chat_messages = gr.Code(
label = "Template Messages",
language = "json",
interactive = True,
render = False,
)
example_input = gr.Examples(
examples = example_values,
example_labels = example_labels,
inputs = [
chat_settings,
chat_messages,
],
)
chat_settings.render()
chat_messages.render()
chat_template = gr.Code(
label = "Chat Template (default)",
language = "jinja2",
interactive = True,
)
with gr.Accordion("Additional Templates", open = False):
inverse_template = gr.Code(
label = "Inverse Template",
language = "jinja2",
interactive = True,
visible = False,
)
chat_template_tool_use = gr.Code(
label = "Chat Template (tool_use)",
language = "jinja2",
interactive = True,
)
chat_template_rag = gr.Code(
label = "Chat Template (rag)",
language = "jinja2",
interactive = True,
)
with gr.Tab("Render", id = "render") as render_tab:
rendered_chat_template = gr.Textbox(
label = "Chat Prompt (default)",
interactive = False,
lines = 20,
show_copy_button = True,
)
with gr.Accordion("Additional Output", open = False):
rendered_inverse_template = gr.Code(
label = "Inverse Chat Messages",
language = "json",
interactive = False,
visible = False,
)
rendered_chat_template_tool_use = gr.Textbox(
label = "Chat Prompt (tool_use)",
interactive = False,
lines = 20,
show_copy_button = True,
)
rendered_chat_template_rag = gr.Textbox(
label = "Chat Prompt (rag)",
interactive = False,
lines = 20,
show_copy_button = True,
)
model_info = gr.State(
value = {},
)
@gr.on(
triggers = [
hf_search.submit,
],
inputs = [
hf_search,
],
outputs = [
hf_branch,
],
show_api = False,
)
def get_branches(
repo: str,
oauth_token: gr.OAuthToken | None = None,
):
branches = []
try:
refs = hfapi.list_repo_refs(
repo,
token = oauth_token.token if oauth_token else False,
)
branches = [b.name for b in refs.branches]
open_prs = hfapi.get_repo_discussions(
repo,
discussion_type = "pull_request",
discussion_status = "open",
token = oauth_token.token if oauth_token else False,
)
branches += [pr.git_reference for pr in open_prs]
except Exception as e:
pass
return {
hf_branch: gr.Dropdown(
branches or None,
value = "main" if "main" in branches else None,
),
}
@gr.on(
triggers = [
pr_title.input,
],
inputs = [
pr_title,
],
outputs = [
pr_submit,
],
show_api = False,
)
def enable_pr_submit(
title: str,
):
return gr.Button(
interactive = bool(title)
)
@gr.on(
triggers = [
pr_preview_tab.select,
],
inputs = [
model_info,
pr_title,
pr_description,
chat_template,
chat_template_tool_use,
chat_template_rag,
inverse_template,
],
outputs = [
pr_preview_title,
pr_preview_description,
pr_preview_diff,
],
show_api = False,
)
def render_pr_preview(
info: dict,
title: str,
description: str,
template: str,
template_tool_use: str,
template_rag: str,
template_inverse: str,
):
changes = []
org_template = ""
org_template_inverse = ""
org_template_tool_use = ""
org_template_rag = ""
tokenizer_file = info.get(ModelFiles.TOKENIZER_CONFIG, {})
org_config = tokenizer_file.get("data")
if org_config:
tokenizer_config = TokenizerConfig(tokenizer_file.get("content"))
org_template = tokenizer_config.chat_templates.get("default") or ""
org_template_tool_use = tokenizer_config.chat_templates.get("tool_use") or ""
org_template_rag = tokenizer_config.chat_templates.get("rag") or ""
# org_template_inverse = tokenizer_config.inverse_template or ""
tokenizer_config.chat_templates["default"] = template
tokenizer_config.chat_templates["tool_use"] = template_tool_use
tokenizer_config.chat_templates["rag"] = template_rag
# tokenizer_config.inverse_template = template_inverse
new_config = tokenizer_config.json(get_json_indent(org_config))
if org_config.endswith("\n"):
new_config += "\n"
changes += [
(token if token[1] in ("-", "+", "@") else token[1:].replace("\t", "\u21e5").replace("\r\n", "\u240d\u240a\r\n").replace("\r", "\u240d\r").replace("\n", "\u240a\n"), token[0] if token[0] != " " else None) # .replace(" ", "\u2423")
for token in unified_diff(new_config.splitlines(keepends = True), org_config.splitlines(keepends = True), fromfile = ModelFiles.TOKENIZER_CONFIG, tofile = ModelFiles.TOKENIZER_CONFIG)
]
tokenizer_chat_template = info.get(ModelFiles.TOKENIZER_CHAT_TEMPLATE, {})
org_template = tokenizer_chat_template.get("data", org_template)
tokenizer_inverse_template = info.get(ModelFiles.TOKENIZER_INVERSE_TEMPLATE, {})
org_template_inverse = tokenizer_inverse_template.get("data", org_template_inverse)
if org_template or template:
changes += character_diff(f"Default Template{f' ({ModelFiles.TOKENIZER_CHAT_TEMPLATE})' if tokenizer_chat_template else ''}", org_template, template)
if org_template_inverse or template_inverse:
changes += character_diff(f"Inverse Template{f' ({ModelFiles.TOKENIZER_INVERSE_TEMPLATE})' if tokenizer_inverse_template else ''}", org_template_inverse, template_inverse)
if org_template_tool_use or template_tool_use:
changes += character_diff("Tool Use Template", org_template_tool_use, template_tool_use)
if org_template_rag or template_rag:
changes += character_diff("RAG Template", org_template_rag, template_rag)
return title, description, changes
@gr.on(
triggers = [
pr_submit.click,
],
inputs = [
hf_search,
hf_branch,
model_info,
pr_title,
pr_description,
chat_template,
chat_template_tool_use,
chat_template_rag,
inverse_template,
],
outputs = [
model_info,
hf_branch,
pr_title,
pr_preview_title,
pr_description,
pr_submit,
],
show_api = False,
)
def submit_pull_request(
repo: str,
branch: str | None,
info: dict,
title: str,
description: str,
template: str,
template_tool_use: str,
template_rag: str,
template_inverse: str,
progress = gr.Progress(track_tqdm = True),
oauth_token: gr.OAuthToken | None = None,
):
operations = []
pr_branch = branch if branch.startswith("refs/pr/") else None
tokenizer_file = info.get(ModelFiles.TOKENIZER_CONFIG, {})
if org_config := tokenizer_file.get("data"):
tokenizer_config = TokenizerConfig(tokenizer_file.get("content"))
tokenizer_config.chat_templates["default"] = template
tokenizer_config.chat_templates["tool_use"] = template_tool_use
tokenizer_config.chat_templates["rag"] = template_rag
# tokenizer_config.inverse_template = template_inverse
new_config = tokenizer_config.json(get_json_indent(org_config))
if org_config.endswith("\n"):
new_config += "\n"
if org_config != new_config:
operations.append(CommitOperationAdd(ModelFiles.TOKENIZER_CONFIG, new_config.encode("utf-8")))
tokenizer_chat_template = info.get(ModelFiles.TOKENIZER_CHAT_TEMPLATE, {})
if template_data := tokenizer_chat_template.get("data"):
if template_data != template:
operations.append(CommitOperationAdd(ModelFiles.TOKENIZER_CHAT_TEMPLATE, template.encode("utf-8")))
tokenizer_inverse_template = info.get(ModelFiles.TOKENIZER_INVERSE_TEMPLATE, {})
if template_data := tokenizer_inverse_template.get("data"):
if template_data != template_inverse:
operations.append(CommitOperationAdd(ModelFiles.TOKENIZER_INVERSE_TEMPLATE, template_inverse.encode("utf-8")))
if not operations:
gr.Info("No changes to commit...")
return gr.skip()
try:
commit = hfapi.create_commit(
repo,
operations,
revision = branch,
commit_message = title,
commit_description = description,
create_pr = False if pr_branch else True,
parent_commit = info.get("parent_commit"),
token = oauth_token.token if oauth_token else False,
)
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "Error committing changes",
)
return gr.skip()
info["parent_commit"] = commit.oid
if org_config:
tokenizer_file["data"] = new_config
tokenizer_file["content"] = json.loads(new_config)
if tokenizer_chat_template:
tokenizer_chat_template["data"] = template
if tokenizer_inverse_template:
tokenizer_inverse_template["data"] = template_inverse
branches = []
try:
refs = hfapi.list_repo_refs(
repo,
token = oauth_token.token if oauth_token else False,
)
branches = [b.name for b in refs.branches]
open_prs = hfapi.get_repo_discussions(
repo,
discussion_type = "pull_request",
discussion_status = "open",
token = oauth_token.token if oauth_token else False,
)
branches += [pr.git_reference for pr in open_prs]
except Exception as e:
pass
pr_created = commit.pr_revision if commit.pr_revision in branches else None
return {
model_info: info,
hf_branch: gr.skip() if pr_branch else gr.Dropdown(
branches or None,
value = pr_created or branch,
),
pr_title: gr.skip() if pr_branch else gr.Textbox(
value = None,
placeholder = "Message" if pr_created else "Title",
label = commit.commit_message if pr_created else None,
show_label = True if pr_created else False,
),
pr_preview_title: gr.skip() if pr_branch else gr.Textbox(
label = commit.commit_message if pr_created else None,
show_label = True if pr_created else False,
),
pr_description: gr.Code(
value = pr_description_default,
),
pr_submit: gr.skip() if pr_branch else gr.Button(
value = f"Commit to PR #{commit.pr_num}" if pr_created else "Create Pull Request",
),
}
@gr.on(
triggers = [
hf_search.submit,
hf_branch.change,
],
outputs = [
pr_tabs,
template_tabs,
],
show_api = False,
)
def switch_to_edit_tabs():
return gr.Tabs(
selected = "edit",
), gr.Tabs(
selected = "edit",
)
@gr.on(
triggers = [
chat_template.focus,
chat_template_tool_use.focus,
chat_template_rag.focus,
inverse_template.focus,
],
outputs = [
pr_tabs,
],
show_api = False,
)
def switch_to_edit_tab():
return gr.Tabs(
selected = "edit",
)
def template_data_from_model_info(
repo: str,
branch: str | None,
oauth_token: gr.OAuthToken | None = None,
):
try:
info = hfapi.model_info(
repo,
revision = branch,
expand = [
"config",
"disabled",
"gated",
"gguf",
"private",
"widgetData",
],
token = oauth_token.token if oauth_token else False,
)
except Exception as e:
gr.Warning(
message = str(e),
title = "Error loading model info",
)
return {}, None, None, None, None, None, None
templates = info.gguf.get("chat_template") if info.gguf else info.config.get("tokenizer_config", {}).get("chat_template") if info.config else None
model_info = {
"gguf": bool(info.gguf),
"disabled": info.disabled,
"gated": info.gated,
"private": info.private,
}
template_messages = example_values[0][1]
template_tool_use = None
template_rag = None
template_inverse = None
template_kwargs = {
"add_generation_prompt": True,
"clean_up_tokenization_spaces": False,
"bos_token": "<|startoftext|>",
"eos_token": "<|im_end|>",
}
if info.config:
# template_inverse = info.config.get("tokenizer_config", {}).get("inverse_template")
for k, v in info.config.get("tokenizer_config", {}).items():
if k != "chat_template": # and k != "inverse_template":
template_kwargs[k] = v
if info.widget_data:
for data in info.widget_data:
if "messages" in data:
template_messages = json.dumps(data["messages"], ensure_ascii = False, indent = 2)
break
if isinstance(templates, list):
templates = { template["name"]: template["template"] for template in templates }
template_tool_use = templates.get("tool_use")
template_rag = templates.get("rag")
templates = templates.get("default")
return model_info, json.dumps(template_kwargs, ensure_ascii = False, indent = 2), template_messages, templates, template_tool_use, template_rag, template_inverse
def template_data_from_model_files(
repo: str,
branch: str | None,
info: dict,
progress = gr.Progress(track_tqdm = True),
oauth_token: gr.OAuthToken | None = None,
):
write_access = False
if info and oauth_token:
if info.get("gguf"):
gr.Warning("Repository contains GGUFs, use GGUF Editor if you want to commit changes...")
elif info.get("disabled"):
gr.Warning("Repository is disabled, committing changes is not possible...")
elif (gated := info.get("gated")) or (private := info.get("private")):
try:
hfapi.auth_check(
repo,
token = oauth_token.token if oauth_token else False,
)
except Exception as e:
if gated:
gr.Warning(f"Repository is gated with {gated} approval, you must request access to be able to make changes...")
elif private:
gr.Warning("Repository is private, you must use proper credentials to be able to make changes...")
gr.Warning(str(e))
else:
write_access = True
else:
write_access = True
if write_access:
if (write_access := hfapi.file_exists(
repo,
ModelFiles.TOKENIZER_CONFIG,
revision = branch,
token = oauth_token.token if oauth_token else False,
)):
try:
commits = hfapi.list_repo_commits(
repo,
revision = branch,
token = oauth_token.token if oauth_token else False,
)
parent_commit = commits[0].commit_id if commits else None
tokenizer_config_file = hfapi.hf_hub_download(
repo,
ModelFiles.TOKENIZER_CONFIG,
revision = parent_commit or branch,
token = oauth_token.token if oauth_token else False,
)
tokenizer_chat_template = None
if (hfapi.file_exists(
repo,
ModelFiles.TOKENIZER_CHAT_TEMPLATE,
revision = branch,
token = oauth_token.token if oauth_token else False,
)):
tokenizer_chat_template = hfapi.hf_hub_download(
repo,
ModelFiles.TOKENIZER_CHAT_TEMPLATE,
revision = parent_commit or branch,
token = oauth_token.token if oauth_token else False,
)
tokenizer_inverse_template = None
if (hfapi.file_exists(
repo,
ModelFiles.TOKENIZER_INVERSE_TEMPLATE,
revision = branch,
token = oauth_token.token if oauth_token else False,
)):
tokenizer_inverse_template = hfapi.hf_hub_download(
repo,
ModelFiles.TOKENIZER_INVERSE_TEMPLATE,
revision = parent_commit or branch,
token = oauth_token.token if oauth_token else False,
)
except Exception as e:
gr.Warning(
message = str(e),
title = "Error downloading template files",
)
else:
info["parent_commit"] = parent_commit
if tokenizer_config_file:
with open(tokenizer_config_file, "r", encoding = "utf-8") as fp:
config_content = fp.read()
info[ModelFiles.TOKENIZER_CONFIG] = {
"data": config_content,
"content": json.loads(config_content),
}
if tokenizer_chat_template:
with open(tokenizer_chat_template, "r", encoding = "utf-8") as fp:
template_data = fp.read()
info[ModelFiles.TOKENIZER_CHAT_TEMPLATE] = {
"data": template_data,
}
if tokenizer_inverse_template:
with open(tokenizer_inverse_template, "r", encoding = "utf-8") as fp:
template_data = fp.read()
info[ModelFiles.TOKENIZER_INVERSE_TEMPLATE] = {
"data": template_data,
}
else:
gr.Warning(f"No {ModelFiles.TOKENIZER_CONFIG} found in repository...")
pr_details = None
if branch and branch.startswith("refs/pr/"):
pr_num = branch.split("/")[-1]
if pr_num and pr_num.isdigit():
pr_details = hfapi.get_discussion_details(
repo,
int(pr_num),
token = oauth_token.token if oauth_token else False,
)
return {
model_info: info,
pr_group: gr.Accordion(
visible = write_access,
),
pr_title: gr.Textbox(
value = None,
placeholder = "Message" if pr_details else "Title",
label = pr_details.title if pr_details else None,
show_label = True if pr_details else False,
),
pr_preview_title: gr.Textbox(
label = pr_details.title if pr_details else None,
show_label = True if pr_details else False,
),
pr_description: gr.Code(
value = pr_description_default,
),
pr_submit: gr.Button(
value = f"Commit to PR #{pr_details.num}" if pr_details else "Create Pull Request",
),
# chat_template: gr.skip() if ModelFiles.TOKENIZER_CHAT_TEMPLATE not in info else gr.Code(
# value = info[ModelFiles.TOKENIZER_CHAT_TEMPLATE]["data"],
# ),
# inverse_template: gr.skip() if ModelFiles.TOKENIZER_INVERSE_TEMPLATE not in info else gr.Code(
# value = info[ModelFiles.TOKENIZER_INVERSE_TEMPLATE]["data"],
# ),
}
def update_examples(
settings: str,
):
settings = json.loads(settings)
examples = []
for example in example_values:
x = example.copy()
x0 = json.loads(x[0])
x0.update(settings)
x[0] = json.dumps(x0, ensure_ascii = False, indent = 2)
examples.append(x)
return gr.Dataset(
samples = examples,
)
gr.on(
fn = template_data_from_model_info,
triggers = [
hf_search.submit,
hf_branch.input,
],
inputs = [
hf_search,
hf_branch,
],
outputs = [
model_info,
chat_settings,
chat_messages,
chat_template,
chat_template_tool_use,
chat_template_rag,
inverse_template,
],
).success(
fn = update_examples,
inputs = [
chat_settings,
],
outputs = [
example_input.dataset,
],
show_api = False,
).then(
fn = template_data_from_model_files,
inputs = [
hf_search,
hf_branch,
model_info,
],
outputs = [
model_info,
pr_group,
pr_title,
pr_preview_title,
pr_description,
pr_submit,
# chat_template,
# inverse_template,
],
show_api = False,
)
@gr.on(
triggers = [
render_tab.select,
],
inputs = [
chat_settings,
chat_messages,
chat_template,
chat_template_tool_use,
chat_template_rag,
inverse_template,
],
outputs = [
rendered_chat_template,
rendered_chat_template_tool_use,
rendered_chat_template_rag,
rendered_inverse_template,
],
)
def render_chat_templates(
settings: str,
messages: str,
template: str,
template_tool_use: str | None = None,
template_rag: str | None = None,
template_inverse: str | None = None,
):
try:
settings = json.loads(settings) if settings else {}
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "Template Settings Error",
)
return gr.skip()
try:
messages = json.loads(messages) if messages else []
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "Template Messages Error",
)
return gr.skip()
if not isinstance(settings, dict):
gr.Warning("Invalid Template Settings!")
return gr.skip()
if not messages or not isinstance(messages, list) or not isinstance(messages[0], dict) or "role" not in messages[0]:
gr.Warning("No Template Messages!")
return gr.skip()
tools = settings.get("tools")
documents = settings.get("documents")
add_generation_prompt = settings.get("add_generation_prompt")
cleanup_settings = []
for k in settings.keys():
if k.endswith("_side") or k.endswith("_token") or k.endswith("_tokens") or k == "clean_up_tokenization_spaces":
continue
cleanup_settings.append(k)
for cleanup in cleanup_settings:
del settings[cleanup]
tokenizer = PreTrainedTokenizerBase(**settings)
chat_output = None
chat_tool_use_output = None
chat_rag_output = None
inverse_output = None
try:
chat_output = tokenizer.apply_chat_template(messages, tools = tools, documents = documents, chat_template = template, add_generation_prompt = add_generation_prompt, tokenize = False)
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "Chat Template Error",
)
try:
chat_tool_use_output = tokenizer.apply_chat_template(messages, tools = tools or [], chat_template = template_tool_use, add_generation_prompt = add_generation_prompt, tokenize = False) if template_tool_use else None
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "Tool Use Template Error",
)
try:
chat_rag_output = tokenizer.apply_chat_template(messages, documents = documents or [], chat_template = template_rag, add_generation_prompt = add_generation_prompt, tokenize = False) if template_rag else None
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "RAG Template Error",
)
try:
inverse_output = tokenizer.apply_inverse_template(messages, inverse_template = template_inverse) if template_inverse else None
except Exception as e:
gr.Warning(
message = str(e),
duration = None,
title = "Inverse Template Error",
)
return chat_output, chat_tool_use_output, chat_rag_output, json.dumps(inverse_output, ensure_ascii = False, indent = 2) if inverse_output is not None else None
if __name__ == "__main__":
blocks.queue(
max_size = 10,
default_concurrency_limit = 10,
)
blocks.launch(ssr_mode = False)