pt-sk commited on
Commit
3f0e944
·
verified ·
1 Parent(s): 3c1b667

Create inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +37 -0
code/inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def generate(model,
6
+ tokenizer,
7
+ prompt: str,
8
+ n_tokens_to_gen: int = 200,
9
+ sample: bool = True,
10
+ top_k: int = 40):
11
+ model.eval()
12
+
13
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
14
+
15
+ for token_n in range(n_tokens_to_gen):
16
+ with torch.no_grad():
17
+ indices_to_input = input_ids
18
+ next_token_logits = mamba_model(indices_to_input)[:, -1]
19
+
20
+ probs = F.softmax(next_token_logits, dim=-1)
21
+ (batch, vocab_size) = probs.shape
22
+
23
+ if top_k is not None:
24
+ (values, indices) = torch.topk(probs, k=top_k)
25
+ probs[probs < values[:, -1, None]] = 0
26
+ probs = probs / probs.sum(axis=1, keepdims=True)
27
+
28
+ if sample:
29
+ next_indices = torch.multinomial(probs, num_samples=1)
30
+ else:
31
+ next_indices = torch.argmax(probs, dim=-1)[:, None]
32
+
33
+ input_ids = torch.cat([input_ids, next_indices], dim=1)
34
+
35
+ output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
36
+
37
+ return output_completions