File size: 2,894 Bytes
bf66e5a 366e62e bf66e5a 833b301 b179918 000ad8b bf66e5a 216cf30 bf66e5a 4f7f1d5 bf66e5a 4c4f932 bf66e5a 000ad8b 0e224b1 b29c898 4c4f932 0e224b1 4c4f932 66e62c6 000ad8b b179918 6d8b690 8cda5e7 000ad8b 8cda5e7 6d8b690 000ad8b 6d8b690 000ad8b 02ffbef e1040a6 b179918 000ad8b cd860a6 000ad8b 5e35500 392d92f 833b301 bf66e5a 366e62e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
class EndpointHandler():
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
self.tokenizer = tokenizer
self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
# 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
# 13 is a newline character
# [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
# [2087, 29885, 4430, 29889] is "Admitted."
bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
bad_words_ids.extend(additional_bad_words_ids)
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
max_generation_length = 75 # Desired number of tokens to generate
# max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
# # Truncate input_ids to the most recent tokens that fit within the max_input_length
# if input_ids.shape[1] > max_input_length:
# input_ids = input_ids[:, -max_input_length:]
max_length = input_ids.shape[1] + max_generation_length
generated_ids = self.model.generate(
input_ids,
max_length=max_length, # 50 new tokens
bad_words_ids=bad_words_ids,
temperature=1,
top_k=40,
do_sample=True,
stopping_criteria=self.stopping_criteria,
)
generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
return prediction
class StopAtPeriodCriteria(StoppingCriteria):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Decode the last generated token to text
last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
# Check if the decoded text ends with a period
return '.' in last_token_text |