hunterbown commited on
Commit
c42dd2b
·
verified ·
1 Parent(s): 751c787

Upload generation_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generation_utils.py +133 -0
generation_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
9
+ from transformers.cache_utils import DynamicCache
10
+
11
+ def add_gumbel_noise(logits, temperature):
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_num_transfer_tokens(mask_index, steps):
21
+ mask_num = mask_index.sum(dim=1, keepdim=True)
22
+
23
+ base = mask_num // steps
24
+ remainder = mask_num % steps
25
+
26
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
27
+
28
+ for i in range(mask_num.size(0)):
29
+ num_transfer_tokens[i, :remainder[i]] += 1
30
+
31
+ return num_transfer_tokens
32
+
33
+ def make_block_causal_mask(seq_len, block_size=2, device=None, dtype=torch.bool):
34
+ num_blocks = (seq_len + block_size - 1) // block_size
35
+ block_mask = torch.tril(torch.ones((num_blocks, num_blocks), dtype=torch.bool, device=device))
36
+ local_block = torch.ones((block_size, block_size), dtype=torch.bool, device=device)
37
+ mask = torch.kron(block_mask, local_block)[:seq_len, :seq_len]
38
+
39
+ attention_mask = mask.float()
40
+ attention_mask.masked_fill_(~mask, float('-inf'))
41
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(dtype)
42
+ return attention_mask
43
+
44
+ @ torch.no_grad()
45
+ def generate_block(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
46
+ remasking='low_confidence', tokenizer=None, mask_id=5, threshold=0.95, shift=False, eos_id=None):
47
+ x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
48
+ x[:, :prompt.shape[1]] = prompt.clone()
49
+
50
+ assert gen_length % block_length == 0
51
+ num_blocks = gen_length // block_length
52
+
53
+ assert steps % num_blocks == 0
54
+ steps = steps // num_blocks
55
+
56
+ prompt_len = prompt.shape[1]
57
+ res_block = block_length - prompt_len % block_length
58
+ every_block = [block_length for _ in range(num_blocks)]
59
+ if res_block > 0:
60
+ every_block = [res_block] + every_block
61
+ every_block[-1] = block_length - res_block
62
+ cum_block = [sum(every_block[:i+1]) for i in range(len(every_block))]
63
+ num_block = len(cum_block)
64
+
65
+ block_diffusion_attention_mask = make_block_causal_mask(prompt.shape[1] + gen_length, block_length, model.device, dtype=torch.bfloat16)
66
+ nfe = 0
67
+ final_flag = 0
68
+ prefill_length = prompt_len // block_length * block_length
69
+ if prefill_length > 0:
70
+ cur_attn_mask = block_diffusion_attention_mask[:, :, :prefill_length, :prefill_length]
71
+ past_key_values = model(x[:, :prefill_length], attention_mask=cur_attn_mask, use_cache=True).past_key_values
72
+ for num_block in range(num_blocks):
73
+ current_block_start = prompt_len + cum_block[num_block - 1] if num_block > 0 else prefill_length
74
+ current_block_end = prompt_len + cum_block[num_block]
75
+
76
+ block_mask_index = (x[:, current_block_start:current_block_end] == mask_id)
77
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
78
+
79
+ replace_position = torch.zeros_like(x, dtype=torch.bool)
80
+ replace_position[:, current_block_start:current_block_end] = 1
81
+ i = 0
82
+ while True:
83
+ nfe += 1
84
+ mask_index = (x[:, current_block_start:current_block_end] == mask_id)
85
+ cur_attn_mask = block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end]
86
+ output = model(x[:, current_block_start:current_block_end], attention_mask=block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end], past_key_values=past_key_values, use_cache=True, cache_position=replace_position.nonzero(as_tuple=True)[1])
87
+ logits = output.logits
88
+ x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index,
89
+ x[:, current_block_start:current_block_end], num_transfer_tokens[:, i] if threshold is None else None, threshold, shift=False)
90
+ x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index]
91
+ if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
92
+ if eos_id is not None and (x[:, current_block_start:current_block_end] == eos_id).sum() > 0:
93
+ final_flag = 1
94
+ x = x[:, :current_block_end]
95
+ break
96
+ past_key_values = model(x[:, current_block_start:current_block_end], attention_mask=block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end], past_key_values=past_key_values, use_cache=True, cache_position=replace_position.nonzero(as_tuple=True)[1]).past_key_values
97
+ break
98
+ if final_flag == 1:
99
+ break
100
+ i += 1
101
+ return x, nfe
102
+
103
+
104
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, shift=False):
105
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
106
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
107
+ if shift == True:
108
+ x0 = torch.cat([x[:, :1], x0[:, :-1]], dim=-1)
109
+ pad = torch.zeros_like(logits[:, :1])
110
+ logits = torch.cat([pad, logits[:, :-1]], dim=1)
111
+ if remasking == 'low_confidence':
112
+ p = F.softmax(logits.to(torch.float64), dim=-1)
113
+ x0_p = torch.squeeze(
114
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
115
+ elif remasking == 'random':
116
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
117
+ else:
118
+ raise NotImplementedError(remasking)
119
+
120
+ x0 = torch.where(mask_index, x0, x)
121
+ confidence = torch.where(mask_index, x0_p, -np.inf)
122
+
123
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
124
+ if threshold is not None:
125
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
126
+ for j in range(confidence.shape[0]):
127
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
128
+ transfer_index[j, select_index] = True
129
+ if threshold is not None:
130
+ for k in range(1, num_transfer_tokens[j]):
131
+ if confidence[j, select_index[k]] < threshold:
132
+ transfer_index[j, select_index[k]] = False
133
+ return x0, transfer_index