Turing / chat.py
AGofficial's picture
Upload 8 files
53264fa verified
import os
import torch
import torch.nn.functional as F
from collections import OrderedDict
import string
import sys
from model import ChatGCLM, MAX_SEQ_LEN
# --- Configuration ---
EOS_ID = 2
OFFSET = 3
CHARS = string.printable
def get_model_path():
"""Finds the first model file starting with Turing_ in the current directory."""
for f in os.listdir("."):
if f.startswith("Turing_") and f.endswith(".pt"):
return f
return None
MODEL_PATH = get_model_path()
if MODEL_PATH is None:
print("Error: No model checkpoint found!")
print("Please train the model first with: python3 train.py")
sys.exit(1)
# --- Helper Functions ---
def encode(text):
return [CHARS.index(c) + OFFSET for c in text if c in CHARS]
def decode(ids):
return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET])
def load_model(device):
vocab_size = len(CHARS) + OFFSET
model = ChatGCLM(vocab_size).to(device)
if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0:
print(f"Loading model from: {MODEL_PATH}")
ckpt = torch.load(MODEL_PATH, map_location=device)
if isinstance(ckpt, dict):
if 'model_state_dict' in ckpt:
state_dict = ckpt['model_state_dict']
elif 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
else:
state_dict = ckpt
else:
state_dict = ckpt
# Handle compilation prefix if present
def _strip_module_prefix(sd):
keys = list(sd.keys())
if any(k.startswith('module.') for k in keys):
new_sd = OrderedDict()
for k, v in sd.items():
new_key = k[len('module.'): ] if k.startswith('module.') else k
new_sd[new_key] = v
return new_sd
return sd
state_dict = _strip_module_prefix(state_dict)
res = model.load_state_dict(state_dict, strict=False)
missing = getattr(res, 'missing_keys', None)
unexpected = getattr(res, 'unexpected_keys', None)
if missing:
print(f"Warning: missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: unexpected keys in state_dict: {unexpected}")
model.eval()
return model
else:
print(f"Error: Could not load model from {MODEL_PATH}")
return None
@torch.no_grad()
def generate_stream(model, prompt, device, max_new_tokens=500, temperature=0.7, top_k=50):
"""
Generates text from the model and streams it to stdout.
Returns the full generated text.
"""
model.eval()
input_ids = encode(prompt)
x = torch.tensor([input_ids], dtype=torch.long, device=device)
# We don't print the prompt again, we just stream the new tokens
generated_ids = []
for _ in range(max_new_tokens):
# Crop context if needed
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
logits = model(ctx)
next_token_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
if idx == EOS_ID:
break
x = torch.cat((x, next_token), dim=1)
generated_ids.append(idx)
token_text = decode([idx])
print(token_text, end="", flush=True)
if len(generated_ids) >= 3 and decode(generated_ids[-3:]) == "<u>":
print('\b\b\b \b\b\b', end="", flush=True)
break
return decode(generated_ids)
def main():
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
model = load_model(device)
if model is None:
sys.exit(1)
print("\n" + "="*50)
print("Turing | Chat Interface")
print(f"Model: {MODEL_PATH}")
print("Type 'quit', 'exit', or 'q' to end the session.")
print("="*50 + "\n")
history = ""
while True:
try:
# Get user input
user_input = input("\n\033[1;36mYou:\033[0m ") # Cyan color for "You:"
if user_input.strip().lower() in ['quit', 'exit', 'q']:
print("\nGoodbye!")
break
if not user_input.strip():
continue
print("\033[1;32mModel:\033[0m ", end="", flush=True) # Green color for "Model:"
# Since this is a raw completion model, we might want to feed it the input directly
# and let it continue.
# Prepare the prompt with history
current_turn = f"<u> {user_input} <a>"
full_prompt = history + current_turn
# Generate response
response = generate_stream(model, full_prompt, device=device)
# Update history
# We strip <u> from the end if it was generated as a stop token
cleaned_response = response
if cleaned_response.endswith("<u>"):
cleaned_response = cleaned_response[:-3]
history += current_turn + cleaned_response
print() # Newline after generation
except KeyboardInterrupt:
print("\n\nExiting...")
break
except Exception as e:
print(f"\nError: {e}")
if __name__ == "__main__":
main()