update handler
Browse files- handler.py +33 -28
handler.py
CHANGED
@@ -1,48 +1,53 @@
|
|
1 |
-
import torch
|
2 |
from typing import Dict, List, Any
|
|
|
|
|
|
|
|
|
3 |
from transformers import (
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
pipeline
|
9 |
)
|
10 |
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
|
11 |
|
12 |
class EndpointHandler():
|
13 |
def __init__(self, path=""):
|
14 |
-
|
15 |
-
peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
|
16 |
language = "Chinese"
|
17 |
-
task = "transcribe"
|
18 |
-
peft_config = PeftConfig.from_pretrained(
|
19 |
-
model= WhisperForConditionalGeneration.from_pretrained(
|
20 |
peft_config.base_model_name_or_path
|
21 |
)
|
22 |
-
model = PeftModel.from_pretrained(model,
|
23 |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
24 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
25 |
feature_extractor = processor.feature_extractor
|
26 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
27 |
-
|
28 |
-
self.pipeline = pipeline(
|
29 |
-
self.pipeline.model.
|
30 |
-
|
31 |
-
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
32 |
-
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
33 |
-
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
34 |
-
|
35 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
36 |
"""
|
37 |
-
|
38 |
-
inputs (:obj: `str`
|
39 |
-
|
40 |
-
|
41 |
A :obj:`list` | `dict`: will be serialized and returned
|
42 |
"""
|
43 |
-
|
|
|
|
|
44 |
inputs = data.pop("inputs", data)
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
return prediction
|
|
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
+
from transformers import pipeline
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
from transformers import (
|
7 |
+
AutomaticSpeechRecognitionPipeline,
|
8 |
+
WhisperForConditionalGeneration,
|
9 |
+
WhisperTokenizer,
|
10 |
+
WhisperProcessor
|
|
|
11 |
)
|
12 |
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
|
13 |
|
14 |
class EndpointHandler():
|
15 |
def __init__(self, path=""):
|
16 |
+
|
|
|
17 |
language = "Chinese"
|
18 |
+
task = "transcribe"
|
19 |
+
peft_config = PeftConfig.from_pretrained(path)
|
20 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
21 |
peft_config.base_model_name_or_path
|
22 |
)
|
23 |
+
model = PeftModel.from_pretrained(model, path)
|
24 |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
25 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
26 |
feature_extractor = processor.feature_extractor
|
27 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
28 |
+
self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor = feature_extractor)
|
29 |
+
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language=language, task=task)
|
30 |
+
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
31 |
+
|
|
|
|
|
|
|
|
|
32 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
33 |
"""
|
34 |
+
data args:
|
35 |
+
inputs (:obj: `str`)
|
36 |
+
date (:obj: `str`)
|
37 |
+
Return:
|
38 |
A :obj:`list` | `dict`: will be serialized and returned
|
39 |
"""
|
40 |
+
# get inputs
|
41 |
+
|
42 |
+
# run normal prediction
|
43 |
inputs = data.pop("inputs", data)
|
44 |
+
print("a1", inputs)
|
45 |
+
print("a2", inputs, file=sys.stderr)
|
46 |
+
print("a3", inputs, file=sys.stdout)
|
47 |
+
|
48 |
+
prediction = self.pipeline(inputs, return_timestamps=False)
|
49 |
+
|
50 |
+
print("b1", prediction)
|
51 |
+
print("b2", prediction, file=sys.stderr)
|
52 |
+
print("b3", prediction, file=sys.stdout)
|
53 |
return prediction
|