| | import os |
| | import torch |
| | import torch.nn.functional as F |
| | import tiktoken |
| | from model import Crimson, MAX_SEQ_LEN |
| |
|
| | MODEL_PATH = "crimson_8.5M.pt" |
| | VOCAB_PATH = "vocab_map.pt" |
| | TOKENIZER_NAME = "gpt2" |
| |
|
| | PAD_ID = 0 |
| | SEP_ID = 1 |
| | EOS_ID = 2 |
| | OFFSET = 3 |
| |
|
| | def load_model_and_vocab(device): |
| | if not os.path.exists(VOCAB_PATH): |
| | return None, None, None |
| | vocab_data = torch.load(VOCAB_PATH, map_location="cpu") |
| | used_tokens = vocab_data["used_tokens"] |
| | id2new = vocab_data["id2new"] |
| | vocab_size = len(used_tokens) + OFFSET |
| | model = Crimson(vocab_size).to(device) |
| | if os.path.exists(MODEL_PATH): |
| | model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
| | model.eval() |
| | else: |
| | return None, None, None |
| | return model, used_tokens, id2new |
| |
|
| | @torch.no_grad() |
| | def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50): |
| | model.eval() |
| | raw_ids = tokenizer.encode(prompt) |
| | input_ids = [id2new[rid] for rid in raw_ids if rid in id2new] |
| | if not input_ids: input_ids = [PAD_ID] |
| | x = torch.tensor([input_ids], dtype=torch.long, device=device) |
| | generated = [] |
| | for _ in range(max_new_tokens): |
| | ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
| | logits = model(ctx) |
| | next_token_logits = logits[:, -1, :] / temperature |
| | if top_k is not None: |
| | v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
| | next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | idx = next_token.item() |
| | if idx == EOS_ID: break |
| | x = torch.cat((x, next_token), dim=1) |
| | generated.append(idx) |
| | return tokenizer.decode([used_tokens[i - OFFSET] for i in generated if i >= OFFSET]) |
| |
|
| | if __name__ == "__main__": |
| | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| | model, used_tokens, id2new = load_model_and_vocab(device) |
| | enc = tiktoken.get_encoding(TOKENIZER_NAME) |
| | if model: |
| | newline_id = id2new.get(enc.encode("\n")[0], OFFSET) |
| | while True: |
| | x = torch.tensor([[newline_id]], dtype=torch.long, device=device) |
| | with torch.no_grad(): |
| | for _ in range(900): |
| | ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
| | logits = model(ctx) |
| | logits = logits[:, -1, :] / 0.8 |
| | v, _ = torch.topk(logits, min(50, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | idx = next_token.item() |
| | x = torch.cat((x, next_token), dim=1) |
| | if idx == EOS_ID: break |
| | if idx >= OFFSET: |
| | print(enc.decode([used_tokens[idx - OFFSET]]), end="", flush=True) |
| | if input("\nPress [Enter] to generate again, or type 'exit': ").lower() == 'exit': break |
| |
|