import numpy as np import torch from transformers import AutoTokenizer, Pipeline class TextGenerationPipeline(Pipeline): def __init__(self, model, **kwargs): # type: ignore super().__init__(model=model, **kwargs) # Load tokenizers model_name = "InstaDeepAI/ChatNT" self.english_tokenizer = AutoTokenizer.from_pretrained( model_name, subfolder="english_tokenizer" ) self.bio_tokenizer = AutoTokenizer.from_pretrained( model_name, subfolder="bio_tokenizer" ) def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]: preprocess_kwargs = {} forward_kwargs = {} postprocess_kwargs = {} # type: ignore if "max_num_tokens_to_decode" in kwargs: forward_kwargs["max_num_tokens_to_decode"] = kwargs[ "max_num_tokens_to_decode" ] if "english_tokens_max_length" in kwargs: preprocess_kwargs["english_tokens_max_length"] = kwargs[ "english_tokens_max_length" ] if "bio_tokens_max_length" in kwargs: preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"] return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess( self, inputs: dict, english_tokens_max_length: int = 512, bio_tokens_max_length: int = 512, ) -> dict: english_sequence = inputs["english_sequence"] dna_sequences = inputs["dna_sequences"] context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " # noqa space = " " if english_sequence[-1] == " ": space = "" english_sequence = context + english_sequence + space + "ASSISTANT:" english_tokens = self.english_tokenizer( english_sequence, return_tensors="pt", padding="max_length", truncation=True, max_length=english_tokens_max_length, ).input_ids bio_tokens = self.bio_tokenizer( dna_sequences, return_tensors="pt", padding="max_length", max_length=bio_tokens_max_length, truncation=True, ).input_ids.unsqueeze(0) return {"english_tokens": english_tokens, "bio_tokens": bio_tokens} def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict: english_tokens = model_inputs["english_tokens"].clone() bio_tokens = model_inputs["bio_tokens"].clone() projected_bio_embeddings = None actual_num_steps = 0 with torch.no_grad(): for _ in range(max_num_tokens_to_decode): # Check if no more pad token id if ( self.english_tokenizer.pad_token_id not in english_tokens[0].cpu().numpy() ): break # Predictions outs = self.model( multi_omics_tokens_ids=(english_tokens, bio_tokens), projection_english_tokens_ids=english_tokens, projected_bio_embeddings=projected_bio_embeddings, ) projected_bio_embeddings = outs["projected_bio_embeddings"] logits = outs["logits"].detach().cpu().numpy() # Get predicted token first_idx_pad_token = np.where( english_tokens[0].cpu() == self.english_tokenizer.pad_token_id )[0][0] predicted_token = np.argmax(logits[0, first_idx_pad_token - 1]) # If it's then stop, else add the predicted token if predicted_token == self.english_tokenizer.eos_token_id: break else: english_tokens[0, first_idx_pad_token] = predicted_token actual_num_steps += 1 # Get the position where generation started idx_begin_generation = np.where( model_inputs["english_tokens"][0].cpu() == self.english_tokenizer.pad_token_id )[0][0] # Get generated tokens generated_tokens = english_tokens[ 0, idx_begin_generation : idx_begin_generation + actual_num_steps ] return { "generated_tokens": generated_tokens, } def postprocess(self, model_outputs: dict) -> str: generated_tokens = model_outputs["generated_tokens"] generated_sequence: str = self.english_tokenizer.decode(generated_tokens) return generated_sequence