File size: 6,462 Bytes
4ea6cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from models.model import Microformer
from tokenizers import Tokenizer
from config import VOCAB_SIZE, EMBED_DIM, NUM_HEADS, FF_DIM, NUM_LAYERS, MAX_SEQ_LEN, ADAPTER_DIM
import sqlite3
from datetime import datetime

# --- Load tokenizer and model ---
tokenizer = Tokenizer.from_file("data/tokenizer.json")
VOCAB_SIZE = tokenizer.get_vocab_size()

model = Microformer(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    ff_dim=FF_DIM,
    num_layers=NUM_LAYERS,
    max_seq_len=MAX_SEQ_LEN,
    long_term_adapter_dim=ADAPTER_DIM,
    session_adapter_dim=ADAPTER_DIM
)
model.load_state_dict(torch.load("microformer.pt"))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# --- Freeze all but session adapters and output for online learning ---
model.freeze_except_adapters(session_only=True, include_output=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-2  # High LR for visible learning during teaching
)

# --- Memory DB setup ---
conn = sqlite3.connect("memory.db")
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS memory (
    timestamp TEXT,
    prompt TEXT,
    response TEXT
)
""")
conn.commit()

def top_k_top_p_filtering(logits, top_k=50, top_p=0.9):
    logits = logits.squeeze(0)  # [1, vocab] → [vocab]
    probs = torch.softmax(logits, dim=-1)

    # Sort probabilities
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # Top-p mask
    sorted_mask = cumulative_probs > top_p
    sorted_mask[1:] = sorted_mask[:-1].clone()
    sorted_mask[0] = False

    # Top-k mask
    if top_k < sorted_probs.size(0):
        sorted_mask[top_k:] = True

    # Zero out masked values
    sorted_probs[sorted_mask] = 0.0

    # Normalize and sample
    sorted_probs /= sorted_probs.sum()
    sampled_relative_index = torch.multinomial(sorted_probs, 1).item()
    sampled_token_id = sorted_indices[sampled_relative_index].item()

    return sampled_token_id

def generate(prompt, length=100, temperature=1.0, top_p=0.9, top_k=50):
    input_ids = tokenizer.encode(prompt).ids
    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)

    eos_token_id = tokenizer.token_to_id("<EOS>")

    for _ in range(length):
        with torch.no_grad():
            logits = model(input_tensor)
            logits = logits[:, -1, :] / temperature

            # Repetition penalty
            for token_id in input_tensor[0].tolist():
                logits[0, token_id] *= 0.8

            next_token_id = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

        input_tensor = torch.cat([input_tensor, torch.tensor([[next_token_id]], device=device)], dim=1)

        if next_token_id == eos_token_id:
            break

    output_ids = input_tensor[0].tolist()
    decoded = tokenizer.decode(output_ids)

    if "<EOS>" in decoded:
        decoded = decoded.split("<EOS>")[0].strip()

    return decoded

def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64):
    # Always called after freeze_except_adapters(session_only=True)
    ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")]
    if len(ids) < 2:
        return None

    ids = ids[:max_len + 1]
    input_ids = ids[:-1]
    target_ids = ids[1:]
    input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids))
    target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids))
    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
    target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device)

    model.train()
    logits = model(input_tensor)
    logits = logits.view(-1, logits.size(-1))
    targets = target_tensor.view(-1)
    loss = loss_fn(logits, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    model.eval()
    return loss.item()

# Optional: Reset session adapter weights between sessions
def reset_session_adapters(model):
    for layer in model.layers:
        if getattr(layer, 'session_adapter', None) is not None:
            for param in layer.session_adapter.parameters():
                if param.data is not None:
                    nn.init.zeros_(param.data)

if __name__ == "__main__":
    while True:
        prompt = input("\nEnter a prompt (or 'exit' to quit): ")
        if prompt.lower() in {"exit", "quit"}:
            break
        temp = float(input("Temperature (e.g. 0.7, 1.0): "))

        output = generate(prompt, length=100, temperature=temp, top_p=0.9, top_k=50)
        print("\nGenerated text:\n")
        print(output)

        # Online learning: always update session adapters only
        teach = input("\nDo you want to teach the model a better answer? (y/N): ").strip().lower()
        if teach == "y":
            your_answer = input("Type your ideal response for this prompt: ")
            model.freeze_except_adapters(session_only=True, include_output=True)
            online_text = prompt + " " + your_answer
            loss = online_unsupervised_update(
                model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN
            )
            print(f"[Online update loss: {loss:.4f}]")
        else:
            model.freeze_except_adapters(session_only=True, include_output=True)
            online_text = prompt + " " + output
            loss = online_unsupervised_update(
                model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN
            )
            print(f"[Online (self-improve) update loss: {loss:.4f}]")

        # Store the interaction in memory DB as before
        c.execute("INSERT INTO memory (timestamp, prompt, response) VALUES (?, ?, ?)",
                  (datetime.now().isoformat(timespec='seconds'), prompt, output))
        conn.commit()

        print("\nRecent memory:")
        for row in c.execute("SELECT * FROM memory ORDER BY timestamp DESC LIMIT 5"):
            print(f"[{row[0]}] {row[1]}{row[2]}")

        # Optional: Uncomment to reset fast-memory (session adapters) between users/sessions
        # reset_session_adapters(model)