import torch from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # check for GPU device = 0 if torch.cuda.is_available() else -1 # multi-model list multi_model_list = [ {"model_id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"}, {"model_id": "Helsinki-NLP/opus-mt-en-de", "task": "translation"}, {"model_id": "facebook/bart-large-cnn", "task": "summarization"}, {"model_id": "dslim/bert-base-NER", "task": "token-classification"}, {"model_id": "textattack/bert-base-uncased-ag-news", "task": "text-classification"}, ] class EndpointHandler(): def __init__(self, path=""): self.multi_model={} # load all the models onto device for model in multi_model_list: self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: # deserialize incomin request inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) model_id = data.pop("model_id", None) # check if model_id is in the list of models if model_id is None or model_id not in self.multi_model: raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}") # pass inputs with all kwargs in data if parameters is not None: prediction = self.multi_model[model_id](inputs, **parameters) else: prediction = self.multi_model[model_id](inputs) # postprocess the prediction return prediction