| import torch |
| import tiktoken |
| import os |
| from model import GPTConfig, GPT |
|
|
| out_dir = 'out' |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
|
|
| enc = tiktoken.get_encoding("gpt2") |
| EOS_TOKEN_ID = 50256 |
|
|
| def ask_gpt(prompt, max_new_tokens=150, temperature=0.7, top_k=25): |
| start_ids = enc.encode(prompt) |
| x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] |
|
|
| with torch.no_grad(): |
| y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
| |
| full_ids = y[0].tolist() |
| new_ids = full_ids[len(start_ids):] |
| |
| response = enc.decode(new_ids) |
| response = response.split('<|endoftext|>')[0] |
| return response |
|
|
| print("--- Crest Completion Chat started ---") |
| while True: |
| user_input = input("\nYour Prompt: ") |
| if user_input.lower() in ['exit', 'quit']: break |
| |
| antwort_rest = ask_gpt(user_input) |
| |
| print(f"\nCrest Completion: {user_input}{antwort_rest}") |
| print("-" * 30) |