kiranpantha commited on
Commit
03c4019
·
verified ·
1 Parent(s): e7e598c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -13,35 +13,47 @@ model_urls = [
13
  "kiranpantha/whisper-large-v3-turbo-nepali",
14
  ]
15
 
 
 
 
 
 
 
 
 
 
 
16
  # Cache models and processors
17
  model_cache = {}
18
 
19
  def load_model(model_name):
20
  """Loads and caches the model and processor with proper device management."""
21
  if model_name not in model_cache:
22
- processor_name = model_name.replace("kiranpantha", "openai").replace(
23
- "-nepali", "").replace("-ne", "").replace("-np", "")
24
 
25
- # Load processor and model
26
  processor = AutoProcessor.from_pretrained(processor_name)
27
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
28
  model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
 
29
  model_cache[model_name] = (processor, model, device)
 
30
  return model_cache[model_name]
31
 
32
  def create_pipeline(model_name):
33
  """Creates an ASR pipeline with proper configuration."""
34
  processor, model, device = load_model(model_name)
 
35
  return AutomaticSpeechRecognitionPipeline(
36
  model=model,
37
  processor=processor,
38
- device=device,
39
- generate_kwargs={"task": "transcribe", "language": "nepali"} # Verify language code
40
  )
41
 
42
  def process_audio(model_url, audio_chunk):
43
  """Processes audio and returns transcription with error handling."""
44
  try:
 
45
  audio_array, sample_rate = audio_chunk
46
 
47
  # Convert stereo to mono
@@ -51,7 +63,7 @@ def process_audio(model_url, audio_chunk):
51
  # Resample to 16kHz if needed
52
  if sample_rate != 16000:
53
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
54
- audio_array = resampler(torch.tensor(audio_array)).numpy()
55
 
56
  # Create pipeline and process
57
  asr_pipeline = create_pipeline(model_url)
@@ -65,24 +77,15 @@ def process_audio(model_url, audio_chunk):
65
  with gr.Blocks() as demo:
66
  gr.Markdown("# Nepali Speech Recognition with Whisper Models")
67
 
68
- model_dropdown = gr.Dropdown(
69
- choices=model_urls,
70
- label="Select Model",
71
- value=model_urls[0]
72
- )
73
-
74
- audio_input = gr.Audio(
75
- type="numpy",
76
- label="Input Audio",
77
- streaming=True
78
- )
79
-
80
  output_text = gr.Textbox(label="Transcription")
81
-
82
- audio_input.stream(
 
83
  fn=process_audio,
84
  inputs=[model_dropdown, audio_input],
85
  outputs=output_text,
86
  )
87
 
88
- demo.launch()
 
13
  "kiranpantha/whisper-large-v3-turbo-nepali",
14
  ]
15
 
16
+ # Mapping model names correctly
17
+ processor_mappings = {
18
+ "kiranpantha/whisper-tiny-ne": "openai/whisper-tiny",
19
+ "kiranpantha/whisper-base-ne": "openai/whisper-base",
20
+ "kiranpantha/whisper-small-np": "openai/whisper-small",
21
+ "kiranpantha/whisper-medium-nepali": "openai/whisper-medium",
22
+ "kiranpantha/whisper-large-v3-nepali": "openai/whisper-large-v3",
23
+ "kiranpantha/whisper-large-v3-turbo-nepali": "openai/whisper-large-v3",
24
+ }
25
+
26
  # Cache models and processors
27
  model_cache = {}
28
 
29
  def load_model(model_name):
30
  """Loads and caches the model and processor with proper device management."""
31
  if model_name not in model_cache:
32
+ processor_name = processor_mappings.get(model_name, model_name) # Handle mapping
 
33
 
 
34
  processor = AutoProcessor.from_pretrained(processor_name)
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
37
+
38
  model_cache[model_name] = (processor, model, device)
39
+
40
  return model_cache[model_name]
41
 
42
  def create_pipeline(model_name):
43
  """Creates an ASR pipeline with proper configuration."""
44
  processor, model, device = load_model(model_name)
45
+
46
  return AutomaticSpeechRecognitionPipeline(
47
  model=model,
48
  processor=processor,
49
+ device=device.index if device.type == "cuda" else -1, # Ensure compatibility
50
+ generate_kwargs={"task": "transcribe", "language": "ne"} # "nepali" might not work
51
  )
52
 
53
  def process_audio(model_url, audio_chunk):
54
  """Processes audio and returns transcription with error handling."""
55
  try:
56
+ # Unpack audio_chunk (tuple) into audio array and sample rate
57
  audio_array, sample_rate = audio_chunk
58
 
59
  # Convert stereo to mono
 
63
  # Resample to 16kHz if needed
64
  if sample_rate != 16000:
65
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
66
+ audio_array = resampler(torch.tensor(audio_array).unsqueeze(0)).squeeze(0).numpy()
67
 
68
  # Create pipeline and process
69
  asr_pipeline = create_pipeline(model_url)
 
77
  with gr.Blocks() as demo:
78
  gr.Markdown("# Nepali Speech Recognition with Whisper Models")
79
 
80
+ model_dropdown = gr.Dropdown(choices=model_urls, label="Select Model", value=model_urls[0])
81
+ audio_input = gr.Audio(type="numpy", label="Input Audio")
 
 
 
 
 
 
 
 
 
 
82
  output_text = gr.Textbox(label="Transcription")
83
+ transcribe_button = gr.Button("Transcribe")
84
+
85
+ transcribe_button.click(
86
  fn=process_audio,
87
  inputs=[model_dropdown, audio_input],
88
  outputs=output_text,
89
  )
90
 
91
+ demo.launch()