kykybeepbopboop's picture
Update app.py
9bc5631 verified
import gradio as gr
import torch
import torch.nn.functional as F
import json
import os
from huggingface_hub import snapshot_download
from model import VacuumInspiredRNN # Import your class!
class WordTokenizer:
# Same as in train.py—copy the full class here for self-containment
def __init__(self, vocab_size=768):
self.pad_id = 0
self.unk_id = 1
self.word_to_idx = {'<pad>': self.pad_id, '<unk>': self.unk_id}
self.idx_to_word = {self.pad_id: '<pad>', self.unk_id: '<unk>'}
self.vocab_size = vocab_size
def build_vocab(self, texts):
from collections import Counter
words = [w for text in texts for w in text.lower().split()]
counter = Counter(words)
most_common = counter.most_common(self.vocab_size - 2)
for word, _ in most_common:
idx = len(self.word_to_idx)
if idx < self.vocab_size:
self.word_to_idx[word] = idx
self.idx_to_word[idx] = word
def encode(self, text):
return [self.word_to_idx.get(w, self.unk_id) for w in text.lower().split()]
def decode(self, tokens):
return ' '.join(self.idx_to_word.get(t, '<unk>') for t in tokens if t != self.pad_id)
@classmethod
def load(cls, path):
with open(path, 'r') as f:
data = json.load(f)
tokenizer = cls(data['vocab_size'])
tokenizer.word_to_idx = data['word_to_idx']
tokenizer.idx_to_word = data['idx_to_word']
tokenizer.pad_id = data['pad_id']
tokenizer.unk_id = data['unk_id']
return tokenizer
@gr.cache
def load_model():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
local_pth = 'trained/model.pth'
local_tok = 'trained/tokenizer.json'
if os.path.exists(local_pth):
sd = torch.load(local_pth, map_location=device)
tok = WordTokenizer.load(local_tok)
else:
repo = "your-username/vacuum-rnn-llm" # Update!
snapshot_download(repo_id=repo, local_dir='cache')
sd = torch.load('cache/model.pth', map_location=device)
tok = WordTokenizer.load('cache/tokenizer.json')
model = VacuumInspiredRNN(vocab_size=tok.vocab_size).to(device)
model.load_state_dict(sd)
model.eval()
return model, tok, device
def generate(prompt, max_new=50, temp=0.8):
model, tok, device = load_model()
if not prompt.strip(): return "Add a prompt!"
ptoks = tok.encode(prompt)
if not ptoks: return "Invalid prompt."
with torch.no_grad():
inp = torch.tensor([ptoks], device=device)
_, hidden = model(inp, add_fluctuation=True)
gen_toks = ptoks[:]
for _ in range(max_new):
last = torch.tensor([[gen_toks[-1]]], device=device)
logits, hidden = model(last, hidden)
next_log = logits[0, -1] / temp
probs = F.softmax(next_log, dim=0)
next_t = torch.multinomial(probs, 1).item()
gen_toks.append(next_t)
if next_t == tok.pad_id: break
return tok.decode(gen_toks)
gr.Interface(
generate, [
gr.Textbox("Prompt", placeholder="The quick brown fox..."),
gr.Slider(10, 200, 50, "Max New Words"),
gr.Slider(0.1, 2.0, 0.8, "Temperature")
], gr.Textbox("Output", lines=8),
title="Your Vacuum RNN LLM"
).launch(share=True)