CreatedNull commited on
Commit
b180c22
·
verified ·
1 Parent(s): 9d848aa

Delete ml_tinygpt.py

Browse files
Files changed (1) hide show
  1. ml_tinygpt.py +0 -66
ml_tinygpt.py DELETED
@@ -1,66 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from model import MiniGPT
4
- from dataset import MiniBPETokenizr,SimpleTokenizr
5
- import json
6
- import os
7
-
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- # Load tokenizer
11
- tokenizer = SimpleTokenizr()
12
- tokenizer.load("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
13
-
14
- # Load model
15
- model = MiniGPT(vocab_size=len(tokenizer))
16
- model.load_state_dict(torch.load("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth", map_location=device) if os.path.exists("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth") else torch.load("./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth", map_location=device)["model_state_dict"] )
17
- model.eval().to(device)
18
- totalparams = sum(p.numel() for p in model.parameters())
19
- print(f"Model total params: {totalparams:,}")
20
-
21
- def sample_token(logits, temperature=1.0):
22
- logits = logits / temperature
23
- logits = torch.nan_to_num(logits, nan=-1e9)
24
- probs = F.softmax(logits, dim=-1)
25
-
26
- if torch.any(torch.isnan(probs)) or torch.any(probs < 0):
27
- print("⚠️ Invalid probs detected. Using uniform fallback.")
28
- probs = torch.ones_like(probs) / probs.size(-1)
29
-
30
- return torch.multinomial(probs, num_samples=1).item()
31
-
32
- def generate_reply(prompt, max_tokens=100):
33
- tokens = tokenizer.encode(prompt)
34
- if not tokens:
35
- print("⚠️ Empty prompt after encoding.")
36
- return
37
- input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
38
- generated = []
39
-
40
- with torch.no_grad():
41
- for _ in range(max_tokens):
42
- logits = model(input_ids)
43
- logits = logits[:, -1, :]
44
- next_token = sample_token(logits)
45
- generated.append(next_token)
46
-
47
- next_str = tokenizer.itos.get(next_token, "")
48
- encoded_text = tokenizer.encode(next_str)
49
- decoded_text = tokenizer.decode(encoded_text)
50
- print(decoded_text, end=" ", flush=True)
51
-
52
- if next_str == "<END>":
53
- break
54
-
55
- input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)
56
- print()
57
-
58
- # Chat loop
59
- print("🧠 MiniGPT Chat (type 'exit' to quit')")
60
- while True:
61
- user_input = input("User: ")
62
- if user_input.lower() == "exit":
63
- break
64
- prompt = f"^User: {user_input}\nMiniGPT:"
65
- print("MiniGPT: ", end="", flush=True)
66
- generate_reply(prompt)