maomaocun commited on
Commit
01fd92c
·
verified ·
1 Parent(s): 8bd8789

Update modeling_llada.py

Browse files

update generate_function with streaming output

Files changed (1) hide show
  1. modeling_llada.py +81 -31
modeling_llada.py CHANGED
@@ -1181,7 +1181,8 @@ class LLaDAModel(nn.Module):
1181
  attention_bias: Optional[torch.Tensor] = None,
1182
  past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1183
  use_cache: bool = False,
1184
- last_logits_only: bool = False,
 
1185
  output_hidden_states: Optional[bool] = None,
1186
  ) -> LLaDAOutput:
1187
  """
@@ -1351,10 +1352,9 @@ class LLaDAModel(nn.Module):
1351
  assert cache is not None
1352
  attn_key_values.extend(cache)
1353
 
1354
- if last_logits_only:
1355
- # shape: (batch_size, 1, d_model)
1356
- x = x[:, -1, :].unsqueeze(1)
1357
-
1358
  # Apply final layer norm.
1359
  # shape: (batch_size, seq_len or 1, d_model)
1360
  x = self.transformer.ln_f(x) # type: ignore
@@ -1406,6 +1406,7 @@ class LLaDAModelLM(PreTrainedModel):
1406
  self.model = LLaDAModel(model_config, init_params=init_params)
1407
  else:
1408
  self.model = model
 
1409
 
1410
  def forward(
1411
  self,
@@ -1419,7 +1420,8 @@ class LLaDAModelLM(PreTrainedModel):
1419
  output_attentions: Optional[bool] = None,
1420
  output_hidden_states: Optional[bool] = None,
1421
  return_dict: Optional[bool] = None,
1422
- cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
 
1423
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1424
  if use_cache is None:
1425
  use_cache = self.config.use_cache
@@ -1438,6 +1440,8 @@ class LLaDAModelLM(PreTrainedModel):
1438
  past_key_values=past_key_values,
1439
  use_cache=use_cache,
1440
  output_hidden_states=output_hidden_states,
 
 
1441
  )
1442
 
1443
  logits = outputs.logits
@@ -1457,31 +1461,6 @@ class LLaDAModelLM(PreTrainedModel):
1457
  hidden_states=hidden_states,
1458
  )
1459
 
1460
- def can_generate(self) -> bool:
1461
- return True
1462
-
1463
- def prepare_inputs_for_generation(
1464
- self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
1465
- ):
1466
- if past_key_values:
1467
- # This is because we want the model to only process the last generated token.
1468
- input_ids = input_ids[:, -1:]
1469
- model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
1470
-
1471
- model_inputs.update(kwargs)
1472
- model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
1473
- return model_inputs
1474
-
1475
- # TODO: these are required to make the implementation complete.
1476
- # def resize_position_embeddings(self, new_num_position_embeddings: int):
1477
- # pass
1478
- #
1479
- # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
1480
- # pass
1481
- #
1482
- # def _reorder_cache(self, past_key_values, beam_idx):
1483
- # pass
1484
-
1485
  def get_input_embeddings(self) -> torch.nn.Module:
1486
  return self.model.transformer.wte
1487
 
@@ -1504,5 +1483,76 @@ class LLaDAModelLM(PreTrainedModel):
1504
  if self.config.weight_tying:
1505
  self.model.transformer.ff_out = self.model.transformer.wte
1506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1508
  AutoModel.register(LLaDAConfig, LLaDAModelLM)
 
1181
  attention_bias: Optional[torch.Tensor] = None,
1182
  past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1183
  use_cache: bool = False,
1184
+ last_block_logits_only: bool = False,
1185
+ block_length: int = 64,
1186
  output_hidden_states: Optional[bool] = None,
1187
  ) -> LLaDAOutput:
1188
  """
 
1352
  assert cache is not None
1353
  attn_key_values.extend(cache)
1354
 
1355
+ if last_block_logits_only:
1356
+ # shape: (batch_size, block_length, d_model)
1357
+ x = x[:, -block_length:, :]
 
1358
  # Apply final layer norm.
1359
  # shape: (batch_size, seq_len or 1, d_model)
1360
  x = self.transformer.ln_f(x) # type: ignore
 
1406
  self.model = LLaDAModel(model_config, init_params=init_params)
1407
  else:
1408
  self.model = model
1409
+ self.mask_id = model_config.mask_token_id
1410
 
1411
  def forward(
1412
  self,
 
1420
  output_attentions: Optional[bool] = None,
1421
  output_hidden_states: Optional[bool] = None,
1422
  return_dict: Optional[bool] = None,
1423
+ last_block_logits_only: bool = False,
1424
+ block_length: int = 64,
1425
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1426
  if use_cache is None:
1427
  use_cache = self.config.use_cache
 
1440
  past_key_values=past_key_values,
1441
  use_cache=use_cache,
1442
  output_hidden_states=output_hidden_states,
1443
+ last_block_logits_only=last_block_logits_only,
1444
+ block_length=block_length,
1445
  )
1446
 
1447
  logits = outputs.logits
 
1461
  hidden_states=hidden_states,
1462
  )
1463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1464
  def get_input_embeddings(self) -> torch.nn.Module:
1465
  return self.model.transformer.wte
1466
 
 
1483
  if self.config.weight_tying:
1484
  self.model.transformer.ff_out = self.model.transformer.wte
1485
 
1486
+
1487
+ def prefill_phase(self, input_ids, block_length):
1488
+ """Prefill phase: Process initial prompt and generate KV cache."""
1489
+ with torch.no_grad():
1490
+ outputs = self(
1491
+ input_ids=input_ids,
1492
+ use_cache=True,
1493
+ return_dict=True,
1494
+ last_block_logits_only=True,
1495
+ block_length=block_length
1496
+ )
1497
+ output_past_key_values = []
1498
+ for i in range(len(outputs.past_key_values)):
1499
+ k,v = outputs.past_key_values[i]
1500
+ new_k,new_v = k[:,:,:-block_length,:],v[:,:,:-block_length,:]
1501
+ output_past_key_values.append((new_k,new_v))
1502
+ output_past_key_values = tuple(output_past_key_values)
1503
+ return {
1504
+ 'input_ids': input_ids,
1505
+ 'logits': outputs.logits,
1506
+ 'past_key_values': output_past_key_values,
1507
+ }
1508
+
1509
+ def unmask_function_greedy(self, logits, x, threshold=0.9):
1510
+ """Greedy unmasking function with confidence threshold."""
1511
+ mask_index = x == self.mask_id
1512
+ x_top_0 = torch.argmax(logits, dim=-1)
1513
+ p = F.softmax(logits, dim=-1)
1514
+ confidence = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x_top_0, -1)), -1)
1515
+ transfer_index = torch.zeros_like(x_top_0, dtype=torch.bool, device=x_top_0.device)
1516
+ confidence = torch.where(mask_index, confidence, -torch.inf)
1517
+ for j in range(confidence.shape[0]):
1518
+ mask = confidence[j] > threshold
1519
+ if mask.sum() == 0:
1520
+ max_conf_idx = torch.argmax(confidence[j])
1521
+ mask[max_conf_idx] = True
1522
+ transfer_index[j] = mask
1523
+ x[transfer_index] = x_top_0[transfer_index]
1524
+ return x
1525
+
1526
+ @torch.no_grad()
1527
+ def generate(self, input_ids, attention_mask, max_gen_length=1024, block_length=64, threshold=0.9,streaming=False,eos_token_id=126081):
1528
+ batchsize, prompt_length = input_ids.shape
1529
+ max_num_blocks = max_gen_length // block_length
1530
+ output_ids = input_ids
1531
+ block_x = torch.full((batchsize, block_length), self.mask_id, dtype=torch.long).to(self.device)
1532
+ output_ids = torch.cat([output_ids, block_x], dim=-1)
1533
+ # prefilling block loop
1534
+ prefill_outputs = self.prefill_phase(output_ids, block_length)
1535
+ past_key_values = prefill_outputs['past_key_values']
1536
+ logits = prefill_outputs['logits']
1537
+ output_ids[:,-block_length:] = self.unmask_function_greedy(logits=logits, x=output_ids[:,-block_length:], threshold=threshold)
1538
+ # decoding block loop
1539
+ for j in range(max_num_blocks):
1540
+ while (output_ids[:,-block_length:] == self.mask_id).sum():
1541
+ outputs = self(
1542
+ input_ids=output_ids[:,-block_length:],
1543
+ past_key_values=past_key_values,
1544
+ use_cache=True,
1545
+ return_dict=True
1546
+ )
1547
+ output_ids[:,-block_length:] = self.unmask_function_greedy(logits=outputs.logits, x=output_ids[:,-block_length:], threshold=threshold)
1548
+ past_key_values = outputs.past_key_values
1549
+ if streaming:
1550
+ yield output_ids[:,-block_length:]
1551
+ if (output_ids == eos_token_id).any():
1552
+ return output_ids[:, prompt_length:]
1553
+ block_x = torch.full((batchsize, block_length), self.mask_id, dtype=torch.long).to(self.device)
1554
+ output_ids = torch.cat([output_ids, block_x], dim=-1)
1555
+ return output_ids[:, prompt_length:]
1556
+
1557
  # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1558
  AutoModel.register(LLaDAConfig, LLaDAModelLM)