Create inference.py
Browse files- code/inference.py +37 -0
code/inference.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def generate(model,
|
6 |
+
tokenizer,
|
7 |
+
prompt: str,
|
8 |
+
n_tokens_to_gen: int = 200,
|
9 |
+
sample: bool = True,
|
10 |
+
top_k: int = 40):
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
|
14 |
+
|
15 |
+
for token_n in range(n_tokens_to_gen):
|
16 |
+
with torch.no_grad():
|
17 |
+
indices_to_input = input_ids
|
18 |
+
next_token_logits = mamba_model(indices_to_input)[:, -1]
|
19 |
+
|
20 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
21 |
+
(batch, vocab_size) = probs.shape
|
22 |
+
|
23 |
+
if top_k is not None:
|
24 |
+
(values, indices) = torch.topk(probs, k=top_k)
|
25 |
+
probs[probs < values[:, -1, None]] = 0
|
26 |
+
probs = probs / probs.sum(axis=1, keepdims=True)
|
27 |
+
|
28 |
+
if sample:
|
29 |
+
next_indices = torch.multinomial(probs, num_samples=1)
|
30 |
+
else:
|
31 |
+
next_indices = torch.argmax(probs, dim=-1)[:, None]
|
32 |
+
|
33 |
+
input_ids = torch.cat([input_ids, next_indices], dim=1)
|
34 |
+
|
35 |
+
output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
|
36 |
+
|
37 |
+
return output_completions
|