Habib-HF commited on
Commit
d6990c3
·
verified ·
1 Parent(s): 4e98f15

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -75
handler.py DELETED
@@ -1,75 +0,0 @@
1
- import torch
2
- import io
3
- import soundfile as sf
4
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
-
6
- # Renamed class from InferenceHandler to EndpointHandler
7
- class EndpointHandler:
8
- def __init__(self):
9
- self.processor = None
10
- self.model = None
11
- self.device = None
12
-
13
- def load(self, model_path):
14
- """
15
- Loads the model and processor from the specified path.
16
- """
17
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
- print(f"Loading model on device: {self.device}")
19
-
20
- self.processor = AutoProcessor.from_pretrained(model_path)
21
- self.model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
22
-
23
- if torch.cuda.is_available():
24
- self.model.to(self.device) # Only move to device, no BetterTransformer
25
- self.model.eval() # Set model to evaluation mode
26
-
27
- # Set generation parameters
28
- self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
29
- self.model.config.suppress_tokens = [] # Allow all tokens to be generated
30
-
31
- print("Model and processor loaded successfully.")
32
-
33
- def preprocess(self, input_data):
34
- """
35
- Preprocesses the incoming audio data.
36
- input_data will be bytes (audio file content).
37
- """
38
- # Read audio from bytes using soundfile
39
- # Ensure it's 16kHz, which Whisper expects
40
- audio_bytes_io = io.BytesIO(input_data)
41
- audio, original_sampling_rate = sf.read(audio_bytes_io, dtype='float32')
42
-
43
- # If original sampling rate is not 16kHz, a warning will be logged by feature_extractor
44
- # The feature extractor handles resampling implicitly if original_sampling_rate != processor.feature_extractor.sampling_rate
45
- # Ensure it's 1-D array
46
- if audio.ndim > 1:
47
- audio = audio.mean(axis=1) # Convert to mono if stereo
48
-
49
- return audio, original_sampling_rate
50
-
51
- def predict(self, preprocessed_data):
52
- """
53
- Performs inference using the loaded model.
54
- """
55
- audio_array, original_sampling_rate = preprocessed_data
56
-
57
- # Use the processor to create input features, ensuring resampling
58
- input_features = self.processor.feature_extractor(
59
- audio_array,
60
- sampling_rate=original_sampling_rate, # Pass original rate, feature_extractor handles resampling
61
- return_tensors="pt"
62
- ).input_features.to(self.device)
63
-
64
- with torch.no_grad():
65
- generated_ids = self.model.generate(inputs=input_features, max_new_tokens=225) # Use max_new_tokens
66
-
67
- transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
- return transcription
69
-
70
- def postprocess(self, prediction_output):
71
- """
72
- Postprocesses the prediction output.
73
- """
74
- # For ASR, prediction output is already the string transcription
75
- return {"transcription": prediction_output}