| from typing import List, Optional, Union, Dict, Any, Tuple |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from transformers import AutoTokenizer |
| from transformers.processing_utils import ( |
| CommonKwargs, |
| ProcessingKwargs, |
| ProcessorMixin, |
| Unpack, |
| ) |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| from transformers.utils import logging |
|
|
| from bioreason.utils.dna_utils import DNAInput |
|
|
| class DLDNAKwargs(CommonKwargs): |
| """Keyword arguments specific to DNA processing""" |
| max_length_text: Optional[int] |
| max_length_dna: Optional[int] |
|
|
|
|
| class DLProcessorKwargs(ProcessingKwargs, total=False): |
| """Processing keyword arguments for the DL processor""" |
| dna_kwargs: DLDNAKwargs |
| _defaults = { |
| "text_kwargs": { |
| "padding": False, |
| }, |
| } |
|
|
| class DLProcessor(ProcessorMixin): |
| r""" |
| Constructs a DL processor which wraps a NucleotideTransformer DNA processor and a Qwen2_5 tokenizer into a single processor. |
| This processor handles both text and DNA sequence processing to prepare inputs for the DNALLMModel. |
| |
| Args: |
| tokenizer (PreTrainedTokenizerBase, *optional*): |
| The text tokenizer used for processing text inputs. |
| dna_tokenizer (PreTrainedTokenizerBase, *optional*): |
| The DNA tokenizer used for processing DNA sequences. |
| chat_template (`str`, *optional*): |
| A Jinja template for chat formatting. If None, will use the tokenizer's template. |
| """ |
|
|
| attributes = ["tokenizer", "dna_tokenizer"] |
| valid_kwargs = ["model", "chat_template"] |
| tokenizer_class = ( |
| "Qwen2Tokenizer", "Qwen2TokenizerFast", |
| "GPT2TokenizerFast", |
| ) |
| dna_tokenizer_class = ("EsmTokenizer", "Evo2Tokenizer") |
|
|
| def __init__( |
| self, tokenizer=None, dna_tokenizer=None, chat_template=None, **kwargs |
| ): |
| """ |
| Initialize the processor with text and DNA tokenizers. |
| |
| Args: |
| tokenizer: Text tokenizer (usually from a language model) |
| dna_tokenizer: DNA tokenizer (usually from a DNA model) |
| chat_template: Template for formatting chat conversations |
| **kwargs: Additional arguments |
| """ |
| self.tokenizer = tokenizer |
| self.dna_tokenizer = dna_tokenizer |
|
|
| self.dna_token = ( |
| "<|dna_pad|>" |
| if not hasattr(self.tokenizer, "dna_token") |
| else self.tokenizer.dna_token |
| ) |
| |
| |
| if chat_template is None and hasattr(self.tokenizer, "chat_template"): |
| chat_template = self.tokenizer.chat_template |
| super().__init__(tokenizer, dna_tokenizer, chat_template=chat_template) |
| |
| |
| if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def tokenize_dna_sequences( |
| self, |
| batch_dna_sequences: List[List[str]], |
| max_length: int = 2048, |
| return_tensors: str = "pt", |
| device: str = "cuda", |
| ) -> Dict[str, Any]: |
| """ |
| Tokenize a batch of DNA sequences. |
| |
| Args: |
| batch_dna_sequences: List of lists of DNA sequences per batch item |
| max_length: Maximum allowed length for DNA sequences |
| return_tensors: Return format for tensors ("pt" for PyTorch) |
| device: Device to place tensors on |
| |
| Returns: |
| Dict containing: |
| - dna_tokenized: The tokenized DNA sequences |
| - batch_idx_map: Mapping of which sequences belong to which batch item |
| """ |
| |
| batch_idx_map = [] |
| all_sequences = [] |
|
|
| |
| for batch_idx, dna_sequences in enumerate(batch_dna_sequences): |
| for seq in dna_sequences: |
| all_sequences.append(seq) |
| batch_idx_map.append(batch_idx) |
|
|
| |
| if not all_sequences: |
| return {"dna_tokenized": None, "batch_idx_map": []} |
|
|
| |
| dna_tokenized = self.dna_tokenizer( |
| all_sequences, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors=return_tensors, |
| return_attention_mask=True, |
| ) |
| |
| return {"dna_tokenized": dna_tokenized, "batch_idx_map": batch_idx_map} |
|
|
| def __call__( |
| self, |
| batch_dna_sequences: Optional[List[List[str]]] = None, |
| text: Optional[ |
| Union[ |
| TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] |
| ] |
| ] = None, |
| max_length_text: int = 512, |
| max_length_dna: int = 2048, |
| return_tensors: str = "pt", |
| device: str = "cuda", |
| **kwargs: Unpack[DLProcessorKwargs], |
| ) -> BatchFeature: |
| """ |
| Process text and DNA sequences for model input. |
| |
| Args: |
| batch_dna_sequences: List of lists of DNA sequences per batch item |
| text: Input text or list of texts |
| max_length_text: Maximum length for text sequences |
| max_length_dna: Maximum length for DNA sequences |
| return_tensors: Return format for tensors |
| device: Device to place tensors on |
| **kwargs: Additional processor keyword arguments |
| |
| Returns: |
| BatchFeature with tokenized inputs for the model |
| """ |
| output_kwargs = self._merge_kwargs( |
| DLProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| |
| if not isinstance(text, list): |
| text = [text] |
|
|
| |
| dna_inputs = {} |
| if batch_dna_sequences is not None: |
| |
| dna_processing_result = self.tokenize_dna_sequences( |
| batch_dna_sequences, |
| max_length=max_length_dna, |
| return_tensors=return_tensors, |
| device=device, |
| ) |
| |
| |
| index = 0 |
| for i in range(len(text)): |
| while self.dna_token in text[i]: |
| num_dna_tokens = (dna_processing_result['dna_tokenized']['input_ids'][index] != 1).sum().item() |
| text[i] = text[i].replace( |
| self.dna_token, "<|placeholder|>" * num_dna_tokens, 1 |
| ) |
| index += 1 |
| text[i] = text[i].replace("<|placeholder|>", self.dna_token) |
| |
| |
| |
| |
| dna_inputs = { |
| |
| "dna_tokenized": dna_processing_result["dna_tokenized"], |
| "batch_idx_map": dna_processing_result["batch_idx_map"], |
| } |
|
|
| |
| text_kwargs = output_kwargs.get("text_kwargs", {}) |
| |
| if 'padding' in text_kwargs: |
| del text_kwargs['padding'] |
| |
| |
| text_inputs = self.tokenizer( |
| text, |
| max_length=max_length_text + 2 * max_length_dna, |
| return_tensors=return_tensors, |
| padding=True, |
| truncation=True, |
| **text_kwargs, |
| ) |
| |
| |
| return BatchFeature(data={**text_inputs, **dna_inputs}) |
|
|
| def batch_decode(self, *args, **kwargs) -> List[str]: |
| """ |
| This method forwards all its arguments to the tokenizer's batch_decode. |
| |
| Returns: |
| List of decoded strings |
| """ |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs) -> str: |
| """ |
| This method forwards all its arguments to the tokenizer's decode. |
| |
| Returns: |
| Decoded string |
| """ |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| def post_process_dna_to_text( |
| self, |
| generated_outputs: torch.Tensor, |
| skip_special_tokens: bool = True, |
| **kwargs, |
| ) -> List[str]: |
| """ |
| Post-process the model output to decode the text. |
| |
| Args: |
| generated_outputs: The token IDs generated by the model |
| skip_special_tokens: Whether to skip special tokens in the output |
| **kwargs: Additional arguments for the decoder |
| |
| Returns: |
| List of decoded strings |
| """ |
| return self.tokenizer.batch_decode( |
| generated_outputs, |
| skip_special_tokens=skip_special_tokens, |
| **kwargs, |
| ) |
|
|
| @property |
| def model_input_names(self) -> List[str]: |
| """ |
| Get the input names expected by the model. |
| |
| Returns: |
| List of input names |
| """ |
| tokenizer_input_names = self.tokenizer.model_input_names |
| dna_input_names = ["dna_tokenized", "batch_idx_map"] |
| |
| return list(dict.fromkeys(tokenizer_input_names + dna_input_names)) |
|
|