Spaces:
Running
Running
import os | |
import streamlit as st | |
import sentencepiece as spm | |
import pandas as pd | |
import numpy as np | |
st.set_page_config(layout="wide") | |
st.title("Gemini Token Probabilities") | |
# --- Load the SentencePiece tokenizer once --- | |
def load_tokenizer(): | |
# Determine the directory that this script lives in (i.e. src/) | |
here = os.path.dirname(__file__) | |
# Build the absolute path to the gemini-1.5-pro-002 folder inside src/ | |
model_dir = os.path.join(here, "gemini-1.5-pro-002") | |
model_path = os.path.join(model_dir, "gemini-1.5-pro-002.spm.model") | |
if not os.path.isfile(model_path): | |
st.error(f"Cannot find model at:\n{model_path}") | |
st.stop() | |
sp = spm.SentencePieceProcessor() | |
sp.Load(model_path) | |
return sp | |
sp = load_tokenizer() | |
# --- Precompute global min and max raw log‐probs over the entire vocab --- | |
def compute_vocab_min_max(_sp: spm.SentencePieceProcessor): | |
scores = np.array([_sp.GetScore(i) for i in range(_sp.GetPieceSize())], dtype=float) | |
return float(scores.min()), float(scores.max()) | |
global_min, global_max = compute_vocab_min_max(sp) | |
# --- User input area --- | |
text = st.text_area("Enter text to tokenize:", "") | |
if st.button("Tokenize"): | |
if not text.strip(): | |
st.warning("Enter some text first.") | |
else: | |
# 1) Tokenize into subword pieces and IDs | |
pieces = sp.EncodeAsPieces(text) | |
ids = sp.EncodeAsIds(text) | |
# 2) Retrieve raw log‐probability for each input piece | |
raw_scores = np.array([sp.GetScore(tid) for tid in ids], dtype=float) | |
# 3) Normalize each raw_score against [global_min, global_max] → [0, 1] | |
if global_max != global_min: | |
normalized_0_1 = (raw_scores - global_min) / (global_max - global_min) | |
else: | |
normalized_0_1 = np.zeros_like(raw_scores) | |
# 4) Build DataFrame | |
df = pd.DataFrame({ | |
"Token": pieces, | |
# Pass the 0–1 values into “Global Normalized” column | |
"Global Normalized": normalized_0_1 | |
}) | |
# 5) Display with progress bars (as percentages) | |
st.dataframe( | |
df, | |
use_container_width=True, | |
column_config={ | |
"Global Normalized": st.column_config.ProgressColumn( | |
"Score (percent)", | |
help="Raw log-prob min–max normalized over full vocab, shown as %", | |
format="percent", | |
min_value=0.0, | |
max_value=1.0 | |
) | |
} | |
) | |