Crystalcareai
commited on
Create Inference-improved.py
Browse files- Inference-improved.py +108 -0
Inference-improved.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
3 |
+
|
4 |
+
model_path = "cognitivecompuations/Quiet-STaR-Base"
|
5 |
+
|
6 |
+
n_ahead = 8
|
7 |
+
n_ahead_talk = 4
|
8 |
+
merged_talk_heads = True
|
9 |
+
|
10 |
+
# Load the model
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(
|
12 |
+
model_path,
|
13 |
+
max_thoughts=n_ahead + n_ahead_talk + 1,
|
14 |
+
merged_talk_heads=merged_talk_heads,
|
15 |
+
merged_lm_and_talk_heads=False,
|
16 |
+
merged_lm_and_think_heads=True,
|
17 |
+
use_concat_talk_head=True,
|
18 |
+
use_shallow_think=True,
|
19 |
+
use_shallow_talk=False,
|
20 |
+
use_complex_think_head=False,
|
21 |
+
use_complex_talk_head=True,
|
22 |
+
use_weighted_talk_head=True,
|
23 |
+
trust_remote_code=True,
|
24 |
+
torch_dtype=torch.bfloat16,
|
25 |
+
device_map="auto",
|
26 |
+
)
|
27 |
+
|
28 |
+
# Load the tokenizer and assign it to the model instance for compatibility
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
30 |
+
model.tokenizer = tokenizer
|
31 |
+
|
32 |
+
model.use_end_thought_token = True
|
33 |
+
model.use_start_thought_token = True
|
34 |
+
model.wandb_enabled = True
|
35 |
+
model.n_ahead = n_ahead
|
36 |
+
model.n_passes = 2
|
37 |
+
model.eval_mode = True
|
38 |
+
model.first_run = False
|
39 |
+
model.kill_after = 100
|
40 |
+
model.rm_initialized = True
|
41 |
+
model.original_mode = False
|
42 |
+
|
43 |
+
def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
|
44 |
+
with torch.no_grad():
|
45 |
+
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
|
46 |
+
for cur_token_idx in range(max_new_tokens):
|
47 |
+
# Sample the next token
|
48 |
+
new_ids = model(
|
49 |
+
input_ids[~finished_generating],
|
50 |
+
attention_mask=attention_mask[~finished_generating]
|
51 |
+
)['logits']
|
52 |
+
# Mask out the start and end thought tokens so we don't accidentally sample them
|
53 |
+
new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
|
54 |
+
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
|
55 |
+
# Find the index of the last token that is not padding
|
56 |
+
base_answer_ids = input_ids[answer_idx]
|
57 |
+
new_answer_ids = new_ids[list_idx]
|
58 |
+
last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
59 |
+
|
60 |
+
new_ids_sampled = torch.multinomial(
|
61 |
+
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
|
62 |
+
# Assign the new id to the last token
|
63 |
+
if last_token_idx + 1 >= len(base_answer_ids):
|
64 |
+
# Add padding everywhere
|
65 |
+
new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
|
66 |
+
device=input_ids.device)
|
67 |
+
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
68 |
+
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
69 |
+
attention_mask[answer_idx, last_token_idx + 1] = 1
|
70 |
+
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
71 |
+
if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
|
72 |
+
finished_generating[answer_idx] = 1
|
73 |
+
# Check if the end token is generated
|
74 |
+
if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
|
75 |
+
finished_generating[answer_idx] = 1
|
76 |
+
if finished_generating.all():
|
77 |
+
break
|
78 |
+
streamer.put(new_ids_sampled)
|
79 |
+
return input_ids, attention_mask
|
80 |
+
|
81 |
+
prompt = " How would a typical person answer each of the following questions about causation? Frank T., had an ongoing dispute with his neighbor over a stretch of land and one day decided to shoot his neighbor in the body. Frank T. had no experience with guns, his hand slipped on the barrel of the gun, and the shot went wild. Nonetheless, the bullet bounced off a large boulder several feet away and hit the neighbor's body, causing significant injury. Did Frank T. intentionally shoot his neighbor in the body?"
|
82 |
+
|
83 |
+
input_ids = tokenizer(
|
84 |
+
prompt=prompt,
|
85 |
+
return_tensors='pt'
|
86 |
+
).input_ids.cuda()
|
87 |
+
|
88 |
+
# Convert prompt to tokens
|
89 |
+
tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device)
|
90 |
+
|
91 |
+
# Generate an attention mask
|
92 |
+
attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device)
|
93 |
+
|
94 |
+
streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
|
95 |
+
|
96 |
+
output_ids, _ = custom_generate(
|
97 |
+
model,
|
98 |
+
input_ids=tokens,
|
99 |
+
attention_mask=attention_mask,
|
100 |
+
max_new_tokens=512,
|
101 |
+
streamer=streamer,
|
102 |
+
temperature=0.9,
|
103 |
+
)
|
104 |
+
|
105 |
+
generated_text = ""
|
106 |
+
|
107 |
+
print()
|
108 |
+
|