File size: 5,517 Bytes
b404f80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch
import torch.nn as nn
from transformers import BertModel, GPTNeoForCausalLM, AutoTokenizer

# ⚙️ Ensure temporary directory is writable
os.environ["TMPDIR"] = os.path.expanduser("~/tmp")
os.makedirs(os.environ["TMPDIR"], exist_ok=True)

# 💠 Optional modules (brain & heart, if available)
heart_module = None
brain_module = None

if os.path.isdir("heart"):
    try:
        from heart import heart
        heart_module = heart
    except Exception as e:
        print(f"[⚠️] Heart module error: {e}")

if os.path.isdir("brain"):
    try:
        from brain import brain
        brain_module = brain
    except Exception as e:
        print(f"[⚠️] Brain module error: {e}")

# TARSQuantumHybrid Class
class TARSQuantumHybrid(nn.Module):
    def __init__(self, bert_model="bert-base-uncased", gpt_model="EleutherAI/gpt-neo-125M"):
        super(TARSQuantumHybrid, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.gpt = GPTNeoForCausalLM.from_pretrained(gpt_model)

        gpt_hidden_dim = getattr(self.gpt.config, "hidden_size", None) or getattr(self.gpt.config, "n_embd", 768)
        self.embedding_proj = nn.Linear(self.bert.config.hidden_size, gpt_hidden_dim)

        self.tokenizer = AutoTokenizer.from_pretrained(gpt_model)

        # Ensure the tokenizer has a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.gpt.resize_token_embeddings(len(self.tokenizer))
            print("✅ Padding token added and model resized.")

    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = bert_output.last_hidden_state[:, 0, :]
        gpt_input = self.embedding_proj(cls_embedding).unsqueeze(1)
        outputs = self.gpt(inputs_embeds=gpt_input, decoder_input_ids=decoder_input_ids)
        return outputs

    def chat(self, text, max_length=128):
        # 🧠 Tokenize the input text
        cleaned_text = self.clean_input_text(text)
        if not cleaned_text.strip():
            return "🤖 Please provide a non-empty input."

        encoded_input = self.safe_tokenization(cleaned_text)
        
        # Extract input_ids and attention_mask
        input_ids = encoded_input["input_ids"]
        attention_mask = encoded_input["attention_mask"]

        # Debug: Check the token IDs and vocab size
        print(f"Input Text: {cleaned_text}")
        print(f"Input IDs: {input_ids}")
        print(f"Vocabulary Size: {self.tokenizer.vocab_size}")
        
        # Ensure token IDs are within bounds
        if input_ids.numel() > 0 and input_ids.max() >= self.tokenizer.vocab_size:
            raise ValueError(f"Token ID exceeds model's vocabulary size: {input_ids.max()}")

        decoder_input_ids = torch.tensor([[self.tokenizer.bos_token_id]])

        # 🧪 Generate output using the model
        with torch.no_grad():
            outputs = self.forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
            )
            generated_ids = torch.argmax(outputs.logits, dim=-1)

        # Debug: Check the generated token IDs
        print(f"Generated Token IDs: {generated_ids}")

        raw_response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        # 🧼 Clean model echo by removing the original input from the response
        cleaned = raw_response.replace(cleaned_text, "").strip()

        # 🧠 Add insights from optional modules (brain & heart)
        extra_thoughts = ""
        if brain_module and hasattr(brain_module, "get_brain_insight"):
            extra_thoughts += f"\n🧠 {brain_module.get_brain_insight()}"
        if heart_module and hasattr(heart_module, "get_heart_feeling"):
            extra_thoughts += f"\n❤️ {heart_module.get_heart_feeling()}"

        # 🪄 Return final response
        final_response = cleaned if cleaned else "🤖 ...processing quantum entanglement..."
        return final_response + extra_thoughts

    def clean_input_text(self, text):
        # Remove unwanted characters
        cleaned_text = ''.join(e for e in text if e.isalnum() or e.isspace())
        return cleaned_text

    def safe_tokenization(self, text):
        token_ids = self.tokenizer.encode(text, add_special_tokens=True)
        # Ensure that token ids are within vocabulary size
        token_ids = [min(i, self.tokenizer.vocab_size - 1) for i in token_ids]
        return {
            "input_ids": torch.tensor(token_ids).unsqueeze(0),  # Adding batch dimension
            "attention_mask": torch.ones((1, len(token_ids)), dtype=torch.long)
        }

# ✅ Torch-compatible loader
def load_tars(path="tars_v1.pt"):
    from torch.serialization import add_safe_globals
    add_safe_globals({"TARSQuantumHybrid": TARSQuantumHybrid})

    model = torch.load(path, weights_only=False)
    model.eval()
    return model

# ✅ Start chat loop
if __name__ == "__main__":
    print("🤖 TARS model loaded successfully. Ready to chat!")
    model = load_tars()

    while True:
        prompt = input("You: ")
        if prompt.strip().lower() in ["exit", "quit"]:
            print("TARS: Till we meet again in the quantum field. 🌌")
            break
        response = model.chat(prompt)
        print(f"TARS: {response}")