File size: 325 Bytes
52db7c8
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch
from src.utils import encode, decode


def generate(prompt, model, block_size, max_new_tokens, device):
    X = torch.tensor(encode(prompt), dtype=torch.long, device=device)
    X = X[:block_size].unsqueeze(0)
    results = decode(model.generate(X, max_new_tokens=max_new_tokens)[0].tolist())
    return results