File size: 2,693 Bytes
b127d35
 
 
 
 
 
79eec1d
b127d35
 
 
 
79eec1d
 
 
b127d35
 
79eec1d
 
 
 
b127d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79eec1d
b127d35
 
 
 
 
 
 
 
 
 
 
 
 
79eec1d
 
b127d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn.functional as F
from model import MiniGPT
from dataset import MiniBPETokenizr,SimpleTokenizr
import json
import os
from tokenizers import Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer
#tokenizer = SimpleTokenizr()
#tokenizer.load("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
tokenizer = Tokenizer.from_file("./trained-mini-gpt/tokenizer.json")

# Load model
model = MiniGPT(vocab_size=tokenizer.get_vocab_size())
#model.load_state_dict(torch.load("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth", map_location=device) if os.path.exists("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth") else torch.load("./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth", map_location=device)["model_state_dict"] )
checkpoint = torch.load("./trained-mini-gpt/mini-gpt.pth", map_location=device)
model.load_state_dict(checkpoint)
model.eval().to(device)
totalparams = sum(p.numel() for p in model.parameters())
print(f"Model total params: {totalparams:,}")

def sample_token(logits, temperature=1.0):
    logits = logits / temperature
    logits = torch.nan_to_num(logits, nan=-1e9)
    probs = F.softmax(logits, dim=-1)

    if torch.any(torch.isnan(probs)) or torch.any(probs < 0):
        print("⚠️ Invalid probs detected. Using uniform fallback.")
        probs = torch.ones_like(probs) / probs.size(-1)

    return torch.multinomial(probs, num_samples=1).item()

def generate_reply(prompt, max_tokens=100):
    tokens = tokenizer.encode(prompt).ids
    if not tokens:
        print("⚠️ Empty prompt after encoding.")
        return
    input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    generated = []

    with torch.no_grad():
        for _ in range(max_tokens):
            logits = model(input_ids)
            logits = logits[:, -1, :]
            next_token = sample_token(logits)
            generated.append(next_token)

            next_str = tokenizer.id_to_token(next_token)
            encoded_text = tokenizer.encode(next_str).ids
            decoded_text = tokenizer.decode(encoded_text)
            print(decoded_text, end=" ", flush=True)

            if next_str == "<END>":
                break

            input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)
    print()

# Chat loop
print("🧠 MiniGPT Chat (type 'exit' to quit')")
while True:
    user_input = input("User: ")
    if user_input.lower() == "exit":
        break
    prompt = f"^User: {user_input}\nMiniGPT:"
    print("MiniGPT: ", end="", flush=True)
    generate_reply(prompt)