context-probing / app.py
cifkao's picture
Add app files
8443315
raw
history blame
2.86 kB
from pathlib import Path
import streamlit as st
import streamlit.components.v1 as components
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
root_dir = Path(__file__).resolve().parent
_highlighted_text_fn = components.declare_component(
"highlighted_text", path=root_dir / "highlighted_text" / "build"
)
def highlighted_text(tokens, scores, key=None):
return _highlighted_text_fn(tokens=tokens, scores=scores, key=key, default=0)
def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding:
return BatchEncoding({
k: [
t[i][j : j + window_len] + [
pad_id if k == "input_ids" else 0
] * (j + window_len - len(t[i]))
for i in range(len(examples["input_ids"]))
for j in range(0, len(examples["input_ids"][i]) - 1, stride)
]
for k, t in examples.items()
})
BAD_CHAR = chr(0xfffd)
def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
cur_ids = []
result = []
for idx in ids:
cur_ids.append(idx)
decoded = tokenizer.decode(cur_ids)
if BAD_CHAR not in decoded:
if strip_whitespace:
decoded = decoded.strip()
result.append(decoded)
del cur_ids[:]
else:
result.append("")
return result
model_name = st.selectbox("Model", ["distilgpt2"])
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
window_len = st.select_slider("Window size", options=[8, 16, 32, 64, 128, 256, 512, 1024], value=512)
text = st.text_area("Input text", "The complex houses married and single soldiers and their families.")
inputs = tokenizer([text])
[input_ids] = inputs["input_ids"]
window_len = min(window_len, len(input_ids))
tokens = ids_to_readable_tokens(tokenizer, input_ids)
inputs_sliding = get_windows_batched(
inputs,
window_len=window_len,
pad_id=tokenizer.eos_token_id
)
with torch.inference_mode():
logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
logits = F.pad(logits, (0, 0, 0, window_len - 1, 0, 0), value=torch.nan)
logits = logits.view(-1, logits.shape[-1])[:(window_len - 1) * (len(input_ids) + window_len - 2)]
logits = logits.view(window_len - 1, len(input_ids) + window_len - 2, logits.shape[-1])
scores = logits.to(torch.float32).softmax(dim=-1)
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
scores = scores.diff(dim=0).transpose(0, 1)
scores = scores.nan_to_num()
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
scores = scores.to(torch.float16)
print(scores)
st.markdown("---")
highlighted_text(tokens=tokens, scores=scores.tolist())