Habib-HF commited on
Commit
c16d89a
·
verified ·
1 Parent(s): b7eacda

Add custom handler.py for Inference Endpoints

Browse files
Files changed (1) hide show
  1. handler.py +76 -0
handler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
2
+ from optimum.bettertransformer import BetterTransformer
3
+ import torch
4
+ import io
5
+ import soundfile as sf
6
+
7
+ class InferenceHandler:
8
+ def __init__(self):
9
+ self.processor = None
10
+ self.model = None
11
+ self.device = None
12
+
13
+ def load(self, model_path):
14
+ """
15
+ Loads the model and processor from the specified path.
16
+ """
17
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+ print(f"Loading model on device: {self.device}")
19
+
20
+ self.processor = AutoProcessor.from_pretrained(model_path)
21
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
22
+
23
+ if torch.cuda.is_available():
24
+ self.model = BetterTransformer.transform(self.model) # Optimize for faster inference on GPU
25
+ self.model.to(self.device)
26
+ self.model.eval() # Set model to evaluation mode
27
+
28
+ # Set generation parameters
29
+ self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
30
+ self.model.config.suppress_tokens = [] # Allow all tokens to be generated
31
+
32
+ print("Model and processor loaded successfully.")
33
+
34
+ def preprocess(self, input_data):
35
+ """
36
+ Preprocesses the incoming audio data.
37
+ input_data will be bytes (audio file content).
38
+ """
39
+ # Read audio from bytes using soundfile
40
+ # Ensure it's 16kHz, which Whisper expects
41
+ audio_bytes_io = io.BytesIO(input_data)
42
+ audio, original_sampling_rate = sf.read(audio_bytes_io, dtype='float32')
43
+
44
+ # If original sampling rate is not 16kHz, a warning will be logged by feature_extractor
45
+ # The feature extractor handles resampling implicitly if original_sampling_rate != processor.feature_extractor.sampling_rate
46
+ # Ensure it's 1-D array
47
+ if audio.ndim > 1:
48
+ audio = audio.mean(axis=1) # Convert to mono if stereo
49
+
50
+ return audio, original_sampling_rate
51
+
52
+ def predict(self, preprocessed_data):
53
+ """
54
+ Performs inference using the loaded model.
55
+ """
56
+ audio_array, original_sampling_rate = preprocessed_data
57
+
58
+ # Use the processor to create input features, ensuring resampling
59
+ input_features = self.processor.feature_extractor(
60
+ audio_array,
61
+ sampling_rate=original_sampling_rate, # Pass original rate, feature_extractor handles resampling
62
+ return_tensors="pt"
63
+ ).input_features.to(self.device)
64
+
65
+ with torch.no_grad():
66
+ generated_ids = self.model.generate(inputs=input_features, max_new_tokens=225) # Use max_new_tokens
67
+
68
+ transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
69
+ return transcription
70
+
71
+ def postprocess(self, prediction_output):
72
+ """
73
+ Postprocesses the prediction output.
74
+ """
75
+ # For ASR, prediction output is already the string transcription
76
+ return {"transcription": prediction_output}