| | |
| |
|
| | import random |
| | import time |
| | import urllib.request |
| |
|
| | import gradio as gr |
| | import spaces |
| | import torch |
| | import torch.nn.functional as F |
| | import triton |
| | import triton.language as tl |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| |
|
| | MODEL_ID = "SixOpen/HARE" |
| |
|
| | model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).eval() |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| |
|
| |
|
| | @triton.jit |
| | def _wkv7_fwd_kernel( |
| | R, K, V, DECAY, A, O, |
| | STATE_OUT, STATE_IN, |
| | sab_scale, T, |
| | stride_b, stride_t, stride_h, |
| | H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr, |
| | RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr, |
| | ): |
| | pid = tl.program_id(0) |
| | b_idx = pid // H |
| | h_idx = pid % H |
| | base = b_idx * stride_b + h_idx * stride_h |
| |
|
| | di = tl.arange(0, BLOCK_D) |
| | dj = tl.arange(0, BLOCK_D) |
| | mask_i = di < D |
| | mask_j = dj < D |
| |
|
| | if HAS_INIT_STATE: |
| | s_off = b_idx * (H * D * D) + h_idx * (D * D) |
| | state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :] |
| | state_mask = mask_i[:, None] & mask_j[None, :] |
| | state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32) |
| | else: |
| | state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32) |
| |
|
| | for t in range(T): |
| | t_off = base + t * stride_t |
| | kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) |
| | vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32) |
| | rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) |
| | dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32) |
| | at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) |
| |
|
| | sa = tl.sum(state * (-kt)[None, :], axis=1) |
| | ka = kt * at |
| | sab = sa[:, None] * ka[None, :] |
| | state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :] |
| | state = tl.minimum(tl.maximum(state, -10.0), 10.0) |
| |
|
| | out_t = tl.sum(state * rt[None, :], axis=1) |
| | tl.store(O + t_off + di, out_t, mask=mask_i) |
| |
|
| | if RETURN_STATE: |
| | s_off = b_idx * (H * D * D) + h_idx * (D * D) |
| | state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :] |
| | state_mask = mask_i[:, None] & mask_j[None, :] |
| | tl.store(state_ptrs, state, mask=state_mask) |
| |
|
| |
|
| | def wkv7_scan_triton(r, decay, k, v, a, sab_scale, return_state=False, init_state=None): |
| | B, T, H, D = r.shape |
| | r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)] |
| | o = torch.empty_like(r) |
| | state_out = None |
| | if return_state: |
| | state_out = torch.empty(B, H, D, D, dtype=torch.float32, device=r.device) |
| | has_init = init_state is not None |
| | if has_init: |
| | init_state = init_state.contiguous().float() |
| | stride_b = T * H * D |
| | stride_t = H * D |
| | stride_h = D |
| | BLOCK_D = triton.next_power_of_2(D) |
| | _wkv7_fwd_kernel[(B * H,)]( |
| | r, k, v, decay, a, o, |
| | state_out, init_state, |
| | float(sab_scale), T, |
| | stride_b, stride_t, stride_h, |
| | H=H, D=D, BLOCK_D=BLOCK_D, |
| | RETURN_STATE=return_state, |
| | HAS_INIT_STATE=has_init, |
| | ) |
| | if return_state: |
| | return o, state_out |
| | return o |
| |
|
| |
|
| | def find_birwkv_layers(model): |
| | layers = [] |
| | ids = {} |
| | for m in model.modules(): |
| | if type(m).__name__ == 'BiRWKV7Layer': |
| | ids[id(m)] = len(layers) |
| | layers.append(m) |
| | return layers, ids |
| |
|
| |
|
| | class SpanEncoder: |
| |
|
| | def __init__(self, model, tokenizer, chunk_size=512): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.device = next(model.parameters()).device |
| | self.chunk_size = chunk_size |
| |
|
| | self.birwkv_layers, self.birwkv_ids = find_birwkv_layers(model) |
| | self._originals = {} |
| | self._hooked = False |
| | self._active_states = [None] * len(self.birwkv_layers) |
| | self.span_data = {} |
| |
|
| | def _hook(self): |
| | if self._hooked: |
| | return |
| | for layer in self.birwkv_layers: |
| | self._originals[id(layer)] = layer.forward |
| | layer.forward = self._make_fwd(layer) |
| | self._hooked = True |
| |
|
| | def _unhook(self): |
| | if not self._hooked: |
| | return |
| | for layer in self.birwkv_layers: |
| | layer.forward = self._originals[id(layer)] |
| | self._originals.clear() |
| | self._hooked = False |
| |
|
| | def _make_fwd(self, layer): |
| | enc = self |
| | idx = self.birwkv_ids[id(layer)] |
| |
|
| | def fwd(x, attention_mask=None, **kwargs): |
| | B, T, C_ = x.shape |
| | H, D = layer.num_heads, layer.head_size |
| | prev = enc._active_states[idx] |
| | if prev is not None: |
| | x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1) |
| | else: |
| | x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) |
| |
|
| | def mix(mu): |
| | return x + (x_prev - x) * torch.sigmoid(mu) |
| |
|
| | r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D) |
| | w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D) |
| | k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D) |
| | v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D) |
| | a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D) |
| | g = torch.sigmoid(layer.W_g(mix(layer.mu_g))) |
| | sab_scale = torch.sigmoid(layer.sab_gate) |
| | init_st = prev['wkv_state'] if prev else None |
| |
|
| | r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float() |
| | a_f = torch.sigmoid(a.float()) |
| | decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float())) |
| | out_fwd, wkv_state = wkv7_scan_triton( |
| | r_f, decay, k_f, v_f, a_f, sab_scale, |
| | return_state=True, init_state=init_st) |
| | out_bwd = wkv7_scan_triton( |
| | r_f.flip(1), decay.flip(1), k_f.flip(1), |
| | v_f.flip(1), a_f.flip(1), sab_scale, |
| | return_state=False).flip(1) |
| |
|
| | enc._active_states[idx] = { |
| | 'wkv_state': wkv_state, |
| | 'last_x': x[:, -1:].detach().clone(), |
| | } |
| | out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_) |
| | out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2) |
| | out = layer.W_o(out * g) |
| | return out, None |
| | return fwd |
| |
|
| | @torch.no_grad() |
| | def _forward_encode_raw(self, text, init_states=None, max_length=8192): |
| | self._hook() |
| | if init_states is not None: |
| | self._active_states = [ |
| | {k: v.clone() for k, v in s.items()} if s else None |
| | for s in init_states |
| | ] |
| | else: |
| | self._active_states = [None] * len(self.birwkv_layers) |
| |
|
| | enc = self.tokenizer(text, return_tensors='pt', truncation=True, |
| | max_length=max_length) |
| | ids = enc['input_ids'].to(self.device) |
| | mask = enc['attention_mask'].to(self.device) |
| |
|
| | h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state |
| | content = h[0, 1:-1, :].cpu() |
| | n_content = content.shape[0] |
| |
|
| | final_states = [ |
| | {k: v.clone() for k, v in s.items()} if s else None |
| | for s in self._active_states |
| | ] |
| | self._unhook() |
| | return content, n_content, final_states |
| |
|
| | def _chunk_hidden(self, content, return_residual=False): |
| | T = content.shape[0] |
| | chunks = [] |
| | last_end = 0 |
| | for start in range(0, T, self.chunk_size): |
| | end = min(start + self.chunk_size, T) |
| | if end - start < 32: |
| | break |
| | emb = F.normalize(content[start:end].mean(0, keepdim=True), |
| | p=2, dim=-1) |
| | chunks.append(emb) |
| | last_end = end |
| | if not chunks and T > 0: |
| | chunks.append(F.normalize(content.mean(0, keepdim=True), |
| | p=2, dim=-1)) |
| | last_end = T |
| | if return_residual: |
| | residual = content[last_end:] if last_end < T else None |
| | return chunks, residual |
| | return chunks |
| |
|
| | @torch.no_grad() |
| | def encode_query(self, query): |
| | assert not self._hooked |
| | enc = self.tokenizer(query, return_tensors='pt', truncation=True, |
| | max_length=512) |
| | ids = enc['input_ids'].to(self.device) |
| | mask = enc['attention_mask'].to(self.device) |
| | h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state |
| | m = mask.unsqueeze(-1).float() |
| | emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9) |
| | return F.normalize(emb, p=2, dim=-1).cpu() |
| |
|
| | def encode_span(self, text, key): |
| | content, n_tok, states = self._forward_encode_raw(text) |
| | chunks, residual = self._chunk_hidden(content, return_residual=True) |
| | self.span_data[key] = { |
| | 'layer_states': states, |
| | 'chunk_embs': chunks, |
| | 'n_tokens': n_tok, |
| | 'residual_hidden': residual, |
| | } |
| | return n_tok |
| |
|
| | def extend_right(self, piece_text, old_key, new_key): |
| | old = self.span_data.pop(old_key) |
| | content, n_new, states = self._forward_encode_raw( |
| | piece_text, init_states=old['layer_states']) |
| | if old.get('residual_hidden') is not None: |
| | content = torch.cat([old['residual_hidden'], content], dim=0) |
| | new_chunks, residual = self._chunk_hidden( |
| | content, return_residual=True) |
| | self.span_data[new_key] = { |
| | 'layer_states': states, |
| | 'chunk_embs': old['chunk_embs'] + new_chunks, |
| | 'n_tokens': old['n_tokens'] + n_new, |
| | 'residual_hidden': residual, |
| | } |
| | return n_new |
| |
|
| | def smart_merge(self, new_text, left_key, new_key): |
| | left = self.span_data.pop(left_key) |
| | self.remove_old(new_key) |
| | content, n_new, states = self._forward_encode_raw( |
| | new_text, init_states=left['layer_states']) |
| | if left.get('residual_hidden') is not None: |
| | content = torch.cat([left['residual_hidden'], content], dim=0) |
| | new_chunks, residual = self._chunk_hidden( |
| | content, return_residual=True) |
| | self.span_data[new_key] = { |
| | 'layer_states': states, |
| | 'chunk_embs': left['chunk_embs'] + new_chunks, |
| | 'n_tokens': left['n_tokens'] + n_new, |
| | 'residual_hidden': residual, |
| | } |
| | return n_new |
| |
|
| | def remove_old(self, new_key): |
| | s, e = new_key |
| | for old in list(self.span_data.keys()): |
| | if old[0] >= s and old[1] <= e: |
| | del self.span_data[old] |
| |
|
| | def search(self, q_emb, spans, top_k=5): |
| | results = [] |
| | for s, e, text in spans: |
| | key = (s, e) |
| | data = self.span_data.get(key) |
| | if not data or not data['chunk_embs']: |
| | continue |
| | chunk_mat = torch.cat(data['chunk_embs'], dim=0) |
| | sims = (q_emb @ chunk_mat.T).squeeze(0) |
| | if sims.dim() == 0: |
| | sims = sims.unsqueeze(0) |
| | max_sim = sims.max().item() |
| | best_idx = sims.argmax().item() |
| | n_chunks = len(data['chunk_embs']) |
| | chars_per_chunk = len(text) // max(n_chunks, 1) |
| | offset = min(best_idx * chars_per_chunk, len(text) - 1) |
| | while offset > 0 and text[offset - 1] not in ' \n\t': |
| | offset -= 1 |
| | preview = text[offset:offset + 300].replace('\n', ' ').strip() |
| | results.append((s, e, max_sim, preview, data['n_tokens'], n_chunks)) |
| | results.sort(key=lambda x: x[2], reverse=True) |
| | return results[:top_k] |
| |
|
| |
|
| | class TextProvider: |
| |
|
| | def __init__(self, text, piece_size=4096, seed=42): |
| | self.text = text |
| | self.piece_size = piece_size |
| | self.n_pieces = (len(text) + piece_size - 1) // piece_size |
| | self.received = [False] * self.n_pieces |
| | rng = random.Random(seed) |
| | self.arrival_order = list(range(self.n_pieces)) |
| | rng.shuffle(self.arrival_order) |
| | self.next_idx = 0 |
| |
|
| | def poll_pieces(self): |
| | if self.next_idx >= self.n_pieces: |
| | return [] |
| | idx = self.arrival_order[self.next_idx] |
| | self.received[idx] = True |
| | self.next_idx += 1 |
| | return [idx] |
| |
|
| | def get_spans(self): |
| | spans = [] |
| | i = 0 |
| | while i < self.n_pieces: |
| | if self.received[i]: |
| | j = i |
| | while j < self.n_pieces and self.received[j]: |
| | j += 1 |
| | s_byte = i * self.piece_size |
| | e_byte = min(j * self.piece_size, len(self.text)) |
| | spans.append((i, j, self.text[s_byte:e_byte])) |
| | i = j |
| | else: |
| | i += 1 |
| | return spans |
| |
|
| | def piece_text(self, idx): |
| | s = idx * self.piece_size |
| | return self.text[s:min(s + self.piece_size, len(self.text))] |
| |
|
| | def span_text(self, start_piece, end_piece): |
| | s = start_piece * self.piece_size |
| | e = min(end_piece * self.piece_size, len(self.text)) |
| | return self.text[s:e] |
| |
|
| | def progress(self): |
| | return self.next_idx / self.n_pieces |
| |
|
| | def is_complete(self): |
| | return self.next_idx >= self.n_pieces |
| |
|
| |
|
| | FRANKENSTEIN_EXCERPT = """\ |
| | I am by birth a Genevese; and my family is one of the most distinguished \ |
| | of that republic. My ancestors had been for many years counsellors and \ |
| | syndics; and my father had filled several public situations with honour \ |
| | and reputation. |
| | |
| | When I was thirteen years of age, we all went on a party of pleasure to \ |
| | the baths near Thonon: the inclemency of the weather obliged us to remain \ |
| | a day confined to the inn. In this house I found a volume of the works of \ |
| | Cornelius Agrippa. I opened it with apathy; the theory which he attempts \ |
| | to demonstrate, and the wonderful facts which he relates, soon changed \ |
| | this feeling into enthusiasm. A new light seemed to dawn upon my mind. |
| | |
| | When I returned home, my first care was to procure the whole works of \ |
| | this author. My father was not scientific, and I was left to struggle \ |
| | with a child's blindness, added to a student's thirst for knowledge. \ |
| | Under the guidance of my new preceptors, I entered with the greatest \ |
| | diligence into the search of the philosopher's stone and the elixir \ |
| | of life. What glory would attend the discovery, if I could banish \ |
| | disease from the human frame, and render man invulnerable to any but \ |
| | a violent death! |
| | |
| | It was on a dreary night of November that I beheld the accomplishment \ |
| | of my toils. With an anxiety that almost amounted to agony, I collected \ |
| | the instruments of life around me, that I might infuse a spark of being \ |
| | into the lifeless thing that lay at my feet. It was already one in the \ |
| | morning; the rain pattered dismally against the panes, and my candle was \ |
| | nearly burnt out, when, by the glimmer of the half-extinguished light, \ |
| | I saw the dull yellow eye of the creature open; it breathed hard, and \ |
| | a convulsive motion agitated its limbs. |
| | |
| | How can I describe my emotions at this catastrophe, or how delineate the \ |
| | wretch whom with such infinite pains and care I had endeavoured to form? \ |
| | I had selected his features as beautiful. Beautiful!--Great God! His \ |
| | yellow skin scarcely covered the work of muscles and arteries beneath; \ |
| | his hair was of a lustrous black, and flowing; his teeth of a pearly \ |
| | whiteness; but these luxuriances only formed a more horrid contrast with \ |
| | his watery eyes, that seemed almost of the same colour as the dun white \ |
| | sockets in which they were set, his shrivelled complexion, and straight \ |
| | black lips. |
| | |
| | I had worked hard for nearly two years, for the sole purpose of infusing \ |
| | life into an inanimate body. For this I had deprived myself of rest and \ |
| | health. I had desired it with an ardour that far exceeded moderation; but \ |
| | now that I had finished, the beauty of the dream vanished, and breathless \ |
| | horror and disgust filled my heart. |
| | |
| | I did not dare return to the apartment which I inhabited, but felt \ |
| | impelled to hurry on, although drenched by the rain which poured from a \ |
| | black and comfortless sky. I passed the night wretchedly. Morning, \ |
| | dismal and wet, at length dawned, and discovered to my sleepless and \ |
| | aching eyes the church of Ingolstadt, its white steeple and clock, \ |
| | which indicated the sixth hour. |
| | |
| | "I shall satiate my ardour for destruction," the creature said, "and \ |
| | make you so wretched that the light of day will be hateful to you. I \ |
| | will be with you on your wedding-night." I started forward, and \ |
| | exclaimed, "Villain! before you sign my death-warrant, be sure that \ |
| | you are yourself safe." My rage was without bounds; I would have seized \ |
| | him; but he eluded me, and quitted the house with precipitation. |
| | |
| | Great God! why did I not then expire! But I am a wretch, and none ever \ |
| | conceived of the horrors of my secret toil, whilst I dabbled among the \ |
| | unhallowed damps of the grave, or tortured the living animal to animate \ |
| | the lifeless clay. |
| | |
| | I was soon borne away by the waves, and lost in darkness and distance. \ |
| | Immense and rugged mountains of ice often barred up my passage, and I \ |
| | heard the thunder of the ground sea beneath. The cold is excessive, and \ |
| | many of my unfortunate comrades have already found a grave amidst this \ |
| | scene of desolation. Frankenstein! he is not here: I will not rest; I \ |
| | pursue him still over the untrodden snow and frozen ocean. |
| | """ |
| |
|
| | QUICK_DEMOS = { |
| | "Frankenstein (excerpt)": { |
| | "text": FRANKENSTEIN_EXCERPT, |
| | "queries": [ |
| | "the creature opens its eyes for the first time", |
| | "playing god with science", |
| | "a threat on the wedding night", |
| | "a frozen arctic wasteland", |
| | ], |
| | "piece_size": 512, |
| | "sleep": 0.3, |
| | }, |
| | } |
| |
|
| |
|
| | def render_grid(received, n_pieces, highlight=None): |
| | max_width = 60 |
| | if n_pieces <= max_width: |
| | cells = [] |
| | for i in range(n_pieces): |
| | if i == highlight: |
| | bg = '#00ff41' |
| | elif received[i]: |
| | bg = '#28a745' |
| | else: |
| | bg = '#3a3a3a' |
| | cells.append( |
| | f'<span style="display:inline-block;width:14px;height:22px;' |
| | f'background:{bg};margin:1px;border-radius:2px"></span>' |
| | ) |
| | else: |
| | cells = [] |
| | for col in range(max_width): |
| | s = col * n_pieces // max_width |
| | e = (col + 1) * n_pieces // max_width |
| | ratio = sum(received[s:e]) / max(1, e - s) |
| | hl = highlight is not None and s <= highlight < e |
| | if hl: |
| | bg = '#00ff41' |
| | elif ratio > 0.8: |
| | bg = '#28a745' |
| | elif ratio > 0.3: |
| | bg = '#17a2b8' |
| | elif ratio > 0: |
| | bg = '#6c757d' |
| | else: |
| | bg = '#3a3a3a' |
| | cells.append( |
| | f'<span style="display:inline-block;width:14px;height:22px;' |
| | f'background:{bg};margin:1px;border-radius:2px"></span>' |
| | ) |
| |
|
| | n_recv = sum(received) |
| | pct = n_recv / max(n_pieces, 1) * 100 |
| | grid = ''.join(cells) |
| | return ( |
| | f'<div style="font-family:monospace;line-height:1.4;padding:8px 0">' |
| | f'<div style="display:flex;flex-wrap:wrap;gap:0">{grid}</div>' |
| | f'<div style="margin-top:8px;color:#aaa">' |
| | f'Piece {n_recv}/{n_pieces} ({pct:.0f}%)</div></div>' |
| | ) |
| |
|
| |
|
| | def render_search(results_dict, peak_scores=None): |
| | if not results_dict: |
| | return '<p style="color:#888">Waiting for data...</p>' |
| |
|
| | def _score_color(score): |
| | if score > 0.5: |
| | return '#28a745' |
| | elif score > 0.4: |
| | return '#ffc107' |
| | return '#aaa' |
| |
|
| | parts = [] |
| | for query, results in results_dict.items(): |
| | peak = peak_scores.get(query) if peak_scores else None |
| | header = f'"{query}"' |
| | if peak: |
| | header += (f' <span style="color:#888;font-size:0.85em">' |
| | f'(peak: {peak["score"]:.3f})</span>') |
| | parts.append( |
| | f'<div style="margin-bottom:16px">' |
| | f'<div style="font-weight:bold;color:#58a6ff;margin-bottom:6px">' |
| | f'{header}</div>' |
| | ) |
| |
|
| | cur_best = results[0]['score'] if results else 0 |
| | if peak and peak['score'] > cur_best + 0.01: |
| | psc = _score_color(peak['score']) |
| | pp = peak['preview'][:300].replace('<', '<').replace('>', '>') |
| | parts.append( |
| | f'<div style="padding:4px 0 4px 12px;border-left:3px solid {psc};' |
| | f'background:rgba(40,167,69,0.08);margin-bottom:2px">' |
| | f'<span style="color:{psc};font-weight:bold">{peak["score"]:.3f}</span> ' |
| | f'<span style="color:#888;font-size:0.85em">peak</span><br>' |
| | f'<span style="color:#ccc;font-size:0.9em">{pp}...</span>' |
| | f'</div>' |
| | ) |
| |
|
| | if not results: |
| | parts.append('<div style="color:#888;padding-left:12px">No results yet</div>') |
| | else: |
| | for rank, r in enumerate(results[:3], 1): |
| | sc = _score_color(r['score']) |
| | preview = r['preview'][:300].replace('<', '<').replace('>', '>') |
| | parts.append( |
| | f'<div style="padding:4px 0 4px 12px;border-left:3px solid {sc}">' |
| | f'<span style="color:{sc};font-weight:bold">{r["score"]:.3f}</span> ' |
| | f'<span style="color:#888">[{r["span"][0]}-{r["span"][1]}]' |
| | f' ({r["n_chunks"]}ch)</span><br>' |
| | f'<span style="color:#ccc;font-size:0.9em">{preview}...</span>' |
| | f'</div>' |
| | ) |
| | parts.append('</div>') |
| | return ''.join(parts) |
| |
|
| |
|
| | def _state_color(intensity): |
| | h = int(220 - intensity * 170) |
| | s = int(20 + intensity * 70) |
| | light = int(12 + intensity * 38) |
| | return f'hsl({h},{s}%,{light}%)' |
| |
|
| |
|
| | def render_state_viz(state_history, n_layers=14): |
| | if not state_history: |
| | return ('<p style="color:#888">Recurrent state evolution will appear ' |
| | 'as pieces are processed...</p>') |
| |
|
| | n_steps = len(state_history) |
| | cell_w = max(4, min(14, 600 // max(n_steps, 1))) |
| |
|
| | layer_maxes = [] |
| | for li in range(n_layers): |
| | vals = [state_history[t][li] for t in range(n_steps) |
| | if li < len(state_history[t])] |
| | layer_maxes.append(max(vals) if vals else 1.0) |
| |
|
| | rows = [] |
| | for li in range(n_layers): |
| | cells = [] |
| | for t in range(n_steps): |
| | if li < len(state_history[t]): |
| | norm = state_history[t][li] |
| | intensity = min(norm / max(layer_maxes[li], 1e-6), 1.0) |
| | cells.append( |
| | f'<span style="display:inline-block;width:{cell_w}px;' |
| | f'height:12px;background:{_state_color(intensity)};' |
| | f'margin:0 1px"></span>') |
| | rows.append( |
| | f'<div style="display:flex;align-items:center;margin:0">' |
| | f'<span style="width:24px;color:#666;font-size:9px;' |
| | f'text-align:right;margin-right:3px;flex-shrink:0">R{li+1}</span>' |
| | f'<div style="display:flex">{"".join(cells)}</div>' |
| | f'</div>') |
| |
|
| | latest = state_history[-1] |
| | avg_norm = sum(latest) / len(latest) if latest else 0 |
| |
|
| | most_active = 0 |
| | max_delta = 0 |
| | if len(state_history) >= 2: |
| | prev = state_history[-2] |
| | for li in range(min(len(latest), len(prev))): |
| | d = abs(latest[li] - prev[li]) |
| | if d > max_delta: |
| | max_delta = d |
| | most_active = li |
| |
|
| | legend = ''.join( |
| | f'<span style="display:inline-block;width:16px;height:8px;' |
| | f'background:{_state_color(i / 4)};margin:0 1px"></span>' |
| | for i in range(5)) |
| |
|
| | return ( |
| | f'<div style="font-family:monospace;line-height:1.1">' |
| | f'{"".join(rows)}' |
| | f'<div style="color:#777;font-size:10px;margin-top:6px">' |
| | f'{n_layers} RWKV layers \u00d7 {n_steps} pieces | ' |
| | f'Avg state magnitude: {avg_norm:.1f}' |
| | f'{f" | Most active: R{most_active+1}" if len(state_history) >= 2 else ""}' |
| | f'</div>' |
| | f'<div style="color:#666;font-size:9px;margin-top:2px">' |
| | f'{legend} low \u2192 high state magnitude' |
| | f'</div></div>') |
| |
|
| |
|
| | def load_text(url): |
| | resp = urllib.request.urlopen(url, timeout=30) |
| | text = resp.read().decode('utf-8', errors='replace') |
| | start = text.find('*** START OF') |
| | if start != -1: |
| | text = text[text.find('\n', start) + 1:] |
| | end = text.find('*** END OF') |
| | if end != -1: |
| | text = text[:end] |
| | return text |
| |
|
| |
|
| | def streaming_loop(provider, encoder, queries, q_embs, sleep_time=0): |
| | prev_span_keys = set() |
| | hare_tokens = 0 |
| | baseline_tokens = 0 |
| | right_extends = 0 |
| | smart_merges = 0 |
| | full_reencodes = 0 |
| | merge_events = 0 |
| | pieces_processed = 0 |
| | piece_queue = [] |
| | peak_scores = {} |
| | state_history = [] |
| | n_rwkv_layers = len(encoder.birwkv_layers) |
| |
|
| | while not provider.is_complete(): |
| | new_pieces = provider.poll_pieces() |
| | if new_pieces: |
| | piece_queue.extend(new_pieces) |
| | random.shuffle(piece_queue) |
| |
|
| | if not piece_queue: |
| | continue |
| |
|
| | idx = piece_queue.pop(0) |
| | provider.received[idx] = True |
| | pieces_processed += 1 |
| |
|
| | new_spans = provider.get_spans() |
| | new_keys = {(s, e) for s, e, _ in new_spans} |
| |
|
| | for s, e, span_text_val in new_spans: |
| | key = (s, e) |
| | if key in prev_span_keys: |
| | continue |
| |
|
| | right_key = (s, e - 1) |
| | if right_key in encoder.span_data: |
| | n = encoder.extend_right(provider.piece_text(e - 1), right_key, key) |
| | hare_tokens += n |
| | right_extends += 1 |
| | baseline_tokens += encoder.span_data[key]['n_tokens'] |
| | continue |
| |
|
| | best_left = None |
| | for (os_, oe) in list(encoder.span_data.keys()): |
| | if os_ == s and oe < e: |
| | if best_left is None or oe > best_left[1]: |
| | best_left = (os_, oe) |
| |
|
| | if best_left: |
| | new_portion = provider.span_text(best_left[1], e) |
| | n = encoder.smart_merge(new_portion, best_left, key) |
| | hare_tokens += n |
| | smart_merges += 1 |
| | baseline_tokens += encoder.span_data[key]['n_tokens'] |
| | continue |
| |
|
| | encoder.remove_old(key) |
| | n = encoder.encode_span(span_text_val, key) |
| | hare_tokens += n |
| | full_reencodes += 1 |
| | baseline_tokens += n |
| |
|
| | if len(new_keys) < len(prev_span_keys) and pieces_processed > 1: |
| | merge_events += 1 |
| | prev_span_keys = new_keys |
| |
|
| | total_chunks = sum(len(d['chunk_embs']) for d in encoder.span_data.values()) |
| | eff = baseline_tokens / max(hare_tokens, 1) |
| |
|
| | if encoder.span_data: |
| | largest_key = max(encoder.span_data.keys(), |
| | key=lambda k: k[1] - k[0]) |
| | states = encoder.span_data[largest_key].get('layer_states', []) |
| | norms = [] |
| | for st in states: |
| | if st is not None and 'wkv_state' in st: |
| | norms.append(st['wkv_state'].norm().item()) |
| | else: |
| | norms.append(0.0) |
| | state_history.append(norms) |
| |
|
| | search_results = {} |
| | for q in queries: |
| | results = encoder.search(q_embs[q], new_spans, top_k=3) |
| | search_results[q] = [ |
| | {'span': (s, e), 'score': sc, 'preview': pv, |
| | 'n_chunks': nc, 'n_tokens': nt} |
| | for s, e, sc, pv, nt, nc in results |
| | ] |
| | if results: |
| | top = results[0] |
| | sc_top = top[2] |
| | if q not in peak_scores or sc_top > peak_scores[q]['score']: |
| | peak_scores[q] = {'score': sc_top, 'preview': top[3]} |
| |
|
| | grid_html = render_grid(provider.received, provider.n_pieces, highlight=idx) |
| | saved = baseline_tokens - hare_tokens |
| | eff_md = f"**Efficiency: {eff:.1f}x** | {total_chunks} chunks" |
| | tok_md = f"Tokens: {hare_tokens:,} processed | {saved:,} saved via state carry" |
| | strat_md = (f"Right-ext: {right_extends} | Smart-merge: {smart_merges} | " |
| | f"Full: {full_reencodes} | Merges: {merge_events}") |
| | search_html = render_search(search_results, peak_scores) |
| | state_html = render_state_viz(state_history, n_rwkv_layers) |
| |
|
| | yield grid_html, eff_md, tok_md, strat_md, search_html, state_html |
| |
|
| | if sleep_time > 0: |
| | time.sleep(sleep_time) |
| |
|
| | eff = baseline_tokens / max(hare_tokens, 1) |
| | total_chunks = sum(len(d['chunk_embs']) for d in encoder.span_data.values()) |
| | saved = baseline_tokens - hare_tokens |
| | grid_html = render_grid(provider.received, provider.n_pieces) |
| | eff_md = f"**Efficiency: {eff:.1f}x** | {total_chunks} chunks | COMPLETE" |
| | tok_md = f"Tokens: {hare_tokens:,} processed | {saved:,} saved via state carry" |
| | strat_md = (f"Right-ext: {right_extends} | Smart-merge: {smart_merges} | " |
| | f"Full: {full_reencodes} | Merges: {merge_events}") |
| |
|
| | final_spans = provider.get_spans() |
| | search_results = {} |
| | for q in queries: |
| | results = encoder.search(q_embs[q], final_spans, top_k=3) |
| | search_results[q] = [ |
| | {'span': (s, e), 'score': sc, 'preview': pv, |
| | 'n_chunks': nc, 'n_tokens': nt} |
| | for s, e, sc, pv, nt, nc in results |
| | ] |
| | search_html = render_search(search_results, peak_scores) |
| | state_html = render_state_viz(state_history, n_rwkv_layers) |
| | yield grid_html, eff_md, tok_md, strat_md, search_html, state_html |
| |
|
| |
|
| | @spaces.GPU |
| | def start_demo(source_mode, demo_choice, url_input, queries_text, chunk_size): |
| | model.cuda() |
| | encoder = SpanEncoder(model, tokenizer, chunk_size=chunk_size) |
| |
|
| | if source_mode == "Quick Demo": |
| | config = QUICK_DEMOS[demo_choice] |
| | provider = TextProvider(config['text'], |
| | piece_size=config['piece_size'], seed=42) |
| | queries = config['queries'] |
| | sleep_time = config['sleep'] |
| | elif source_mode == "URL": |
| | if not url_input: |
| | yield ('<p style="color:#ffc107">Enter a URL to a text file.</p>', |
| | '', '', '', '', '') |
| | return |
| | text = load_text(url=url_input) |
| | provider = TextProvider(text, piece_size=4096, seed=42) |
| | queries = [q.strip() for q in queries_text.split(',') if q.strip()] |
| | sleep_time = 0 |
| | else: |
| | return |
| |
|
| | if not queries: |
| | queries = ["search query"] |
| |
|
| | q_embs = {q: encoder.encode_query(q) for q in queries} |
| |
|
| | yield from streaming_loop(provider, encoder, queries, q_embs, sleep_time) |
| |
|
| |
|
| | def toggle_inputs(source_mode): |
| | frankenstein_q = "on a dreary night the creature first opened its eyes, an innocent woman is wrongly executed, playing god with science" |
| | return ( |
| | gr.update(visible=(source_mode == "Quick Demo")), |
| | gr.update(visible=(source_mode == "URL")), |
| | gr.update(visible=(source_mode != "Quick Demo"), |
| | value=frankenstein_q), |
| | ) |
| |
|
| |
|
| | def update_queries(demo_choice): |
| | config = QUICK_DEMOS.get(demo_choice, {}) |
| | queries = config.get('queries', []) |
| | return ', '.join(queries) |
| |
|
| |
|
| | def build_demo(): |
| | with gr.Blocks(title="HARE Streaming Demo") as demo: |
| | gr.Markdown( |
| | "# HARE: Streaming Semantic Search", |
| | ) |
| | gr.Markdown( |
| | "Watch [HARE](https://huggingface.co/SixOpen/HARE) build a " |
| | "semantic search index in real-time as content streams in " |
| | "piece by piece. Unlike standard embedding models, HARE's " |
| | "recurrent state carries forward full context without " |
| | "re-encoding, allowing for search over live transcripts, " |
| | "distributed content, and streaming files without " |
| | "needing to download them in full.", |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=280): |
| | source_mode = gr.Radio( |
| | ["URL", "Quick Demo"], |
| | value="URL", |
| | label="Source", |
| | ) |
| | demo_choice = gr.Dropdown( |
| | list(QUICK_DEMOS.keys()), |
| | value=list(QUICK_DEMOS.keys())[0], |
| | label="Demo Content", |
| | visible=False, |
| | ) |
| | url_input = gr.Textbox( |
| | label="Text URL", |
| | value="https://gutenberg.org/files/84/84-0.txt", |
| | placeholder="https://gutenberg.org/files/84/84-0.txt", |
| | visible=True, |
| | ) |
| | queries_input = gr.Textbox( |
| | label="Search Queries (comma-separated)", |
| | value="on a dreary night the creature first opened its eyes, an innocent woman is wrongly executed, playing god with science", |
| | visible=True, |
| | ) |
| |
|
| | with gr.Accordion("Settings", open=False): |
| | chunk_size = gr.Slider( |
| | 128, 1024, value=512, step=64, |
| | label="Chunk Size (tokens)", |
| | ) |
| |
|
| | start_btn = gr.Button("Start Demo", variant="primary", size="lg") |
| |
|
| | with gr.Column(scale=2): |
| | gr.Markdown("### Download Progress") |
| | piece_grid = gr.HTML( |
| | '<div style="padding:20px;color:#666;text-align:center">' |
| | 'Click "Start Demo" to begin</div>' |
| | ) |
| |
|
| | gr.Markdown("### Encoding Efficiency") |
| | with gr.Row(): |
| | efficiency_md = gr.Markdown("**Efficiency: --**") |
| | with gr.Row(): |
| | tokens_md = gr.Markdown("Tokens: --") |
| | strategy_md = gr.Markdown("Right-ext: -- | Smart-merge: -- | Full: --") |
| |
|
| | gr.Markdown("### Search Results") |
| | search_html = gr.HTML( |
| | '<p style="color:#888">Results will appear here as ' |
| | 'pieces are processed...</p>' |
| | ) |
| |
|
| | gr.Markdown("### Recurrent State Evolution") |
| | state_viz = gr.HTML( |
| | '<p style="color:#888">State heatmap will appear as ' |
| | 'pieces are processed...</p>' |
| | ) |
| |
|
| | source_mode.change( |
| | toggle_inputs, |
| | inputs=[source_mode], |
| | outputs=[demo_choice, url_input, queries_input], |
| | ) |
| | demo_choice.change( |
| | update_queries, |
| | inputs=[demo_choice], |
| | outputs=[queries_input], |
| | ) |
| | start_btn.click( |
| | start_demo, |
| | inputs=[source_mode, demo_choice, url_input, queries_input, |
| | chunk_size], |
| | outputs=[piece_grid, efficiency_md, tokens_md, strategy_md, |
| | search_html, state_viz], |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | demo = build_demo() |
| | demo.queue().launch() |
| |
|