mamba / code /inference.py
pt-sk's picture
Create inference.py
3f0e944 verified
raw
history blame
1.15 kB
import torch
import torch.nn.functional as F
def generate(model,
tokenizer,
prompt: str,
n_tokens_to_gen: int = 200,
sample: bool = True,
top_k: int = 40):
model.eval()
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
for token_n in range(n_tokens_to_gen):
with torch.no_grad():
indices_to_input = input_ids
next_token_logits = mamba_model(indices_to_input)[:, -1]
probs = F.softmax(next_token_logits, dim=-1)
(batch, vocab_size) = probs.shape
if top_k is not None:
(values, indices) = torch.topk(probs, k=top_k)
probs[probs < values[:, -1, None]] = 0
probs = probs / probs.sum(axis=1, keepdims=True)
if sample:
next_indices = torch.multinomial(probs, num_samples=1)
else:
next_indices = torch.argmax(probs, dim=-1)[:, None]
input_ids = torch.cat([input_ids, next_indices], dim=1)
output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
return output_completions