Spaces:
No application file
No application file
| import torch | |
| from torch.distributions import Uniform | |
| def forward_process_block_fixed_p(x, mask_id, p_mask): | |
| B, L = x.shape | |
| if isinstance(p_mask, float): | |
| p_mask = torch.full((B, 1), p_mask, device=x.device) | |
| elif p_mask.ndim == 1: | |
| p_mask = p_mask[:, None] | |
| rand = torch.rand((B, L), device=x.device) | |
| mask = rand < p_mask | |
| x_masked = torch.where(mask, mask_id, x) | |
| return x_masked, mask | |
| import torch | |
| def generate_monotonic_pmasks(batch_size, max_blocks, device): | |
| """ | |
| 生成 shape (B, max_blocks) 的单调非降随机序列,每行第一个元素在[0,1]随机,后续不小于前一个 | |
| """ | |
| # 第一个block p_mask随机 | |
| p0 = torch.rand(batch_size, 1, device=device)/2+0.2 | |
| # print(p0) | |
| # 后续blocks生成增量 [0, 1],加起来保证不超过1(之后用 clamp) | |
| increments = torch.rand(batch_size, max_blocks - 1, device=device) * (0.7 - p0)/ (max_blocks - 1) | |
| # print(increments) | |
| # 逐元素累加,保证非降 | |
| cum_increments = torch.cumsum(increments, dim=1) | |
| # print(cum_increments) | |
| # 总 p_mask = p0 + 累积增量,保证不超过1 | |
| p_masks = torch.cat([p0, p0 + cum_increments], dim=1) | |
| p_masks = torch.clamp(p_masks, max=1.0) | |
| # print(p_masks) | |
| return p_masks # (B, max_blocks) | |
| def forward_process_length(input_ids, mask_id, block_size, prompt_lengths,eos_id=None): | |
| """ | |
| Args: | |
| input_ids: (B, L) | |
| prompt_lengths: (B,) | |
| Returns: | |
| noisy_batch, masked_indices, p_mask_tensor | |
| """ | |
| B, L = input_ids.shape | |
| device = input_ids.device | |
| noisy_batch = input_ids.clone() | |
| eos_indices= (input_ids==eos_id) | |
| masked_indices = torch.zeros_like(input_ids,dtype=torch.bool) | |
| p_mask_tensor = torch.zeros((B, L), device=device) | |
| # 计算每个样本block数 | |
| non_prompt_lens = L - prompt_lengths | |
| full_blocks = non_prompt_lens // block_size | |
| remainders = non_prompt_lens % block_size | |
| total_blocks = full_blocks + (remainders > 0).long() | |
| max_blocks = total_blocks.max().item() | |
| # 生成每个样本block的mask比率,单调非降且第一个随机 | |
| p_masks = generate_monotonic_pmasks(B, max_blocks, device) # shape (B, max_blocks) | |
| for i in range(B): | |
| prompt_len = prompt_lengths[i].item() | |
| num_blocks = total_blocks[i].item() | |
| start_block = torch.tensor([0]) # 随机选择一个block开始 | |
| for block_idx in range(num_blocks): | |
| if block_idx < start_block: | |
| continue | |
| start = prompt_len + block_idx * block_size | |
| end = min(start + block_size, L) | |
| p_block = p_masks[i, block_idx-start_block].item() | |
| block = noisy_batch[i, start:end].unsqueeze(0) | |
| masked_block, mask = forward_process_block_fixed_p(block, mask_id, p_block) | |
| noisy_batch[i, start:end] = masked_block.squeeze(0) | |
| masked_indices[i, start:end] = mask.squeeze(0) | |
| # if torch.all(input_ids[i, start:end] == eos_id): | |
| # masked_indices[i,start:end]== False | |
| # print("1") | |
| p_mask_tensor[i, start:end] = p_block | |
| return noisy_batch, masked_indices, p_mask_tensor | |
| # def forward_process_length(input_ids, mask_id, block_size, prompt_lengths, p_min=0.2, p_max=0.9): | |
| # """ | |
| # 返回每个 token 的实际 mask 概率 tensor(非prompt区域),其余为0。 | |
| # """ | |
| # B, L = input_ids.shape | |
| # device = input_ids.device | |
| # noisy_batch = input_ids.clone() | |
| # masked_indices = torch.zeros_like(input_ids, dtype=torch.bool) | |
| # p_mask_tensor = torch.zeros((B, L), device=device) # 最终返回值 | |
| # for i in range(B): | |
| # prompt_len = prompt_lengths[i].item() | |
| # non_prompt_len = L - prompt_len | |
| # full_blocks = non_prompt_len // block_size | |
| # remainder = non_prompt_len % block_size | |
| # total_blocks = full_blocks + (1 if remainder > 0 else 0) | |
| # for block_idx in range(total_blocks): | |
| # start = prompt_len + block_idx * block_size | |
| # end = min(start + block_size, L) | |
| # # block的 mask 概率(线性递增) | |
| # if total_blocks > 1: | |
| # p_block = p_min + (p_max - p_min) * (block_idx / (total_blocks - 1)) | |
| # else: | |
| # p_block = p_max | |
| # block = noisy_batch[i, start:end].unsqueeze(0) | |
| # masked_block, mask = forward_process_block_fixed_p(block, mask_id, p_block) | |
| # noisy_batch[i, start:end] = masked_block.squeeze(0) | |
| # masked_indices[i, start:end] = mask.squeeze(0) | |
| # # 记录 p_mask 到 tensor 中 | |
| # p_mask_tensor[i, start:end] = p_block | |
| # return noisy_batch, masked_indices, p_mask_tensor | |
| def forward_process(input_ids,mask_id ,t_max=1.0, eps=1e-4): | |
| B, L = input_ids.shape | |
| # t = torch.rand(B, device=input_ids.device) | |
| dist = Uniform(0., t_max) | |
| t = dist.sample((B,)).to(input_ids.device) | |
| p_mask = (1 - eps) * t + eps | |
| p_mask = p_mask[:, None].repeat(1, L) | |
| masked_indices = torch.rand((B, L), device=input_ids.device) < p_mask | |
| noisy_batch = torch.where(masked_indices, mask_id, input_ids) | |
| return noisy_batch, masked_indices, p_mask | |
| def flatten_dict(d, parent_key='', sep='_'): | |
| items = [] | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
| if isinstance(v, dict): | |
| items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def shift_logits(logits): | |
| shifted_logits = torch.zeros_like(logits) | |
| shifted_logits[:, 1:, :] = logits[:, :-1, :] | |
| shifted_logits[:, 0, :] = 1.0 | |
| return shifted_logits | |
| if __name__ == '__main__': | |
| input_ids= torch.tensor([[1,5,4,3,25,6,7,9,5,8,7,6],[1,3,8,9,7,34,6,9,5,8,7,6]]) | |
| mask_id=0 | |
| block_size=3 | |
| prompt_length=torch.tensor([2,1]) | |
| noisy_batch, masked_indices,p_mask = forward_process_length(input_ids, mask_id, block_size, prompt_length) | |
| print("noisy_batch:", noisy_batch) | |
| print("masked_indices:", masked_indices) | |
| print("p_mask:", p_mask) | |