prasanna kumar commited on
Commit
0d3569b
β€’
1 Parent(s): 6e43644

added openai and cohere models support along with token visuvalizations

Browse files
Files changed (2) hide show
  1. app.py +88 -24
  2. requirements.txt +4 -0
app.py CHANGED
@@ -4,11 +4,24 @@ import ast
4
  from collections import Counter
5
  import re
6
  import plotly.graph_objs as go
 
 
 
 
7
 
8
  model_path = "models/"
9
 
10
  # Available models
11
- MODELS = ["Meta-Llama-3.1-8B", "gemma-2b"]
 
 
 
 
 
 
 
 
 
12
 
13
  def create_vertical_histogram(data, title):
14
  labels, values = zip(*data) if data else ([], [])
@@ -25,32 +38,80 @@ def create_vertical_histogram(data, title):
25
  )
26
  return fig
27
 
28
- def process_text(text:str,model_name):
29
- tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
30
- token_ids = tokenizer.encode(text, add_special_tokens=True)
31
- tokens = tokenizer.convert_ids_to_tokens(token_ids)
32
- return text,tokens,token_ids,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def process_ids(ids:str,model_name):
35
- tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
36
  token_ids = ast.literal_eval(ids)
37
- text = tokenizer.decode(token_ids)
38
- tokens = tokenizer.convert_ids_to_tokens(token_ids)
39
- return text,tokens,token_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- def process_input(input_type, input_value, model_name):
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if input_type == "Text":
44
- text,tokens,token_ids = process_text(text=input_value,model_name=model_name)
45
  elif input_type == "Token IDs":
46
- text,tokens,token_ids = process_ids(ids=input_value,model_name=model_name)
47
 
48
  character_count = len(text)
49
  word_count = len(text.split())
50
 
51
-
52
- space_count = sum(1 for token in tokens if token == '▁')
53
- special_char_count = sum(1 for token in tokens if not token.isalnum() and token != '▁')
54
 
55
  words = re.findall(r'\b\w+\b', text.lower())
56
  special_chars = re.findall(r'[^\w\s]', text)
@@ -71,7 +132,9 @@ def process_input(input_type, input_value, model_name):
71
  analysis += f"Special character tokens: {special_char_count}\n"
72
  analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}"
73
 
74
- return analysis, text,tokens, str(token_ids), words_hist, special_chars_hist, numbers_hist
 
 
75
 
76
  def text_example():
77
  return "Hello, world! This is an example text input for tokenization."
@@ -85,8 +148,9 @@ with gr.Blocks() as iface:
85
 
86
  with gr.Row():
87
  input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text")
88
- model_name = gr.Dropdown(choices=MODELS, label="Select Model",value=MODELS[0])
89
 
 
90
  input_text = gr.Textbox(lines=5, label="Input")
91
 
92
  with gr.Row():
@@ -96,8 +160,8 @@ with gr.Blocks() as iface:
96
  submit_button = gr.Button("Process")
97
 
98
  analysis_output = gr.Textbox(label="Analysis", lines=6)
99
- text_output = gr.Textbox(label="Text",lines=6)
100
- tokens_output = gr.Textbox(label="Tokens", lines=3)
101
  token_ids_output = gr.Textbox(label="Token IDs", lines=2)
102
 
103
  with gr.Row():
@@ -117,8 +181,8 @@ with gr.Blocks() as iface:
117
 
118
  submit_button.click(
119
  process_input,
120
- inputs=[input_type, input_text, model_name],
121
- outputs=[analysis_output,text_output ,tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot]
122
  )
123
 
124
  if __name__ == "__main__":
 
4
  from collections import Counter
5
  import re
6
  import plotly.graph_objs as go
7
+ import html
8
+ import random
9
+ import tiktoken
10
+ import anthropic
11
 
12
  model_path = "models/"
13
 
14
  # Available models
15
+ MODELS = ["Meta-Llama-3.1-8B", "gemma-2b", "gpt-3.5-turbo","gpt-4","gpt-4o" "Claude-3-Sonnet"]
16
+ openai_models = ["gpt-3.5-turbo","gpt-4","gpt-4o"]
17
+ # Color palette visible on both light and dark themes
18
+ COLOR_PALETTE = [
19
+ "#e6194B", "#3cb44b", "#ffe119", "#4363d8",
20
+ "#f58231", "#911eb4", "#42d4f4", "#f032e6",
21
+ "#bfef45", "#fabed4", "#469990", "#dcbeff",
22
+ "#9A6324", "#fffac8", "#800000", "#aaffc3",
23
+ "#808000", "#ffd8b1", "#000075", "#a9a9a9"
24
+ ]
25
 
26
  def create_vertical_histogram(data, title):
27
  labels, values = zip(*data) if data else ([], [])
 
38
  )
39
  return fig
40
 
41
+ def process_text(text: str, model_name: str, api_key: str = None):
42
+ if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]:
43
+ tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
44
+ token_ids = tokenizer.encode(text, add_special_tokens=True)
45
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
46
+ elif model_name in openai_models:
47
+ encoding = tiktoken.encoding_for_model(model_name=model_name)
48
+ token_ids = encoding.encode(text)
49
+ tokens = [encoding.decode([id]) for id in token_ids]
50
+ elif model_name == "Claude-3-Sonnet":
51
+ if not api_key:
52
+ raise ValueError("API key is required for Claude models")
53
+ client = anthropic.Anthropic(api_key=api_key)
54
+ tokenizer = client.get_tokenizer()
55
+ token_ids = tokenizer.encode(text).ids
56
+ tokens = [tokenizer.decode([id]) for id in token_ids]
57
+ else:
58
+ raise ValueError(f"Unsupported model: {model_name}")
59
+
60
+ return text, tokens, token_ids
61
 
62
+ def process_ids(ids: str, model_name: str, api_key: str = None):
 
63
  token_ids = ast.literal_eval(ids)
64
+ if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]:
65
+ tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
66
+ text = tokenizer.decode(token_ids)
67
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
68
+ elif model_name == openai_models:
69
+ encoding = tiktoken.encoding_for_model(model_name=model_name)
70
+ text = encoding.decode(token_ids)
71
+ tokens = [encoding.decode([id]) for id in token_ids]
72
+ elif model_name == "Claude-3-Sonnet":
73
+ client = anthropic.Anthropic(api_key=api_key)
74
+ tokenizer = client.get_tokenizer()
75
+ text = tokenizer.decode(token_ids)
76
+ tokens = [tokenizer.decode([id]) for id in token_ids]
77
+ else:
78
+ raise ValueError(f"Unsupported model: {model_name}")
79
+
80
+ return text, tokens, token_ids
81
 
82
+ def get_token_color(token, token_colors):
83
+ if token.startswith('<') and token.endswith('>'):
84
+ return "#42d4f4" # Cyan for special tokens
85
+ elif token == '▁' or token == ' ':
86
+ return "#3cb44b" # Green for space tokens
87
+ elif not token.isalnum():
88
+ return "#f032e6" # Magenta for special characters
89
+ else:
90
+ if token not in token_colors:
91
+ token_colors[token] = random.choice(COLOR_PALETTE)
92
+ return token_colors[token]
93
+
94
+ def create_html_tokens(tokens):
95
+ html_output = '<div style="font-family: monospace; border: 1px solid #ccc; padding: 10px; border-radius: 5px; background-color: #f9f9f9; white-space: pre-wrap; word-break: break-all;">'
96
+ token_colors = {}
97
+ for token in tokens:
98
+ color = get_token_color(token, token_colors)
99
+ escaped_token = html.escape(token)
100
+ html_output += f'<span style="background-color: {color}; color: black; padding: 2px 4px; margin: 1px; border-radius: 3px; display: inline-block;">{escaped_token}</span>'
101
+ html_output += '</div>'
102
+ return html_output
103
+
104
+ def process_input(input_type, input_value, model_name, api_key):
105
  if input_type == "Text":
106
+ text, tokens, token_ids = process_text(text=input_value, model_name=model_name, api_key=api_key)
107
  elif input_type == "Token IDs":
108
+ text, tokens, token_ids = process_ids(ids=input_value, model_name=model_name, api_key=api_key)
109
 
110
  character_count = len(text)
111
  word_count = len(text.split())
112
 
113
+ space_count = sum(1 for token in tokens if token in ['▁', ' '])
114
+ special_char_count = sum(1 for token in tokens if not token.isalnum() and token not in ['▁', ' '])
 
115
 
116
  words = re.findall(r'\b\w+\b', text.lower())
117
  special_chars = re.findall(r'[^\w\s]', text)
 
132
  analysis += f"Special character tokens: {special_char_count}\n"
133
  analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}"
134
 
135
+ html_tokens = create_html_tokens(tokens)
136
+
137
+ return analysis, text, html_tokens, str(token_ids), words_hist, special_chars_hist, numbers_hist
138
 
139
  def text_example():
140
  return "Hello, world! This is an example text input for tokenization."
 
148
 
149
  with gr.Row():
150
  input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text")
151
+ model_name = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0])
152
 
153
+ api_key = gr.Textbox(label="API Key Claude models)", type="password")
154
  input_text = gr.Textbox(lines=5, label="Input")
155
 
156
  with gr.Row():
 
160
  submit_button = gr.Button("Process")
161
 
162
  analysis_output = gr.Textbox(label="Analysis", lines=6)
163
+ text_output = gr.Textbox(label="Text", lines=6)
164
+ tokens_output = gr.HTML(label="Tokens")
165
  token_ids_output = gr.Textbox(label="Token IDs", lines=2)
166
 
167
  with gr.Row():
 
181
 
182
  submit_button.click(
183
  process_input,
184
+ inputs=[input_type, input_text, model_name, api_key],
185
+ outputs=[analysis_output, text_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot]
186
  )
187
 
188
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,11 +1,13 @@
1
  aiofiles==23.2.1
2
  annotated-types==0.7.0
 
3
  anyio==4.4.0
4
  certifi==2024.7.4
5
  charset-normalizer==3.3.2
6
  click==8.1.7
7
  contourpy==1.2.1
8
  cycler==0.12.1
 
9
  fastapi==0.112.2
10
  ffmpy==0.4.0
11
  filelock==3.15.4
@@ -20,6 +22,7 @@ huggingface-hub==0.24.6
20
  idna==3.8
21
  importlib_resources==6.4.4
22
  Jinja2==3.1.4
 
23
  kiwisolver==1.4.5
24
  markdown-it-py==3.0.0
25
  MarkupSafe==2.1.5
@@ -51,6 +54,7 @@ six==1.16.0
51
  sniffio==1.3.1
52
  starlette==0.38.2
53
  tenacity==9.0.0
 
54
  tokenizers==0.19.1
55
  tomlkit==0.12.0
56
  tqdm==4.66.5
 
1
  aiofiles==23.2.1
2
  annotated-types==0.7.0
3
+ anthropic==0.34.1
4
  anyio==4.4.0
5
  certifi==2024.7.4
6
  charset-normalizer==3.3.2
7
  click==8.1.7
8
  contourpy==1.2.1
9
  cycler==0.12.1
10
+ distro==1.9.0
11
  fastapi==0.112.2
12
  ffmpy==0.4.0
13
  filelock==3.15.4
 
22
  idna==3.8
23
  importlib_resources==6.4.4
24
  Jinja2==3.1.4
25
+ jiter==0.5.0
26
  kiwisolver==1.4.5
27
  markdown-it-py==3.0.0
28
  MarkupSafe==2.1.5
 
54
  sniffio==1.3.1
55
  starlette==0.38.2
56
  tenacity==9.0.0
57
+ tiktoken==0.7.0
58
  tokenizers==0.19.1
59
  tomlkit==0.12.0
60
  tqdm==4.66.5