File size: 4,831 Bytes
45667ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import numpy as np
import torch
from transformers import AutoTokenizer, Pipeline
class TextGenerationPipeline(Pipeline):
def __init__(self, model, **kwargs): # type: ignore
super().__init__(model=model, **kwargs)
# Load tokenizers
model_name = "InstaDeepAI/ChatNT"
self.english_tokenizer = AutoTokenizer.from_pretrained(
model_name, subfolder="english_tokenizer"
)
self.bio_tokenizer = AutoTokenizer.from_pretrained(
model_name, subfolder="bio_tokenizer"
)
def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]:
preprocess_kwargs = {}
forward_kwargs = {}
postprocess_kwargs = {} # type: ignore
if "max_num_tokens_to_decode" in kwargs:
forward_kwargs["max_num_tokens_to_decode"] = kwargs[
"max_num_tokens_to_decode"
]
if "english_tokens_max_length" in kwargs:
preprocess_kwargs["english_tokens_max_length"] = kwargs[
"english_tokens_max_length"
]
if "bio_tokens_max_length" in kwargs:
preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"]
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
inputs: dict,
english_tokens_max_length: int = 512,
bio_tokens_max_length: int = 512,
) -> dict:
english_sequence = inputs["english_sequence"]
dna_sequences = inputs["dna_sequences"]
context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " # noqa
space = " "
if english_sequence[-1] == " ":
space = ""
english_sequence = context + english_sequence + space + "ASSISTANT:"
english_tokens = self.english_tokenizer(
english_sequence,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=english_tokens_max_length,
).input_ids
bio_tokens = self.bio_tokenizer(
dna_sequences,
return_tensors="pt",
padding="max_length",
max_length=bio_tokens_max_length,
truncation=True,
).input_ids.unsqueeze(0)
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
english_tokens = model_inputs["english_tokens"].clone()
bio_tokens = model_inputs["bio_tokens"].clone()
projected_bio_embeddings = None
actual_num_steps = 0
with torch.no_grad():
for _ in range(max_num_tokens_to_decode):
# Check if no more pad token id
if (
self.english_tokenizer.pad_token_id
not in english_tokens[0].cpu().numpy()
):
break
# Predictions
outs = self.model(
multi_omics_tokens_ids=(english_tokens, bio_tokens),
projection_english_tokens_ids=english_tokens,
projected_bio_embeddings=projected_bio_embeddings,
)
projected_bio_embeddings = outs["projected_bio_embeddings"]
logits = outs["logits"].detach().cpu().numpy()
# Get predicted token
first_idx_pad_token = np.where(
english_tokens[0].cpu() == self.english_tokenizer.pad_token_id
)[0][0]
predicted_token = np.argmax(logits[0, first_idx_pad_token - 1])
# If it's <eos> then stop, else add the predicted token
if predicted_token == self.english_tokenizer.eos_token_id:
break
else:
english_tokens[0, first_idx_pad_token] = predicted_token
actual_num_steps += 1
# Get the position where generation started
idx_begin_generation = np.where(
model_inputs["english_tokens"][0].cpu()
== self.english_tokenizer.pad_token_id
)[0][0]
# Get generated tokens
generated_tokens = english_tokens[
0, idx_begin_generation : idx_begin_generation + actual_num_steps
]
return {
"generated_tokens": generated_tokens,
}
def postprocess(self, model_outputs: dict) -> str:
generated_tokens = model_outputs["generated_tokens"]
generated_sequence: str = self.english_tokenizer.decode(generated_tokens)
return generated_sequence
|