--- license: mit --- # nGPT-enwiki8 small [nGPT](https://arxiv.org/abs/2410.01131) model trained on enwiki8 for testing purposes with [nGPT-pytorch](https://github.com/lucidrains/nGPT-pytorch) ## inference 1. download a weights file from this repo ```sh wget -O ./nGPT_best.pt "https://huggingface.co/pszemraj/nGPT-enwiki8/resolve/main/nGPT_best.pt" ``` 2. install dependencies ```sh # assuming you already have torch pip install fire nGPT-pytorch ``` 3. run inference with below with `python inference.py ./nGPT_best.pt "Once upon a time"` ```py # inference.py import json import sys from pathlib import Path import fire import torch from nGPT_pytorch import nGPT def exists(v): return v is not None def decode_token(token): return str(chr(max(32, token))) def decode_tokens(tokens): return "".join(list(map(decode_token, tokens))) def log(t, eps=1e-20): return torch.log(t.clamp(min=eps)) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax( dim=dim, keepdim=keepdim ) def min_p_filter(logits, min_p=0.1): probs = logits.softmax(dim=-1) max_probs = probs.amax(dim=-1, keepdim=True) limit = min_p * max_probs return torch.where(probs < limit, float("-inf"), logits) def base_decoding( net, prompt: torch.Tensor, seq_len: int, temperature=1.5, min_p=1e-1, filter_thres=0.9, ): prompt_seq_len, out = prompt.shape[-1], prompt.clone() sample_num_times = max(0, seq_len - prompt_seq_len) for _ in range(sample_num_times): logits = net(out) logits = logits[:, -1] logits = min_p_filter(logits, min_p=min_p) sample = gumbel_sample(logits, temperature=temperature, dim=-1) out = torch.cat((out, sample), dim=-1) return out[..., prompt_seq_len:] def main( checkpoint_path: str, prompt: str, max_new_tokens: int = 100, temperature: float = 1.0, min_p: float = 0.1, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): """Generate text using a trained nGPT model.""" # Load checkpoint checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): print(f"Error: Checkpoint not found at {checkpoint_path}") sys.exit(1) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) # Get config from checkpoint or file config = checkpoint.get("config", {}) if not config and checkpoint_path.parent.joinpath("config.json").exists(): with open(checkpoint_path.parent.joinpath("config.json")) as f: config = json.load(f) use_parametrize = config.get("use_parametrize", True) # Initialize model model = nGPT( num_tokens=256, dim=512, depth=8, tied_embedding=True, add_value_residual=True, attn_norm_qk=False, manual_norm_weights=not use_parametrize, ).to(device) # Load weights model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print("\nModel loaded successfully. Generating with:") print(f" Temperature: {temperature}") print(f" Min-p: {min_p}") print(f" Max new tokens: {max_new_tokens}") # Convert prompt to tensor prompt_tensor = torch.tensor( [ord(c) for c in prompt], dtype=torch.long, device=device ) prompt_tensor = prompt_tensor.unsqueeze(0) # Generate with torch.no_grad(): sampled = base_decoding( model, prompt_tensor, seq_len=max_new_tokens, temperature=temperature, min_p=min_p, ) generated = decode_tokens(sampled[0]) print("\nGenerated text:") print("-" * 80) print(prompt + generated) print("-" * 80) return generated if __name__ == "__main__": fire.Fire(main) ```