from typing import Dict, List from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, ) # in line with the default config of the model CONFIG = { 'max_length': 512, 'num_return_sequences': 1, 'no_repeat_ngram_size': 2, 'top_k': 50, 'top_p': 0.95, 'do_sample': True, } class EndpointHandler: def __init__(self, path: str = ""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForSeq2SeqLM.from_pretrained(path) def __call__(self, data: Dict[str, str]) -> List[Dict[str, str]]: inputs = data.pop('inputs', None) if inputs is None or inputs == '': return [{'generated_text': 'No input provided'}] # preprocess input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids # inference output_ids = self.model.generate(input_ids, **CONFIG) # postprocess response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) return [{'generated_text': response}]