File size: 4,286 Bytes
57b690d
 
 
 
 
 
 
 
 
 
 
fbd4fd8
2acbb98
 
5942a55
 
fbd4fd8
2acbb98
fbd4fd8
2acbb98
 
fbd4fd8
2acbb98
 
 
45b4bd7
fbd4fd8
2acbb98
 
 
 
fbd4fd8
 
2acbb98
 
fbd4fd8
 
2acbb98
57b690d
2acbb98
57b690d
fbd4fd8
57b690d
 
 
 
 
d960dc2
57b690d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5942a55
 
57b690d
5942a55
57b690d
5942a55
57b690d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b4bd7
 
 
 
 
 
 
 
5942a55
d95387e
 
57b690d
fbd4fd8
57b690d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import io
import json
import re

import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer


tokenizers = {
    "bert": "google-bert/bert-base-uncased",
    "bge-en": "BAAI/bge-base-en-v1.5",
    "bge-zh": "BAAI/bge-base-zh-v1.5",
    "blenderbot": "facebook/blenderbot-3B",
    "bloom": "bigscience/bloom-560m",
    "bloomz": "bigscience/bloomz-7b1",
    "chatglm3": "THUDM/chatglm3-6b",
    "falcon": "tiiuae/falcon-7b",
    "gemma": "fxmarty/tiny-random-GemmaForCausalLM",
    "gpt-neox": "EleutherAI/gpt-neox-20b",
    "llama": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
    "magicoder": "ise-uiuc/Magicoder-S-DS-6.7B",
    "mistral": "echarlaix/tiny-random-mistral",
    "mpt": "mosaicml/mpt-7b",
    "opt": "facebook/opt-2.7b",
    "phi-2": "microsoft/phi-2",
    "pythia": "EleutherAI/pythia-1.4b-deduped",
    "qwen": "Qwen/Qwen1.5-7B-Chat",
    "redpajama": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
    "roberta": "FacebookAI/roberta-base",
    "starcoder": "bigcode/starcoder2-7b",
    "t5": "google-t5/t5-base",
    "vicuna": "lmsys/vicuna-7b-v1.5",
    "zephyr": "HuggingFaceH4/zephyr-7b-beta",
}

tokenizers = list(tokenizers.values())


def plot_histogram(data):
    plt.hist(data)
    plt.title("Histogram of number of tokens per dataset item")
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close()
    buf.seek(0)
    im = Image.open(buf)
    return im


def count(model_id, dataset_id, config, split, column, add_special_tokens=True):
    tokencounter = []
    wordcounter = []
    charcounter = []
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if config == "":
        config is None
    dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
    pattern = r"[a-zA-Z]+"
    for item in dataset:
        tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"]
        tokencounter.append(len(tokens))
        charcounter.append(len(item[column]))
        # not 100% accurate but good enough
        words = re.findall(pattern, item[column])
        wordcounter.append(len(words))
    percentiles = [0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
    df = pd.DataFrame(tokencounter).describe(percentiles=percentiles).T
    df.insert(0, "type", "tokens")
    dfc = pd.DataFrame(charcounter).describe(percentiles=percentiles).T
    dfc.insert(0, "type", "chars")
    dfw = pd.DataFrame(wordcounter).describe(percentiles=percentiles).T
    dfw.insert(0, "type", "words")
    df.loc[-1] = dfw.values[0]
    df.index = df.index + 1  # shifting index
    df.loc[-1] = dfc.values[0]
    df = df.round(1)
    df.drop("count", axis=1, inplace=True)

    return plot_histogram(tokencounter), df


demo = gr.Interface(
    fn=count,
    title="Dataset token counts and distribution",
    inputs=[
        gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True),
        gr.Textbox(label="Dataset"),
        gr.Textbox(label="Config"),
        gr.Textbox(label="Split"),
        gr.Textbox(label="Column"),
        gr.Checkbox(label="Add special tokens", value=True),
    ],
    outputs=[
        gr.Image(),
        gr.Dataframe(label="Token, word and character counts per dataset item"),
    ],
    examples=[
        ["tiiuae/falcon-7b", "gsarti/flores_101", "eng", "dev", "sentence"],
        ["tiiuae/falcon-7b", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"],
        ["tiiuae/falcon-7b", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"],
        ["tiiuae/falcon-7b", "gsm8k", "main", "test", "question"],
        ["tiiuae/falcon-7b", "locuslab/TOFU", "world_facts", "train", "question"],
        ["tiiuae/falcon-7b", "imdb", "", "test", "text"],
        ["tiiuae/falcon-7b", "wikitext", "wikitext-2-v1", "validation", "text"],
        ["tiiuae/falcon-7b", "zeroshot/twitter-financial-news-sentiment", "", "validation", "text"],
        ["BAAI/bge-base-en-v1.5", "PolyAI/banking77", "", "test", "text"],
        ["BAAI/bge-base-en-v1.5", "mteb/amazon_massive_intent, "en", "test", "text"],
        ["BAAI/bge-base-en-v1.5", "mteb/sts16-sts", "", "test", "sentence1"],
    ],
    cache_examples=True,
)

demo.launch()