med-flan-t5-large / handler.py
bradleyfowler123's picture
Upload handler.py
1061bb6
raw
history blame
No virus
6.23 kB
from typing import Any, Dict, List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
MAX_TOKENS_IN_BATCH = 4_000 # Hard limit to prevent OOMs
DEFAULT_MAX_NEW_TOKENS = 10 # By default limit the output to 10 tokens
class EndpointHandler:
"""
This class is used to handle the inference with pre and post process for
text2text models. See
https://huggingface.co/docs/inference-endpoints/guides/custom_handler for
more details.
"""
def __init__(self, path: str = ""):
try:
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
except:
import accelerate
print(f"ACCELERATE VERSION: {accelerate.__version__}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
This method is called when the endpoint is called.
Arguments
---------
data (Dict[str, Any]):
Must contains the input data under `input` key and any
parameters for the inference under `parameters`.
Returns
-------
output (List[Dict[str, Any]]):
A list, length equal to the number of outputted characters,
where each item is a dictionary containing `generated_text` (i.e
the character), `perplexity` and `first_token_probs`.
"""
input_texts = data["inputs"]
generate_kwargs = data.get("parameters", {})
# This is not technically a generate_kwarg, but needs to live under parameters
check_first_tokens = generate_kwargs.pop("check_first_tokens", None)
max_new_tokens = (
generate_kwargs.pop("max_new_tokens", None) or DEFAULT_MAX_NEW_TOKENS
)
# Tokenizing input texts
inputs = self.tokenizer(
input_texts, return_tensors="pt", padding=True, truncation=True,
)["input_ids"]
# Make sure not to OOM if too many inputs
assert inputs.dim() == 2, f"Inputs have dimension {inputs.dim()} != 2"
total_tokens = inputs.shape[0] * (inputs.shape[1] + max_new_tokens - 1)
assert (
total_tokens <= MAX_TOKENS_IN_BATCH
), f"Passed {total_tokens} (shape: {inputs.shape}, max_new_tokens: {max_new_tokens}), which is greater than limit of {MAX_TOKENS_IN_BATCH}"
# Run inference on GPU
inputs = inputs.to("cuda:0")
with torch.no_grad():
outputs = self.model.generate(
inputs,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=max_new_tokens,
**generate_kwargs,
)
inputs = inputs.to("cpu")
scores = [s.to("cpu") for s in outputs.scores]
del outputs
# process outputs
to_return: Dict[str, Any] = {
"generated_text": self._output_text_from_scores(scores),
"perplexity": [float(p) for p in self._perplexity(scores)],
}
if check_first_tokens:
to_return["first_token_probs"] = self._get_first_token_probs(
check_first_tokens, scores
)
# Reformat output to conform to HF Pipeline format
return [
{key: to_return[key][ndx] for key in to_return.keys()}
for ndx in range(len(to_return["generated_text"]))
]
def _output_text_from_scores(self, scores: List[torch.Tensor]) -> List[str]:
"""
Returns the decoded text from the scores.
TODO (ENG-20823): Use the returned sequences so we pay attention to
things like bad_words, force_words etc.
"""
# Always return list format
batch_token_ids = [
[score[ndx].argmax() for score in scores]
for ndx in range(scores[0].shape[0])
]
# Fix for new tokens being generated after EOS
new_batch_token_ids = []
for token_ids in batch_token_ids:
try:
new_token_ids = token_ids[
: token_ids.index(self.tokenizer.eos_token_id)
]
except ValueError:
new_token_ids = token_ids[:-1]
new_batch_token_ids.append(new_token_ids)
return self.tokenizer.batch_decode(new_batch_token_ids)
def _perplexity(self, scores: List[torch.Tensor]) -> List[float]:
"""
Returns the perplexity (model confidence) of the outputted text.
e^( sum(ln(p(word))) / N)
TODO (ENG-20823): don't include the trailing pad tokens in perplexity
"""
return torch.exp(
torch.stack(
[score.softmax(axis=1).log().max(axis=1)[0] for score in scores]
).sum(axis=0)
/ len(scores)
).tolist()
def _get_first_token_probs(
self, tokens: List[str], scores: List[torch.Tensor]
) -> List[Dict[str, float]]:
"""
Return the softmaxed probabilities of the specific tokens for each
output
"""
first_token_probs = []
softmaxed_scores = scores[0].softmax(axis=1)
# Finding the correct token IDs
# TODO (ENG-20824): Support multi-token words
token_ids = {}
for token in tokens:
encoded_token: List[int] = self.tokenizer.encode(token)
if len(encoded_token) > 2:
# This means the tokenizer broke the token up into multiple parts
token_ids[token] = -1
else:
token_ids[token] = encoded_token[0]
# Now finding the scores for each token in the list
for seq_ndx in range(scores[0].shape[0]):
curr_token_probs: Dict[str, float] = {}
for token in tokens:
if token_ids[token] == -1:
curr_token_probs[token] = 0
else:
curr_token_probs[token] = float(
softmaxed_scores[seq_ndx, token_ids[token]]
)
first_token_probs.append(curr_token_probs)
return first_token_probs