komyu1227 commited on
Commit
b3834f9
·
1 Parent(s): 5b2a8fd

change whisper

Browse files
Files changed (2) hide show
  1. serve.py +34 -1
  2. uv.lock +0 -0
serve.py CHANGED
@@ -7,13 +7,17 @@ import io
7
  from pydub import AudioSegment
8
  import time
9
  import logging
 
10
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- model = load_model(device)
 
 
 
17
 
18
  def transcribe_audio(audio_data_bytes):
19
  try:
@@ -33,6 +37,35 @@ def transcribe_audio(audio_data_bytes):
33
  return result
34
  except Exception as e:
35
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  app = FastAPI()
38
 
 
7
  from pydub import AudioSegment
8
  import time
9
  import logging
10
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
11
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+ # model = load_model(device)
18
+
19
+ processor = WhisperProcessor.from_pretrained("Ivydata/whisper-small-japanese").to(device)
20
+ model = WhisperForConditionalGeneration.from_pretrained("Ivydata/whisper-small-japanese").to(device)
21
 
22
  def transcribe_audio(audio_data_bytes):
23
  try:
 
37
  return result
38
  except Exception as e:
39
  raise HTTPException(status_code=500, detail=str(e))
40
+
41
+ def transcribe_whisper(audio_data_bytes):
42
+ try:
43
+ start_time = time.time()
44
+ audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
45
+
46
+ # Get audio data as numpy array
47
+ audio_data_int16 = np.array(audio_segment.get_array_of_samples())
48
+ # Convert to float32 normalized to [-1, 1]
49
+ audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
50
+
51
+ # Process with whisper
52
+ input_features = processor(audio=audio_data_float32,
53
+ sampling_rate=audio_segment.frame_rate,
54
+ return_tensors="pt").input_features.to(device)
55
+
56
+ predicted_ids = model.generate(input_features=input_features)
57
+
58
+ result = processor.batch_decode(predicted_ids, skip_special_tokens=True)
59
+ resultText = result[0] if isinstance(result, list) and len(result) > 0 else str(result)
60
+ result = {
61
+ "text": resultText
62
+ }
63
+ end_time = time.time()
64
+ print(f"Time taken: {end_time - start_time} seconds")
65
+ return result
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=str(e))
68
+
69
 
70
  app = FastAPI()
71
 
uv.lock CHANGED
The diff for this file is too large to render. See raw diff