gemini-tokens / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
b9c9f51 verified
import os
import streamlit as st
import sentencepiece as spm
import pandas as pd
import numpy as np
st.set_page_config(layout="wide")
st.title("Gemini Token Probabilities")
# --- Load the SentencePiece tokenizer once ---
@st.cache_resource
def load_tokenizer():
# Determine the directory that this script lives in (i.e. src/)
here = os.path.dirname(__file__)
# Build the absolute path to the gemini-1.5-pro-002 folder inside src/
model_dir = os.path.join(here, "gemini-1.5-pro-002")
model_path = os.path.join(model_dir, "gemini-1.5-pro-002.spm.model")
if not os.path.isfile(model_path):
st.error(f"Cannot find model at:\n{model_path}")
st.stop()
sp = spm.SentencePieceProcessor()
sp.Load(model_path)
return sp
sp = load_tokenizer()
# --- Precompute global min and max raw log‐probs over the entire vocab ---
@st.cache_data
def compute_vocab_min_max(_sp: spm.SentencePieceProcessor):
scores = np.array([_sp.GetScore(i) for i in range(_sp.GetPieceSize())], dtype=float)
return float(scores.min()), float(scores.max())
global_min, global_max = compute_vocab_min_max(sp)
# --- User input area ---
text = st.text_area("Enter text to tokenize:", "")
if st.button("Tokenize"):
if not text.strip():
st.warning("Enter some text first.")
else:
# 1) Tokenize into subword pieces and IDs
pieces = sp.EncodeAsPieces(text)
ids = sp.EncodeAsIds(text)
# 2) Retrieve raw log‐probability for each input piece
raw_scores = np.array([sp.GetScore(tid) for tid in ids], dtype=float)
# 3) Normalize each raw_score against [global_min, global_max] → [0, 1]
if global_max != global_min:
normalized_0_1 = (raw_scores - global_min) / (global_max - global_min)
else:
normalized_0_1 = np.zeros_like(raw_scores)
# 4) Build DataFrame
df = pd.DataFrame({
"Token": pieces,
# Pass the 0–1 values into “Global Normalized” column
"Global Normalized": normalized_0_1
})
# 5) Display with progress bars (as percentages)
st.dataframe(
df,
use_container_width=True,
column_config={
"Global Normalized": st.column_config.ProgressColumn(
"Score (percent)",
help="Raw log-prob min–max normalized over full vocab, shown as %",
format="percent",
min_value=0.0,
max_value=1.0
)
}
)