Update modeling_llada.py
Browse filesupdate generate_function with streaming output
- modeling_llada.py +81 -31
 
    	
        modeling_llada.py
    CHANGED
    
    | 
         @@ -1181,7 +1181,8 @@ class LLaDAModel(nn.Module): 
     | 
|
| 1181 | 
         
             
                    attention_bias: Optional[torch.Tensor] = None,
         
     | 
| 1182 | 
         
             
                    past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
         
     | 
| 1183 | 
         
             
                    use_cache: bool = False,
         
     | 
| 1184 | 
         
            -
                     
     | 
| 
         | 
|
| 1185 | 
         
             
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 1186 | 
         
             
                ) -> LLaDAOutput:
         
     | 
| 1187 | 
         
             
                    """
         
     | 
| 
         @@ -1351,10 +1352,9 @@ class LLaDAModel(nn.Module): 
     | 
|
| 1351 | 
         
             
                                assert cache is not None
         
     | 
| 1352 | 
         
             
                                attn_key_values.extend(cache)
         
     | 
| 1353 | 
         | 
| 1354 | 
         
            -
                    if  
     | 
| 1355 | 
         
            -
                        # shape: (batch_size,  
     | 
| 1356 | 
         
            -
                        x = x[:, - 
     | 
| 1357 | 
         
            -
             
     | 
| 1358 | 
         
             
                    # Apply final layer norm.
         
     | 
| 1359 | 
         
             
                    # shape: (batch_size, seq_len or 1, d_model)
         
     | 
| 1360 | 
         
             
                    x = self.transformer.ln_f(x)  # type: ignore
         
     | 
| 
         @@ -1406,6 +1406,7 @@ class LLaDAModelLM(PreTrainedModel): 
     | 
|
| 1406 | 
         
             
                        self.model = LLaDAModel(model_config, init_params=init_params)
         
     | 
| 1407 | 
         
             
                    else:
         
     | 
| 1408 | 
         
             
                        self.model = model
         
     | 
| 
         | 
|
| 1409 | 
         | 
| 1410 | 
         
             
                def forward(
         
     | 
| 1411 | 
         
             
                    self,
         
     | 
| 
         @@ -1419,7 +1420,8 @@ class LLaDAModelLM(PreTrainedModel): 
     | 
|
| 1419 | 
         
             
                    output_attentions: Optional[bool] = None,
         
     | 
| 1420 | 
         
             
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 1421 | 
         
             
                    return_dict: Optional[bool] = None,
         
     | 
| 1422 | 
         
            -
                     
     | 
| 
         | 
|
| 1423 | 
         
             
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         
     | 
| 1424 | 
         
             
                    if use_cache is None:
         
     | 
| 1425 | 
         
             
                        use_cache = self.config.use_cache
         
     | 
| 
         @@ -1438,6 +1440,8 @@ class LLaDAModelLM(PreTrainedModel): 
     | 
|
| 1438 | 
         
             
                        past_key_values=past_key_values,
         
     | 
| 1439 | 
         
             
                        use_cache=use_cache,
         
     | 
| 1440 | 
         
             
                        output_hidden_states=output_hidden_states,
         
     | 
| 
         | 
|
| 
         | 
|
| 1441 | 
         
             
                    )
         
     | 
| 1442 | 
         | 
| 1443 | 
         
             
                    logits = outputs.logits
         
     | 
| 
         @@ -1457,31 +1461,6 @@ class LLaDAModelLM(PreTrainedModel): 
     | 
|
| 1457 | 
         
             
                        hidden_states=hidden_states,
         
     | 
| 1458 | 
         
             
                    )
         
     | 
| 1459 | 
         | 
| 1460 | 
         
            -
                def can_generate(self) -> bool:
         
     | 
| 1461 | 
         
            -
                    return True
         
     | 
| 1462 | 
         
            -
             
     | 
| 1463 | 
         
            -
                def prepare_inputs_for_generation(
         
     | 
| 1464 | 
         
            -
                    self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
         
     | 
| 1465 | 
         
            -
                ):
         
     | 
| 1466 | 
         
            -
                    if past_key_values:
         
     | 
| 1467 | 
         
            -
                        # This is because we want the model to only process the last generated token.
         
     | 
| 1468 | 
         
            -
                        input_ids = input_ids[:, -1:]
         
     | 
| 1469 | 
         
            -
                    model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
         
     | 
| 1470 | 
         
            -
             
     | 
| 1471 | 
         
            -
                    model_inputs.update(kwargs)
         
     | 
| 1472 | 
         
            -
                    model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
         
     | 
| 1473 | 
         
            -
                    return model_inputs
         
     | 
| 1474 | 
         
            -
             
     | 
| 1475 | 
         
            -
                # TODO: these are required to make the implementation complete.
         
     | 
| 1476 | 
         
            -
                # def resize_position_embeddings(self, new_num_position_embeddings: int):
         
     | 
| 1477 | 
         
            -
                #     pass
         
     | 
| 1478 | 
         
            -
                #
         
     | 
| 1479 | 
         
            -
                # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
         
     | 
| 1480 | 
         
            -
                #     pass
         
     | 
| 1481 | 
         
            -
                #
         
     | 
| 1482 | 
         
            -
                # def _reorder_cache(self, past_key_values, beam_idx):
         
     | 
| 1483 | 
         
            -
                #     pass
         
     | 
| 1484 | 
         
            -
             
     | 
| 1485 | 
         
             
                def get_input_embeddings(self) -> torch.nn.Module:
         
     | 
| 1486 | 
         
             
                    return self.model.transformer.wte
         
     | 
| 1487 | 
         | 
| 
         @@ -1504,5 +1483,76 @@ class LLaDAModelLM(PreTrainedModel): 
     | 
|
| 1504 | 
         
             
                    if self.config.weight_tying:
         
     | 
| 1505 | 
         
             
                        self.model.transformer.ff_out = self.model.transformer.wte
         
     | 
| 1506 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1507 | 
         
             
            # Register the model so that it is available for transformer pipelines, auto-loading, etc.
         
     | 
| 1508 | 
         
             
            AutoModel.register(LLaDAConfig, LLaDAModelLM)
         
     | 
| 
         | 
|
| 1181 | 
         
             
                    attention_bias: Optional[torch.Tensor] = None,
         
     | 
| 1182 | 
         
             
                    past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
         
     | 
| 1183 | 
         
             
                    use_cache: bool = False,
         
     | 
| 1184 | 
         
            +
                    last_block_logits_only: bool = False,
         
     | 
| 1185 | 
         
            +
                    block_length: int = 64,
         
     | 
| 1186 | 
         
             
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 1187 | 
         
             
                ) -> LLaDAOutput:
         
     | 
| 1188 | 
         
             
                    """
         
     | 
| 
         | 
|
| 1352 | 
         
             
                                assert cache is not None
         
     | 
| 1353 | 
         
             
                                attn_key_values.extend(cache)
         
     | 
| 1354 | 
         | 
| 1355 | 
         
            +
                    if last_block_logits_only:
         
     | 
| 1356 | 
         
            +
                        # shape: (batch_size, block_length, d_model)
         
     | 
| 1357 | 
         
            +
                        x = x[:, -block_length:, :]
         
     | 
| 
         | 
|
| 1358 | 
         
             
                    # Apply final layer norm.
         
     | 
| 1359 | 
         
             
                    # shape: (batch_size, seq_len or 1, d_model)
         
     | 
| 1360 | 
         
             
                    x = self.transformer.ln_f(x)  # type: ignore
         
     | 
| 
         | 
|
| 1406 | 
         
             
                        self.model = LLaDAModel(model_config, init_params=init_params)
         
     | 
| 1407 | 
         
             
                    else:
         
     | 
| 1408 | 
         
             
                        self.model = model
         
     | 
| 1409 | 
         
            +
                    self.mask_id = model_config.mask_token_id
         
     | 
| 1410 | 
         | 
| 1411 | 
         
             
                def forward(
         
     | 
| 1412 | 
         
             
                    self,
         
     | 
| 
         | 
|
| 1420 | 
         
             
                    output_attentions: Optional[bool] = None,
         
     | 
| 1421 | 
         
             
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 1422 | 
         
             
                    return_dict: Optional[bool] = None,
         
     | 
| 1423 | 
         
            +
                    last_block_logits_only: bool = False,
         
     | 
| 1424 | 
         
            +
                    block_length: int = 64,
         
     | 
| 1425 | 
         
             
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         
     | 
| 1426 | 
         
             
                    if use_cache is None:
         
     | 
| 1427 | 
         
             
                        use_cache = self.config.use_cache
         
     | 
| 
         | 
|
| 1440 | 
         
             
                        past_key_values=past_key_values,
         
     | 
| 1441 | 
         
             
                        use_cache=use_cache,
         
     | 
| 1442 | 
         
             
                        output_hidden_states=output_hidden_states,
         
     | 
| 1443 | 
         
            +
                        last_block_logits_only=last_block_logits_only,
         
     | 
| 1444 | 
         
            +
                        block_length=block_length,
         
     | 
| 1445 | 
         
             
                    )
         
     | 
| 1446 | 
         | 
| 1447 | 
         
             
                    logits = outputs.logits
         
     | 
| 
         | 
|
| 1461 | 
         
             
                        hidden_states=hidden_states,
         
     | 
| 1462 | 
         
             
                    )
         
     | 
| 1463 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1464 | 
         
             
                def get_input_embeddings(self) -> torch.nn.Module:
         
     | 
| 1465 | 
         
             
                    return self.model.transformer.wte
         
     | 
| 1466 | 
         | 
| 
         | 
|
| 1483 | 
         
             
                    if self.config.weight_tying:
         
     | 
| 1484 | 
         
             
                        self.model.transformer.ff_out = self.model.transformer.wte
         
     | 
| 1485 | 
         | 
| 1486 | 
         
            +
             
     | 
| 1487 | 
         
            +
                def prefill_phase(self, input_ids, block_length):
         
     | 
| 1488 | 
         
            +
                    """Prefill phase: Process initial prompt and generate KV cache."""
         
     | 
| 1489 | 
         
            +
                    with torch.no_grad():
         
     | 
| 1490 | 
         
            +
                        outputs = self(
         
     | 
| 1491 | 
         
            +
                            input_ids=input_ids,
         
     | 
| 1492 | 
         
            +
                            use_cache=True,
         
     | 
| 1493 | 
         
            +
                            return_dict=True,
         
     | 
| 1494 | 
         
            +
                            last_block_logits_only=True,
         
     | 
| 1495 | 
         
            +
                            block_length=block_length
         
     | 
| 1496 | 
         
            +
                        )
         
     | 
| 1497 | 
         
            +
                    output_past_key_values = []
         
     | 
| 1498 | 
         
            +
                    for i in range(len(outputs.past_key_values)):
         
     | 
| 1499 | 
         
            +
                        k,v = outputs.past_key_values[i]
         
     | 
| 1500 | 
         
            +
                        new_k,new_v = k[:,:,:-block_length,:],v[:,:,:-block_length,:]
         
     | 
| 1501 | 
         
            +
                        output_past_key_values.append((new_k,new_v))
         
     | 
| 1502 | 
         
            +
                    output_past_key_values = tuple(output_past_key_values)
         
     | 
| 1503 | 
         
            +
                    return {
         
     | 
| 1504 | 
         
            +
                        'input_ids': input_ids,
         
     | 
| 1505 | 
         
            +
                        'logits': outputs.logits,
         
     | 
| 1506 | 
         
            +
                        'past_key_values': output_past_key_values,
         
     | 
| 1507 | 
         
            +
                    }
         
     | 
| 1508 | 
         
            +
             
     | 
| 1509 | 
         
            +
                def unmask_function_greedy(self, logits, x, threshold=0.9):
         
     | 
| 1510 | 
         
            +
                    """Greedy unmasking function with confidence threshold."""
         
     | 
| 1511 | 
         
            +
                    mask_index = x == self.mask_id
         
     | 
| 1512 | 
         
            +
                    x_top_0 = torch.argmax(logits, dim=-1)
         
     | 
| 1513 | 
         
            +
                    p = F.softmax(logits, dim=-1)
         
     | 
| 1514 | 
         
            +
                    confidence = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x_top_0, -1)), -1)
         
     | 
| 1515 | 
         
            +
                    transfer_index = torch.zeros_like(x_top_0, dtype=torch.bool, device=x_top_0.device)
         
     | 
| 1516 | 
         
            +
                    confidence = torch.where(mask_index, confidence, -torch.inf)
         
     | 
| 1517 | 
         
            +
                    for j in range(confidence.shape[0]):
         
     | 
| 1518 | 
         
            +
                        mask = confidence[j] > threshold
         
     | 
| 1519 | 
         
            +
                        if mask.sum() == 0:
         
     | 
| 1520 | 
         
            +
                            max_conf_idx = torch.argmax(confidence[j])
         
     | 
| 1521 | 
         
            +
                            mask[max_conf_idx] = True
         
     | 
| 1522 | 
         
            +
                        transfer_index[j] = mask
         
     | 
| 1523 | 
         
            +
                    x[transfer_index] = x_top_0[transfer_index]
         
     | 
| 1524 | 
         
            +
                    return x
         
     | 
| 1525 | 
         
            +
             
     | 
| 1526 | 
         
            +
                @torch.no_grad()
         
     | 
| 1527 | 
         
            +
                def generate(self, input_ids, attention_mask, max_gen_length=1024, block_length=64, threshold=0.9,streaming=False,eos_token_id=126081):
         
     | 
| 1528 | 
         
            +
                    batchsize, prompt_length = input_ids.shape
         
     | 
| 1529 | 
         
            +
                    max_num_blocks = max_gen_length // block_length
         
     | 
| 1530 | 
         
            +
                    output_ids = input_ids
         
     | 
| 1531 | 
         
            +
                    block_x = torch.full((batchsize, block_length), self.mask_id, dtype=torch.long).to(self.device)
         
     | 
| 1532 | 
         
            +
                    output_ids = torch.cat([output_ids, block_x], dim=-1)
         
     | 
| 1533 | 
         
            +
                    # prefilling block loop
         
     | 
| 1534 | 
         
            +
                    prefill_outputs = self.prefill_phase(output_ids, block_length)
         
     | 
| 1535 | 
         
            +
                    past_key_values = prefill_outputs['past_key_values']
         
     | 
| 1536 | 
         
            +
                    logits = prefill_outputs['logits']
         
     | 
| 1537 | 
         
            +
                    output_ids[:,-block_length:] = self.unmask_function_greedy(logits=logits, x=output_ids[:,-block_length:], threshold=threshold)
         
     | 
| 1538 | 
         
            +
                    # decoding block loop
         
     | 
| 1539 | 
         
            +
                    for j in range(max_num_blocks):  
         
     | 
| 1540 | 
         
            +
                        while (output_ids[:,-block_length:] == self.mask_id).sum():
         
     | 
| 1541 | 
         
            +
                            outputs = self(
         
     | 
| 1542 | 
         
            +
                                input_ids=output_ids[:,-block_length:],
         
     | 
| 1543 | 
         
            +
                                past_key_values=past_key_values,
         
     | 
| 1544 | 
         
            +
                                use_cache=True,
         
     | 
| 1545 | 
         
            +
                                return_dict=True
         
     | 
| 1546 | 
         
            +
                            )
         
     | 
| 1547 | 
         
            +
                            output_ids[:,-block_length:] = self.unmask_function_greedy(logits=outputs.logits, x=output_ids[:,-block_length:], threshold=threshold)
         
     | 
| 1548 | 
         
            +
                        past_key_values = outputs.past_key_values
         
     | 
| 1549 | 
         
            +
                        if streaming:
         
     | 
| 1550 | 
         
            +
                            yield output_ids[:,-block_length:]
         
     | 
| 1551 | 
         
            +
                        if (output_ids == eos_token_id).any():
         
     | 
| 1552 | 
         
            +
                            return output_ids[:, prompt_length:]
         
     | 
| 1553 | 
         
            +
                        block_x = torch.full((batchsize, block_length), self.mask_id, dtype=torch.long).to(self.device)
         
     | 
| 1554 | 
         
            +
                        output_ids = torch.cat([output_ids, block_x], dim=-1)
         
     | 
| 1555 | 
         
            +
                    return output_ids[:, prompt_length:]
         
     | 
| 1556 | 
         
            +
             
     | 
| 1557 | 
         
             
            # Register the model so that it is available for transformer pipelines, auto-loading, etc.
         
     | 
| 1558 | 
         
             
            AutoModel.register(LLaDAConfig, LLaDAModelLM)
         
     |