Text Generation
Transformers
Safetensors
English
model_binary_affine_code_n_layer_32
feature-extraction
causal-lm
transformer
decoder-only
table-free-input
binary-token-codes
affine-recoding
research
custom_code
Instructions to use E6E831728/affine-recoded-minimal-code-table-free with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use E6E831728/affine-recoded-minimal-code-table-free with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="E6E831728/affine-recoded-minimal-code-table-free", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("E6E831728/affine-recoded-minimal-code-table-free", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use E6E831728/affine-recoded-minimal-code-table-free with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "E6E831728/affine-recoded-minimal-code-table-free" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "E6E831728/affine-recoded-minimal-code-table-free", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/E6E831728/affine-recoded-minimal-code-table-free
- SGLang
How to use E6E831728/affine-recoded-minimal-code-table-free with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "E6E831728/affine-recoded-minimal-code-table-free" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "E6E831728/affine-recoded-minimal-code-table-free", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "E6E831728/affine-recoded-minimal-code-table-free" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "E6E831728/affine-recoded-minimal-code-table-free", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use E6E831728/affine-recoded-minimal-code-table-free with Docker Model Runner:
docker model run hf.co/E6E831728/affine-recoded-minimal-code-table-free
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from transformers.generation import GenerationMixin | |
| from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithCrossAttentions | |
| def gf2_rank(M: torch.Tensor) -> int: | |
| """ | |
| Rank over GF(2). | |
| M: [n, m] tensor with 0/1 entries. | |
| """ | |
| M = (M.clone().to(torch.uint8) & 1) | |
| n_rows, n_cols = M.shape | |
| rank = 0 | |
| for col in range(n_cols): | |
| pivot = None | |
| for r in range(rank, n_rows): | |
| if M[r, col].item(): | |
| pivot = r | |
| break | |
| if pivot is None: | |
| continue | |
| if pivot != rank: | |
| tmp = M[rank].clone() | |
| M[rank] = M[pivot] | |
| M[pivot] = tmp | |
| for r in range(n_rows): | |
| if r != rank and M[r, col].item(): | |
| M[r] ^= M[rank] | |
| rank += 1 | |
| if rank == n_rows: | |
| break | |
| return rank | |
| def gf2_inverse(A: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Inverse over GF(2). | |
| A: [n, n] with 0/1 entries, invertible over GF(2). | |
| """ | |
| A = (A.clone().to(torch.uint8) & 1) | |
| n = A.shape[0] | |
| I = torch.eye(n, dtype=torch.uint8, device=A.device) | |
| aug = torch.cat([A, I], dim=1) | |
| row = 0 | |
| for col in range(n): | |
| pivot = None | |
| for r in range(row, n): | |
| if aug[r, col].item(): | |
| pivot = r | |
| break | |
| if pivot is None: | |
| raise ValueError("Matrix is not invertible over GF(2).") | |
| if pivot != row: | |
| tmp = aug[row].clone() | |
| aug[row] = aug[pivot] | |
| aug[pivot] = tmp | |
| for r in range(n): | |
| if r != row and aug[r, col].item(): | |
| aug[r] ^= aug[row] | |
| row += 1 | |
| left = aug[:, :n] | |
| if not torch.equal(left, I): | |
| raise RuntimeError("GF(2) inverse construction failed.") | |
| return aug[:, n:] | |
| def make_random_invertible_binary_matrix( | |
| code_bits: int, | |
| seed: int = 0, | |
| min_row_weight: int = 4, | |
| min_col_weight: int = 4, | |
| device: str = "cpu", | |
| ): | |
| """ | |
| Random dense-ish invertible matrix A in GL(code_bits, 2) | |
| and random shift b in {0,1}^{code_bits}. | |
| min_row_weight / min_col_weight are optional constraints | |
| to avoid trivial near-permutation matrices. | |
| """ | |
| g = torch.Generator(device=device if device != "cpu" else "cpu") | |
| g.manual_seed(seed) | |
| while True: | |
| A = torch.randint( | |
| 0, 2, (code_bits, code_bits), | |
| generator=g, dtype=torch.uint8, device=device | |
| ) | |
| if gf2_rank(A) != code_bits: | |
| continue | |
| if min_row_weight is not None: | |
| if not torch.all(A.sum(dim=1) >= min_row_weight): | |
| continue | |
| if min_col_weight is not None: | |
| if not torch.all(A.sum(dim=0) >= min_col_weight): | |
| continue | |
| b = torch.randint( | |
| 0, 2, (code_bits,), | |
| generator=g, dtype=torch.uint8, device=device | |
| ) | |
| return A, b | |
| class BVVConfig(PretrainedConfig): | |
| model_type = "model_binary_affine_code_n_layer_32" | |
| def __init__( | |
| self, | |
| vocab_size=65536, | |
| code_bits=None, | |
| n_embed=16, # backward-compatible alias | |
| d_model=1024, | |
| n_head=32, | |
| n_layer=32, | |
| block_size=1024, | |
| dropout=0.00, | |
| layer_norm_eps=1e-5, | |
| initializer_range=0.02, | |
| pad_token_id=57344, | |
| pad_id=57344, | |
| bos_token_id=None, | |
| eos_token_id=None, | |
| tie_word_embeddings=False, | |
| use_cache=False, | |
| # affine code params | |
| code_seed=12345, | |
| code_matrix=None, # optional explicit A | |
| code_shift=None, # optional explicit b | |
| min_row_weight=4, | |
| min_col_weight=4, | |
| zero_pad_code=True, | |
| **kwargs, | |
| ): | |
| if pad_token_id is None: | |
| pad_token_id = 57344 if pad_id is None else pad_id | |
| super().__init__( | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| tie_word_embeddings=tie_word_embeddings, | |
| use_cache=use_cache, | |
| **kwargs, | |
| ) | |
| if code_bits is None: | |
| code_bits = n_embed | |
| if vocab_size != (1 << code_bits): | |
| raise ValueError( | |
| f"For the exact minimal-code experiment require " | |
| f"vocab_size == 2**code_bits, got vocab_size={vocab_size}, code_bits={code_bits}." | |
| ) | |
| if d_model % code_bits != 0: | |
| raise ValueError(f"d_model ({d_model}) must be divisible by code_bits ({code_bits})") | |
| if d_model % n_head != 0: | |
| raise ValueError(f"d_model ({d_model}) must be divisible by n_head ({n_head})") | |
| if (d_model // n_head) % 2 != 0: | |
| raise ValueError("head_dim must be even for rotary embeddings") | |
| self.vocab_size = vocab_size | |
| self.block_size = block_size | |
| self.max_position_embeddings = block_size | |
| self.code_bits = code_bits | |
| self.n_embed = code_bits # alias for old scripts | |
| self.d_model = d_model | |
| self.n_head = n_head | |
| self.n_layer = n_layer | |
| self.dropout = dropout | |
| self.layer_norm_eps = layer_norm_eps | |
| self.initializer_range = initializer_range | |
| self.scale = d_model // code_bits | |
| # code params | |
| self.code_seed = code_seed | |
| self.code_matrix = code_matrix | |
| self.code_shift = code_shift | |
| self.min_row_weight = min_row_weight | |
| self.min_col_weight = min_col_weight | |
| self.zero_pad_code = zero_pad_code | |
| # backward compatibility | |
| self.pad_id = pad_token_id | |
| class BinaryAffineCodeInput(nn.Module): | |
| """ | |
| Table-free token input: | |
| token id -> 16-bit code -> affine GF(2) mixing -> tiled lift to d_model | |
| No trainable parameters. | |
| """ | |
| def __init__(self, config: BVVConfig): | |
| super().__init__() | |
| self.vocab_size = config.vocab_size | |
| self.code_bits = config.code_bits | |
| self.d_model = config.d_model | |
| self.scale = config.scale | |
| self.pad_token_id = config.pad_token_id | |
| self.zero_pad_code = config.zero_pad_code | |
| self.register_buffer( | |
| "bit_positions", | |
| torch.arange(self.code_bits, dtype=torch.long), | |
| persistent=False, | |
| ) | |
| if config.code_matrix is None: | |
| A, _ = make_random_invertible_binary_matrix( | |
| code_bits=self.code_bits, | |
| seed=config.code_seed, | |
| min_row_weight=config.min_row_weight, | |
| min_col_weight=config.min_col_weight, | |
| device="cpu", | |
| ) | |
| else: | |
| A = torch.tensor(config.code_matrix, dtype=torch.uint8) | |
| if A.shape != (self.code_bits, self.code_bits): | |
| raise ValueError( | |
| f"code_matrix must have shape {(self.code_bits, self.code_bits)}, got {tuple(A.shape)}" | |
| ) | |
| if gf2_rank(A) != self.code_bits: | |
| raise ValueError("Provided/generated code_matrix is not invertible over GF(2).") | |
| # --- choose b so that pad_token_id maps to 0^K --- | |
| if config.code_shift is None: | |
| pad = torch.tensor(config.pad_token_id, dtype=torch.long, device=A.device) | |
| bit_positions = torch.arange(self.code_bits, dtype=torch.long, device=A.device) | |
| # LSB-first, same convention as ids_to_bits() | |
| pad_bits = ((pad >> bit_positions) & 1).to(torch.float32) # [K] | |
| # because forward uses: codes = bits @ A.T xor b | |
| b = torch.remainder(pad_bits @ A.to(torch.float32).T, 2.0).to(torch.uint8) | |
| else: | |
| b = torch.tensor(config.code_shift, dtype=torch.uint8) | |
| if b.shape != (self.code_bits,): | |
| raise ValueError( | |
| f"code_shift must have shape {(self.code_bits,)}, got {tuple(b.shape)}" | |
| ) | |
| self.register_buffer("A_gf2", (A & 1).contiguous(), persistent=True) | |
| self.register_buffer("b_gf2", (b & 1).contiguous(), persistent=True) | |
| def ids_to_bits(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| """ | |
| input_ids: [B, T] int64 | |
| returns: [B, T, K] float32 in {0,1} | |
| """ | |
| if input_ids.dtype != torch.long: | |
| input_ids = input_ids.long() | |
| if input_ids.min().item() < 0 or input_ids.max().item() >= self.vocab_size: | |
| raise ValueError( | |
| f"input_ids out of range: min={input_ids.min().item()}, " | |
| f"max={input_ids.max().item()}, vocab_size={self.vocab_size}" | |
| ) | |
| bits = ((input_ids.unsqueeze(-1) >> self.bit_positions) & 1).to(torch.float32) | |
| return bits | |
| def mix_bits_affine(self, bits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| bits: [B, T, K] float32 with entries 0/1 | |
| returns c = bits @ A^T + b mod 2 | |
| """ | |
| A = self.A_gf2.to(device=bits.device, dtype=torch.float32) | |
| b = self.b_gf2.to(device=bits.device, dtype=torch.float32) | |
| mixed = torch.remainder(torch.matmul(bits, A.T) + b, 2.0) | |
| return mixed | |
| def encode_bits(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| bits = self.ids_to_bits(input_ids) | |
| codes = self.mix_bits_affine(bits) | |
| return codes | |
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| """ | |
| returns x: [B, T, d_model] | |
| """ | |
| codes = self.encode_bits(input_ids) # [B, T, K] | |
| x = codes.repeat(1, 1, self.scale) # [B, T, d_model] | |
| # Optional: keep pad positions exactly zero in the continuous input tensor | |
| if self.zero_pad_code and self.pad_token_id is not None: | |
| pad_mask = input_ids.eq(self.pad_token_id).unsqueeze(-1) | |
| x = x.masked_fill(pad_mask, 0.0) | |
| return x | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
| t = torch.arange(end, device=freqs.device) | |
| freqs = torch.outer(t, freqs).float() | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
| ndim = x.ndim | |
| assert 0 <= 1 < ndim | |
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis.view(*shape) | |
| def apply_rotary_emb( | |
| xq: torch.Tensor, | |
| xk: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ): | |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| return xq_out.type_as(xq), xk_out.type_as(xk) | |
| class MultiHeadSelfAttention(nn.Module): | |
| def __init__(self, d_model, n_head, dropout=0.0): | |
| super().__init__() | |
| assert d_model % n_head == 0 | |
| self.d_model = d_model | |
| self.n_head = n_head | |
| self.head_dim = d_model // n_head | |
| assert self.head_dim % 2 == 0, "head_dim must be even for rotary embeddings" | |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.o_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, freqs_cis, mask=None): | |
| B, T, C = x.shape | |
| q = self.q_proj(x).view(B, T, self.n_head, self.head_dim) | |
| k = self.k_proj(x).view(B, T, self.n_head, self.head_dim) | |
| v = self.v_proj(x).view(B, T, self.n_head, self.head_dim) | |
| q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) | |
| q = q.transpose(1, 2) # (B, n_head, T, head_dim) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| if mask is not None: | |
| attn_scores = attn_scores + mask | |
| attn_probs = F.softmax(attn_scores.float(), dim=-1).type_as(q) | |
| attn_probs = self.dropout(attn_probs) | |
| out = torch.matmul(attn_probs, v) | |
| out = out.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.o_proj(out) | |
| class TransformerMLP(nn.Module): | |
| def __init__(self, d_model, dropout=0.0): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(d_model, 4 * d_model), | |
| nn.GELU(), | |
| nn.Linear(4 * d_model, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, n_head, dropout=0.0, layer_norm_eps=1e-5): | |
| super().__init__() | |
| self.self_attn = MultiHeadSelfAttention(d_model, n_head, dropout=dropout) | |
| self.mlp = TransformerMLP(d_model, dropout=dropout) | |
| self.input_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.post_attention_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| def forward(self, x, freqs_cis, mask=None): | |
| x = x + self.self_attn(self.input_layernorm(x), freqs_cis, mask) | |
| x = x + self.mlp(self.post_attention_layernorm(x)) | |
| return x | |
| class BVVForCausalLM(PreTrainedModel, GenerationMixin): | |
| config_class = BVVConfig | |
| main_input_name = "input_ids" | |
| def __init__(self, config: BVVConfig): | |
| super().__init__(config) | |
| # no nn.Embedding here | |
| self.input_code = BinaryAffineCodeInput(config) | |
| self.transformer_layers = nn.ModuleList([ | |
| TransformerBlock( | |
| config.d_model, | |
| n_head=config.n_head, | |
| dropout=config.dropout, | |
| layer_norm_eps=config.layer_norm_eps, | |
| ) | |
| for _ in range(config.n_layer) | |
| ]) | |
| self.final_layernorm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size) | |
| self.register_buffer( | |
| "freqs_cis", | |
| precompute_freqs_cis( | |
| config.d_model // config.n_head, | |
| config.block_size, | |
| ), | |
| persistent=False, | |
| ) | |
| self.post_init() | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def get_input_embeddings(self): | |
| # there is no embedding table | |
| return None | |
| def set_input_embeddings(self, value): | |
| raise NotImplementedError("This model uses algorithmic binary token codes, not nn.Embedding.") | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): | |
| if input_ids.shape[1] > self.config.block_size: | |
| input_ids = input_ids[:, -self.config.block_size:] | |
| if attention_mask is not None: | |
| attention_mask = attention_mask[:, -self.config.block_size:] | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| } | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| labels=None, | |
| targets=None, | |
| return_dict=None, | |
| output_logits=True, | |
| **kwargs, | |
| ): | |
| if input_ids is None: | |
| raise ValueError("input_ids must be provided") | |
| if labels is not None and targets is not None: | |
| raise ValueError("Use either labels or targets, not both.") | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| B, T = input_ids.shape | |
| if T > self.config.block_size: | |
| raise ValueError(f"Sequence length {T} exceeds block_size {self.config.block_size}") | |
| # ---- table-free input coding ---- | |
| x = self.input_code(input_ids) | |
| # cast to model dtype if needed | |
| x = x.to(dtype=self.final_layernorm.weight.dtype) | |
| freqs_cis = self.freqs_cis[:T] | |
| if not torch.is_complex(freqs_cis): | |
| freqs_cis = torch.view_as_complex(freqs_cis.contiguous()) | |
| freqs_cis = freqs_cis.to(x.device) | |
| mask = None | |
| mask_value = torch.finfo(x.dtype).min | |
| if T > 1: | |
| mask = torch.full((1, 1, T, T), mask_value, device=x.device, dtype=x.dtype) | |
| mask = torch.triu(mask, diagonal=1) | |
| if attention_mask is not None: | |
| if attention_mask.shape != (B, T): | |
| raise ValueError(f"attention_mask must have shape {(B, T)}, got {tuple(attention_mask.shape)}") | |
| pad_mask = torch.zeros((B, 1, 1, T), device=x.device, dtype=x.dtype) | |
| pad_mask = pad_mask.masked_fill(attention_mask[:, None, None, :].eq(0), mask_value) | |
| mask = pad_mask if mask is None else mask + pad_mask | |
| for layer in self.transformer_layers: | |
| x = layer(x, freqs_cis, mask) | |
| x = self.final_layernorm(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = labels[:, 1:].contiguous() | |
| if attention_mask is not None: | |
| shift_labels = shift_labels.masked_fill(attention_mask[:, 1:].eq(0), -100) | |
| if self.config.pad_token_id is not None: | |
| shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_token_id, -100) | |
| loss = F.cross_entropy( | |
| shift_logits.float().view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| elif targets is not None: | |
| legacy_targets = targets.contiguous() | |
| if attention_mask is not None: | |
| legacy_targets = legacy_targets.masked_fill(attention_mask.eq(0), -100) | |
| if self.config.pad_token_id is not None: | |
| legacy_targets = legacy_targets.masked_fill(legacy_targets == self.config.pad_token_id, -100) | |
| loss = F.cross_entropy( | |
| logits.float().view(-1, logits.size(-1)), | |
| legacy_targets.view(-1), | |
| ignore_index=-100, | |
| ) | |
| if not return_dict: | |
| if output_logits: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| return (loss,) if loss is not None else tuple() | |
| if output_logits: | |
| return CausalLMOutput(loss=loss, logits=logits) | |
| return CausalLMOutput(loss=loss, logits=None) | |
| def generate(self, input_ids, max_new_tokens, attention_mask=None, do_sample=False): | |
| was_training = self.training | |
| self.eval() | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long) | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| input_ids_cond = input_ids[:, -self.config.block_size:] | |
| attention_mask_cond = attention_mask[:, -self.config.block_size:] | |
| outputs = self( | |
| input_ids=input_ids_cond, | |
| attention_mask=attention_mask_cond, | |
| return_dict=True | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| if do_sample: | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| else: | |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| attention_mask = torch.cat( | |
| [attention_mask, torch.ones_like(next_token, dtype=attention_mask.dtype)], | |
| dim=1 | |
| ) | |
| if was_training: | |
| self.train() | |
| return input_ids |