Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import json | |
| import os | |
| from huggingface_hub import snapshot_download | |
| from model import VacuumInspiredRNN # Import your class! | |
| class WordTokenizer: | |
| # Same as in train.py—copy the full class here for self-containment | |
| def __init__(self, vocab_size=768): | |
| self.pad_id = 0 | |
| self.unk_id = 1 | |
| self.word_to_idx = {'<pad>': self.pad_id, '<unk>': self.unk_id} | |
| self.idx_to_word = {self.pad_id: '<pad>', self.unk_id: '<unk>'} | |
| self.vocab_size = vocab_size | |
| def build_vocab(self, texts): | |
| from collections import Counter | |
| words = [w for text in texts for w in text.lower().split()] | |
| counter = Counter(words) | |
| most_common = counter.most_common(self.vocab_size - 2) | |
| for word, _ in most_common: | |
| idx = len(self.word_to_idx) | |
| if idx < self.vocab_size: | |
| self.word_to_idx[word] = idx | |
| self.idx_to_word[idx] = word | |
| def encode(self, text): | |
| return [self.word_to_idx.get(w, self.unk_id) for w in text.lower().split()] | |
| def decode(self, tokens): | |
| return ' '.join(self.idx_to_word.get(t, '<unk>') for t in tokens if t != self.pad_id) | |
| def load(cls, path): | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| tokenizer = cls(data['vocab_size']) | |
| tokenizer.word_to_idx = data['word_to_idx'] | |
| tokenizer.idx_to_word = data['idx_to_word'] | |
| tokenizer.pad_id = data['pad_id'] | |
| tokenizer.unk_id = data['unk_id'] | |
| return tokenizer | |
| def load_model(): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| local_pth = 'trained/model.pth' | |
| local_tok = 'trained/tokenizer.json' | |
| if os.path.exists(local_pth): | |
| sd = torch.load(local_pth, map_location=device) | |
| tok = WordTokenizer.load(local_tok) | |
| else: | |
| repo = "your-username/vacuum-rnn-llm" # Update! | |
| snapshot_download(repo_id=repo, local_dir='cache') | |
| sd = torch.load('cache/model.pth', map_location=device) | |
| tok = WordTokenizer.load('cache/tokenizer.json') | |
| model = VacuumInspiredRNN(vocab_size=tok.vocab_size).to(device) | |
| model.load_state_dict(sd) | |
| model.eval() | |
| return model, tok, device | |
| def generate(prompt, max_new=50, temp=0.8): | |
| model, tok, device = load_model() | |
| if not prompt.strip(): return "Add a prompt!" | |
| ptoks = tok.encode(prompt) | |
| if not ptoks: return "Invalid prompt." | |
| with torch.no_grad(): | |
| inp = torch.tensor([ptoks], device=device) | |
| _, hidden = model(inp, add_fluctuation=True) | |
| gen_toks = ptoks[:] | |
| for _ in range(max_new): | |
| last = torch.tensor([[gen_toks[-1]]], device=device) | |
| logits, hidden = model(last, hidden) | |
| next_log = logits[0, -1] / temp | |
| probs = F.softmax(next_log, dim=0) | |
| next_t = torch.multinomial(probs, 1).item() | |
| gen_toks.append(next_t) | |
| if next_t == tok.pad_id: break | |
| return tok.decode(gen_toks) | |
| gr.Interface( | |
| generate, [ | |
| gr.Textbox("Prompt", placeholder="The quick brown fox..."), | |
| gr.Slider(10, 200, 50, "Max New Words"), | |
| gr.Slider(0.1, 2.0, 0.8, "Temperature") | |
| ], gr.Textbox("Output", lines=8), | |
| title="Your Vacuum RNN LLM" | |
| ).launch(share=True) |