asahi417 commited on
Commit
da4f293
1 Parent(s): 9768dae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -9,16 +9,27 @@ import tempfile
9
  import os
10
 
11
  MODEL_NAME = "kotoba-tech/kotoba-whisper-v1.0"
12
- BATCH_SIZE = 8
 
13
  FILE_LIMIT_MB = 1000
14
  YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
15
 
16
- device = 0 if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
17
  pipe = pipeline(
18
  task="automatic-speech-recognition",
19
  model=MODEL_NAME,
20
- chunk_length_s=15,
 
21
  device=device,
 
22
  )
23
 
24
 
@@ -26,7 +37,8 @@ pipe = pipeline(
26
  def transcribe(inputs):
27
  if inputs is None:
28
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
29
- return pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
 
30
 
31
 
32
  def _return_yt_html_embed(yt_url):
@@ -68,7 +80,8 @@ def yt_transcribe(yt_url, max_filesize=75.0):
68
  inputs = f.read()
69
  inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
70
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
71
- text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
 
72
  return html_embed_str, text
73
 
74
 
 
9
  import os
10
 
11
  MODEL_NAME = "kotoba-tech/kotoba-whisper-v1.0"
12
+ BATCH_SIZE = 16
13
+ CHUNK_LENGTH_S = 15
14
  FILE_LIMIT_MB = 1000
15
  YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
16
 
17
+ if torch.cuda.is_available():
18
+ torch_dtype = torch.bfloat16
19
+ device = "cuda:0"
20
+ model_kwargs = {'attn_implementation': 'sdpa'}
21
+ else:
22
+ torch_dtype = torch.float32
23
+ device = "cpu"
24
+ model_kwargs = {}
25
+
26
  pipe = pipeline(
27
  task="automatic-speech-recognition",
28
  model=MODEL_NAME,
29
+ chunk_length_s=CHUNK_LENGTH_S,
30
+ torch_dtype=torch_dtype,
31
  device=device,
32
+ model_kwargs=model_kwargs
33
  )
34
 
35
 
 
37
  def transcribe(inputs):
38
  if inputs is None:
39
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
40
+ generate_kwargs = {"language": "japanese", "task": "transcribe"}
41
+ return pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
42
 
43
 
44
  def _return_yt_html_embed(yt_url):
 
80
  inputs = f.read()
81
  inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
82
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
83
+ generate_kwargs = {"language": "japanese", "task": "transcribe"}
84
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
85
  return html_embed_str, text
86
 
87