|
import gradio as gr |
|
import json |
|
import pandas as pd |
|
import config |
|
from vocab import load_tokener |
|
from utils.zh_util import iter_vocab |
|
from utils.log_util import logger |
|
from utils.compress_rate_util import tokenize_corpus, unit_convertor |
|
from functools import lru_cache |
|
|
|
|
|
@lru_cache |
|
def tokenize(text, tokenizer_type, color_num=5): |
|
""" |
|
""" |
|
logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False)) |
|
pos_tokens = [] |
|
tokenizer = load_tokener(tokenizer_type) |
|
if config.ADD_SPECIAL_TOKEN: |
|
encoding = tokenizer.encode(text, add_special_tokens=True) |
|
else: |
|
encoding = tokenizer.encode(text, add_special_tokens=False) |
|
|
|
table = [] |
|
|
|
for idx, token_id in enumerate(encoding): |
|
decode_text = tokenizer.decode([token_id]) |
|
pos_tokens.extend([(decode_text, str(idx % color_num))]) |
|
|
|
|
|
token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0] |
|
if isinstance(token, bytes): |
|
try: |
|
token_str = token.decode("utf-8") |
|
except: |
|
token_str = token.decode("utf-8", errors="ignore") |
|
logger.error(f"{idx}: decode_error: " + json.dumps( |
|
{"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str}, |
|
ensure_ascii=False)) |
|
|
|
token_bytes = token |
|
|
|
elif isinstance(token, str): |
|
token_str = token |
|
token_bytes = bytes(token_str, "utf-8") |
|
|
|
else: |
|
logger.error(f"{idx}: wrong type for token {token_id} {type(token)} " + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False)) |
|
token_str = token |
|
token_bytes = token |
|
|
|
|
|
|
|
|
|
table.append( |
|
{"TokenID": token_id, |
|
"Token": token_str, |
|
"Text": decode_text, |
|
|
|
"UTF8 Bytes": str(token_bytes), |
|
|
|
} |
|
) |
|
|
|
table_df = pd.DataFrame(table) |
|
logger.info(f"tokenizer_type={tokenizer_type}, Tokens={table[:4]}") |
|
|
|
|
|
return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df |
|
|
|
|
|
@lru_cache |
|
def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2): |
|
""" |
|
input_text.change |
|
""" |
|
pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1) |
|
pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2) |
|
return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2 |
|
|
|
|
|
@lru_cache |
|
def basic_count(tokenizer_type): |
|
tokenizer = load_tokener(tokenizer_type) |
|
stats = iter_vocab(tokenizer) |
|
return tokenizer.vocab_size, f'{stats["中文token数"]}' |
|
|
|
|
|
def get_compress_rate(tokenizer_type, all_corpus, unit): |
|
tokenizer = load_tokener(tokenizer_type) |
|
compress_rate_stats = tokenize_corpus(tokenizer, all_corpus) |
|
compress_rate = unit_convertor(compress_rate_stats, unit) |
|
return compress_rate |
|
|
|
|
|
@lru_cache |
|
def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2): |
|
tokenizer1 = load_tokener(tokenizer_type_1) |
|
tokenizer2 = load_tokener(tokenizer_type_2) |
|
|
|
vocab_set_1 = tokenizer1.get_vocab().keys() |
|
vocab_set_2 = tokenizer2.get_vocab().keys() |
|
|
|
token1 = next(iter(vocab_set_1)) |
|
token2 = next(iter(vocab_set_2)) |
|
if type(token1) != type(token2): |
|
if isinstance(token1, str): |
|
vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1]) |
|
if isinstance(token2, str): |
|
vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2]) |
|
|
|
overlap_tokens = vocab_set_1 & vocab_set_2 |
|
overlap_token_size = len(overlap_tokens) |
|
logger.info( |
|
f"{overlap_token_size} OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}: {list(overlap_tokens)[:10]}") |
|
return overlap_token_size, overlap_token_size |
|
|
|
|
|
|
|
def on_load(url_params, request: gr.Request): |
|
""" |
|
onLoad |
|
""" |
|
text = None |
|
tokenizer_type_1 = None |
|
tokenizer_type_2 = None |
|
try: |
|
url_params = json.loads(url_params) |
|
except: |
|
url_params = {} |
|
if request: |
|
logger.info(str(request.headers)) |
|
client_ip = request.client.host |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer_type_1 = url_params.get("tokenizer1", config.default_tokenizer_type_1) |
|
tokenizer_type_2 = url_params.get("tokenizer2", config.default_tokenizer_type_2) |
|
text = url_params.get("text", config.default_user_input) |
|
logger.info(f"client_ip: {client_ip}; params: {url_params}") |
|
return text, tokenizer_type_1, tokenizer_type_2 |
|
|
|
|
|
def compress_rate_unit_change(unit): |
|
return gr.update(label=f"Compress Rate: {unit}"), gr.update(label=f"Compress Rate: {unit}"), |
|
|
|
def test_coding(): |
|
bytes1 = b'\xe4\xb8\xad' |
|
print(bytes1) |
|
|
|
|
|
if __name__ == "__main__": |
|
print(get_overlap_token_size("gpt_35_turbo", "gpt_4")) |
|
|
|
|