from typing import Dict, List, Any import transformers import torch from datetime import datetime from transformers import StoppingCriteria, StoppingCriteriaList from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger("transformers") class EndpointHandler(): def __init__(self, path=""): logger.info("111111111111111111111111111") logger.info(f"Hugging face handler path {path}") path = 'mosaicml/mpt-7b-instruct' #path = 'mosaicml/mpt-7b' self.model = transformers.AutoModelForCausalLM.from_pretrained(path, #"/Users/itamarlevi/Downloads/my_repo_hf/hf/mpt-7b/venv/Itamarl/test", # 'mosaicml/mpt-7b-instruct', # 'mosaicml/mpt-7b', trust_remote_code=True, torch_dtype=torch.bfloat16, max_seq_len=32000 ) self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') print("tokenizer created ", datetime.now()) stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False stopping_criteria = StoppingCriteriaList([StopOnTokens()]) self.generate_text = transformers.pipeline( model=self.model, tokenizer=self.tokenizer, stopping_criteria=stopping_criteria, task='text-generation', return_full_text=True, temperature=0.1, top_p=0.15, top_k=0, max_new_tokens=2048, repetition_penalty=1.1 ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: logger.info(f"iiinnnnnnnnnn {data}") inputs = data.pop("inputs",data) logger.info(f"iiinnnnnnnnnnbbbbbb {inputs}") res = self.generate_text(inputs) return res