|
import streamlit as st |
|
from datasets import load_dataset |
|
import os |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
st.set_page_config( |
|
page_title="Logprobs inspection", layout="wide" |
|
) |
|
|
|
st.markdown("# Logprobs inspection") |
|
|
|
@st.cache_data |
|
def load_data(): |
|
ds = load_dataset( |
|
"HuggingFaceTB/sample_log_probs", |
|
split="train", |
|
token=HF_TOKEN, |
|
) |
|
return ds |
|
|
|
ds = load_data() |
|
|
|
min_log = min(ds["logprobs"]) |
|
max_log = max(ds["logprobs"]) |
|
col_1, col_2 = st.columns(2) |
|
with col_1: |
|
min_score = st.slider("Select minimum logprob", min_value=min_log, max_value=max_log, value=min_log, step=0.2, key="min_score") |
|
with col_2: |
|
max_score = st.slider("Select maximum logprob", min_value=min_log, max_value=max_log, value=max_log, step=0.2, key="max_score") |
|
|
|
filtered_ds = ds.filter(lambda x: min_score <= x["logprobs"] <= max_score) |
|
index = st.slider("Select a sample", 0, len(filtered_ds), 0) |
|
|
|
with st.expander("The prompt"): |
|
st.markdown(filtered_ds[index]['prompt']) |
|
|
|
st.markdown(f"**Metadata:** log_prob is {filtered_ds[index]['logprobs']:.2f}, seed: {filtered_ds[index]['seed_data']}, {filtered_ds[index]['format']} for {filtered_ds[index]['audience']}.") |
|
st.markdown(filtered_ds[index]["text"]) |
|
|