File size: 4,841 Bytes
b0e4fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

def compute_memory_used_pct(device):
    memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
    memory_pct = (
        memory_used
        / (torch.cuda.get_device_properties(device).total_memory / (1024**3))
        * 100
    )
    return memory_pct

model_path = "./out"

n_ahead = 8
n_ahead_talk = 4
merged_talk_heads = True

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    max_thoughts=n_ahead + n_ahead_talk + 1,
    merged_talk_heads=merged_talk_heads,
    merged_lm_and_talk_heads=False,
    merged_lm_and_think_heads=True,
    use_concat_talk_head=True,
    use_shallow_think=True,
    use_shallow_talk=False,
    use_complex_think_head=False,
    use_complex_talk_head=True,
    use_weighted_talk_head=True,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Load the tokenizer and assign it to the model instance for compatibility
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.tokenizer = tokenizer

model.use_end_thought_token = True
model.use_start_thought_token = True
model.wandb_enabled = True
model.n_ahead = n_ahead
model.n_passes = 2
model.eval_mode = True
model.first_run = False
model.kill_after = 100
model.rm_initialized = True
model.original_mode = False

# Custom generate function
def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
    with torch.no_grad():
        finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
        for cur_token_idx in range(max_new_tokens):
            # Sample the next token
            new_ids = model(
                input_ids[~finished_generating],
                attention_mask=attention_mask[~finished_generating]
            )['logits']
            # Mask out the start and end thought tokens so we don't accidentally sample them
            new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
            for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
                # Find the index of the last token that is not padding
                base_answer_ids = input_ids[answer_idx]
                new_answer_ids = new_ids[list_idx]
                last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()

                new_ids_sampled = torch.multinomial(
                    torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
                # Assign the new id to the last token
                if last_token_idx + 1 >= len(base_answer_ids):
                    # Add padding everywhere
                    new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
                                             device=input_ids.device)
                    input_ids = torch.cat([input_ids, new_padding], dim=-1)
                    attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
                attention_mask[answer_idx, last_token_idx + 1] = 1
                input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
                if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
                    finished_generating[answer_idx] = 1
                # Check if the end token is generated
                if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
                    finished_generating[answer_idx] = 1
            if finished_generating.all():
                break
            streamer.put(new_ids_sampled)
    return input_ids, attention_mask

# Formulate your prompt
prompt_template = "[INST] {prompt} [/INST]"

prompt = "You're standing on the surface of the Earth. "\
        "You walk one mile south, one mile west and one mile north. "\
        "You end up exactly where you started. Where are you?"

# Convert prompt to tokens
tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device)

# Generate an attention mask
attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device)

streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

# Generate output using the custom generate function
output_ids, _ = custom_generate(
    model,
    input_ids=tokens,
    attention_mask=attention_mask,
    max_new_tokens=512,
    streamer=streamer,
    temperature=0.9,
)

generated_text = ""

print()  # Print a newline after streaming is complete

# Cleanup if necessary
torch.cuda.empty_cache()