onisj's picture
updated
0243724 verified
raw
history blame
3.91 kB
from typing import Dict, List, Any
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer
import torch
import os
class EndpointHandler:
def __init__(self, path=""):
self.path = path
self.task = self._determine_task()
if self.task == "text-generation":
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1
)
elif self.task == "text-classification":
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pipeline = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1
)
elif self.task == "sentence-embedding":
self.model = SentenceTransformer(path)
else:
raise ValueError(f"Unsupported task: {self.task} for model at {path}")
def _determine_task(self):
# Load config to determine model_type
config_path = os.path.join(self.path, "config.json")
if not os.path.exists(config_path):
raise ValueError(f"config.json not found in {self.path}")
config = AutoConfig.from_pretrained(self.path)
model_type = config.model_type if hasattr(config, "model_type") else None
# Map model_type or model name to tasks
text_generation_types = ["gpt2"]
text_classification_types = ["bert", "distilbert", "roberta"]
embedding_types = ["bert"] # Sentence-BERT models use bert model_type
model_name = self.path.split("/")[-1]
if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2"]:
return "text-generation"
elif model_type in text_classification_types or model_name in ["emotion_classifier", "intent_classifier", "intent_fallback"]:
return "text-classification"
elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path):
return "sentence-embedding"
elif model_type in text_classification_types and model_name == "emotion_model":
# Handle emotion_model, which may be classification or generation
return "text-classification" # Assume classification; adjust if needed
raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", "")
if not inputs:
return [{"error": "No inputs provided"}]
try:
if self.task == "text-generation":
result = self.pipeline(inputs, max_length=50, num_return_sequences=1)
return [{"generated_text": item["generated_text"]} for item in result]
elif self.task == "text-classification":
result = self.pipeline(inputs, return_all_scores=True)
return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist]
elif self.task == "sentence-embedding":
embeddings = self.model.encode(inputs)
return [{"embeddings": embeddings.tolist()}]
return [{"error": f"Unsupported task: {self.task}"}]
except Exception as e:
return [{"error": f"Inference failed: {str(e)}"}]