|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import tiktoken |
|
from dataclasses import dataclass |
|
|
|
from model import GPT |
|
|
|
|
|
class GPT2Inference: |
|
""" To generate text sequences using a trained GPT2 model """ |
|
|
|
def __init__(self, model, token_encoder, device): |
|
self.model = model |
|
self.token_encoder = token_encoder |
|
self.device = device |
|
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
|
|
|
def generate_sequences(self, prompt, num_seq=5, max_tokens=50): |
|
self.model.eval() |
|
tokens = self.token_encoder.encode(prompt) |
|
tokens = torch.tensor(tokens, dtype=torch.long) |
|
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) |
|
gen_tokens = tokens.to(self.device) |
|
|
|
sample_rng = torch.Generator(device=self.device).manual_seed(42) |
|
|
|
|
|
while gen_tokens.shape[-1] <= max_tokens: |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
|
logits, loss = self.model(gen_tokens) |
|
logits = logits[:, -1, :] |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
|
|
|
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) |
|
next_tok = torch.gather(topk_indices, -1, ix) |
|
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1) |
|
|
|
for i in range(num_seq): |
|
tokens = gen_tokens[i, :max_tokens].tolist() |
|
gen_text = self.token_encoder.decode(tokens) |
|
print(f"> sample {i}: {gen_text}") |
|
|
|
|
|
def parse_args(): |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--prompt', type=str, default="Hello, I am a language model,") |
|
parser.add_argument('--num_seq', type=int, default=5) |
|
parser.add_argument('--max_tokens', type=int, default=50) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
@dataclass |
|
class GPTConfig: |
|
context_length: int = 1024 |
|
vocab_size: int = 50257 |
|
num_layers: int = 12 |
|
embd_size: int = 768 |
|
num_heads: int = 12 |
|
|
|
|
|
def inference(args=None): |
|
if args is None: |
|
args = parse_args() |
|
|
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
device = 'mps' |
|
print(f'using device: {device}') |
|
|
|
model_path = './logs/model_95364.pt' |
|
checkpoint = torch.load(model_path, weights_only=False) |
|
print(f"loaded model from: {model_path}") |
|
|
|
|
|
model = GPT(config=checkpoint['config']) |
|
model.load_state_dict(checkpoint['model']) |
|
model = model.to(device) |
|
token_encoder = tiktoken.get_encoding('gpt2') |
|
generator = GPT2Inference(model, token_encoder, device) |
|
|
|
generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens) |
|
|
|
|
|
if __name__ == '__main__': |
|
inference() |
|
|