onisj commited on
Commit
0243724
·
verified ·
1 Parent(s): 78f5a83
Files changed (1) hide show
  1. merged_distilgpt2/handler.py +39 -33
merged_distilgpt2/handler.py CHANGED
@@ -1,11 +1,13 @@
1
  from typing import Dict, List, Any
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
3
  from sentence_transformers import SentenceTransformer
4
  import torch
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- self.task = self._determine_task(path)
 
9
  if self.task == "text-generation":
10
  self.model = AutoModelForCausalLM.from_pretrained(path)
11
  self.tokenizer = AutoTokenizer.from_pretrained(path)
@@ -27,45 +29,49 @@ class EndpointHandler:
27
  elif self.task == "sentence-embedding":
28
  self.model = SentenceTransformer(path)
29
  else:
30
- raise ValueError(f"Unsupported task: {self.task}")
31
 
32
- def _determine_task(self, path):
33
- model_name = path.split("/")[-1]
34
- text_generation_models = [
35
- "distilgpt2",
36
- "fine_tuned_distilgpt2_lora",
37
- "fine_tuned_gpt2",
38
- "merged_distilgpt2",
39
- "gpt2"
40
- ]
41
- text_classification_models = [
42
- "emotion_classifier",
43
- "emotion_model",
44
- "intent_classifier",
45
- "intent_fallback"
46
- ]
47
- embedding_models = ["intent_encoder", "sentence_transformer"]
48
 
49
- if model_name in text_generation_models:
 
 
 
 
 
 
 
 
 
50
  return "text-generation"
51
- elif model_name in text_classification_models:
52
  return "text-classification"
53
- elif model_name in embedding_models:
54
  return "sentence-embedding"
55
- return "text-generation" # Default
 
 
 
56
 
57
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
58
  inputs = data.get("inputs", "")
59
  if not inputs:
60
  return [{"error": "No inputs provided"}]
61
 
62
- if self.task == "text-generation":
63
- result = self.pipeline(inputs, max_length=50, num_return_sequences=1)
64
- return [{"generated_text": item["generated_text"]} for item in result]
65
- elif self.task == "text-classification":
66
- result = self.pipeline(inputs, return_all_scores=True)
67
- return result
68
- elif self.task == "sentence-embedding":
69
- embeddings = self.model.encode(inputs)
70
- return [{"embeddings": embeddings.tolist()}]
71
- return [{"error": f"Unsupported task: {self.task}"}]
 
 
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
3
  from sentence_transformers import SentenceTransformer
4
  import torch
5
+ import os
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
+ self.path = path
10
+ self.task = self._determine_task()
11
  if self.task == "text-generation":
12
  self.model = AutoModelForCausalLM.from_pretrained(path)
13
  self.tokenizer = AutoTokenizer.from_pretrained(path)
 
29
  elif self.task == "sentence-embedding":
30
  self.model = SentenceTransformer(path)
31
  else:
32
+ raise ValueError(f"Unsupported task: {self.task} for model at {path}")
33
 
34
+ def _determine_task(self):
35
+ # Load config to determine model_type
36
+ config_path = os.path.join(self.path, "config.json")
37
+ if not os.path.exists(config_path):
38
+ raise ValueError(f"config.json not found in {self.path}")
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ config = AutoConfig.from_pretrained(self.path)
41
+ model_type = config.model_type if hasattr(config, "model_type") else None
42
+
43
+ # Map model_type or model name to tasks
44
+ text_generation_types = ["gpt2"]
45
+ text_classification_types = ["bert", "distilbert", "roberta"]
46
+ embedding_types = ["bert"] # Sentence-BERT models use bert model_type
47
+
48
+ model_name = self.path.split("/")[-1]
49
+ if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2"]:
50
  return "text-generation"
51
+ elif model_type in text_classification_types or model_name in ["emotion_classifier", "intent_classifier", "intent_fallback"]:
52
  return "text-classification"
53
+ elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path):
54
  return "sentence-embedding"
55
+ elif model_type in text_classification_types and model_name == "emotion_model":
56
+ # Handle emotion_model, which may be classification or generation
57
+ return "text-classification" # Assume classification; adjust if needed
58
+ raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}")
59
 
60
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
61
  inputs = data.get("inputs", "")
62
  if not inputs:
63
  return [{"error": "No inputs provided"}]
64
 
65
+ try:
66
+ if self.task == "text-generation":
67
+ result = self.pipeline(inputs, max_length=50, num_return_sequences=1)
68
+ return [{"generated_text": item["generated_text"]} for item in result]
69
+ elif self.task == "text-classification":
70
+ result = self.pipeline(inputs, return_all_scores=True)
71
+ return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist]
72
+ elif self.task == "sentence-embedding":
73
+ embeddings = self.model.encode(inputs)
74
+ return [{"embeddings": embeddings.tolist()}]
75
+ return [{"error": f"Unsupported task: {self.task}"}]
76
+ except Exception as e:
77
+ return [{"error": f"Inference failed: {str(e)}"}]