|
from __future__ import annotations |
|
|
|
import glob |
|
import io |
|
import os |
|
import random |
|
import struct |
|
from contextlib import contextmanager |
|
from html import escape |
|
|
|
import msgpack |
|
import streamlit as st |
|
import torch |
|
import tqdm |
|
from huggingface_hub import HfFileSystem |
|
from transformers import AutoTokenizer |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
MODEL_NAME = os.environ.get("MODEL_NAME", "MonetLLM/monet-vd-1.4B-100BT-hf") |
|
CONTEXT_WINDOW = int(os.environ.get("CONTEXT_WINDOW", "12")) |
|
CANDIDATE_THRESHOLD = int(os.environ.get("CANDIDATE_THRESHOLD", "50")) |
|
|
|
HORIZONTAL_STYLE = """<style class="hide-element"> |
|
/* Hides the style container and removes the extra spacing */ |
|
.element-container:has(.hide-element) { |
|
display: none; |
|
} |
|
/* |
|
The selector for >.element-container is necessary to avoid selecting the whole |
|
body of the streamlit app, which is also a stVerticalBlock. |
|
*/ |
|
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) { |
|
display: flex; |
|
flex-direction: row !important; |
|
flex-wrap: wrap; |
|
gap: 0.5rem; |
|
align-items: baseline; |
|
} |
|
/* Buttons and their parent container all have a width of 704px, which we need to override */ |
|
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div { |
|
width: max-content !important; |
|
} |
|
/* Just an example of how you would style buttons, if desired */ |
|
/* |
|
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) button { |
|
border-color: red; |
|
} |
|
*/ |
|
</style>""" |
|
|
|
|
|
@st.cache_resource |
|
def prepare_routing_resources(): |
|
fs = HfFileSystem() |
|
for filename in fs.glob(f"datasets/{MODEL_NAME}-viewer-data/*"): |
|
if not os.path.exists(os.path.basename(filename)): |
|
print(f"[*] Download {filename}...") |
|
fs.download(filename, ".") |
|
|
|
input_tokens = torch.load("inputs.pt") |
|
|
|
examples_tables = [] |
|
for i in tqdm.trange(len(glob.glob("examples-*.msgpack"))): |
|
with open(f"examples-{i}.msgpack", "rb") as fp: |
|
fp.seek(-4, io.SEEK_END) |
|
table_size = struct.unpack(">I", fp.read(4))[0] |
|
|
|
fp.seek(-(table_size + 4), io.SEEK_END) |
|
examples_tables.append(msgpack.Unpacker(fp).unpack()) |
|
|
|
candidates = [] |
|
for i, table in enumerate(tqdm.tqdm(examples_tables)): |
|
candidates.append([]) |
|
with open(f"examples-{i}.msgpack", "rb") as fp: |
|
unpacker = msgpack.Unpacker(fp) |
|
for j in range(len(table)): |
|
if len(unpacker.unpack()) > CANDIDATE_THRESHOLD: |
|
candidates[-1].append(j) |
|
|
|
routing_tables = [] |
|
for i in tqdm.trange(len(examples_tables)): |
|
with open(f"routings-{i}.msgpack", "rb") as fp: |
|
fp.seek(-4, io.SEEK_END) |
|
table_size = struct.unpack(">I", fp.read(4))[0] |
|
|
|
fp.seek(-(table_size + 4), io.SEEK_END) |
|
routing_tables.append(msgpack.Unpacker(fp).unpack()) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
return input_tokens, examples_tables, routing_tables, candidates, tokenizer |
|
|
|
|
|
input_tokens, examples_tables, routing_tables, candidates, tokenizer = ( |
|
prepare_routing_resources() |
|
) |
|
|
|
|
|
def render_routing_examples_in_html(router_index: int, expert_id: int) -> str: |
|
with open(f"examples-{router_index}.msgpack", "rb") as fp: |
|
fp.seek(examples_tables[router_index][expert_id]) |
|
examples = msgpack.Unpacker(fp).unpack() |
|
with open(f"routings-{router_index}.msgpack", "rb") as fp: |
|
table = [] |
|
for i, j, _ in examples: |
|
start = max(j - CONTEXT_WINDOW, 0) |
|
end = min(j + CONTEXT_WINDOW, len(routing_tables[router_index][i])) |
|
|
|
fp.seek(routing_tables[router_index][i][start]) |
|
unpacker = msgpack.Unpacker(fp, strict_map_key=False) |
|
activated = [unpacker.unpack().get(expert_id, 0) for _ in range(start, end)] |
|
|
|
full_text = tokenizer.decode(input_tokens[i]) |
|
encodings = tokenizer(full_text, add_special_tokens=False) |
|
offset = len(encodings.input_ids) - input_tokens.size(1) |
|
|
|
spans, lslice = [], None |
|
for k in range(start, end): |
|
if offset + k >= 0 and (sslice := encodings.token_to_chars(offset + k)): |
|
span, score = full_text[slice(*sslice)], activated[k - start] |
|
if lslice == sslice: |
|
score = max(spans.pop(-1)[1], score) |
|
spans.append((escape(span), score)) |
|
lslice = sslice |
|
|
|
spans = [ |
|
f"<span style='background-color: rgba(144, 238, 144, {score}' title='Routing: {score*100:.2f}%'>{span}</span>" |
|
for span, score in spans |
|
] |
|
table.append( |
|
f""" |
|
<tr> |
|
<td align='right'> |
|
<span style='font-weight: bold'> |
|
{escape(tokenizer.decode(input_tokens[i, j]))} ({activated[j - start] * 100:.2f}%) |
|
</span> |
|
</td> |
|
<td align='left'> |
|
(...) {"".join(spans)} (...) |
|
</td> |
|
<td align='right'> |
|
({i}, {j}) |
|
</td> |
|
</tr> |
|
""" |
|
) |
|
|
|
return f""" |
|
<div style='background-color: white; color: black; padding: 1em 3em; font-size: 12pt'> |
|
<h2 style='font-size: 18pt'> Activated Examples of Group {router_index} / Expert {expert_id} </h2> |
|
<table> |
|
{"".join(table)} |
|
</table> |
|
</div> |
|
""" |
|
|
|
|
|
@contextmanager |
|
def st_horizontal(): |
|
st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True) |
|
with st.container(): |
|
st.markdown( |
|
'<span class="hide-element horizontal-marker"></span>', |
|
unsafe_allow_html=True, |
|
) |
|
yield |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
router_groups = [f"Routing Group {i}" for i in range(len(examples_tables))] |
|
router_index = st.selectbox("Expert Routing Group", router_groups, index=4) |
|
with col2: |
|
expert_id = st.number_input("Expert Index", 0, len(examples_tables[0]), 54136) |
|
|
|
with st_horizontal(): |
|
show_btn = st.button("Show") |
|
random_btn = st.button("Random") |
|
|
|
if show_btn or random_btn: |
|
router_index = router_groups.index(router_index) |
|
if random_btn: |
|
expert_id = random.choice(candidates[router_index]) |
|
st.html(render_routing_examples_in_html(router_index, expert_id)) |
|
|