Plachta commited on
Commit
53eb71f
1 Parent(s): b4a870d

removed offloading

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -97,7 +97,7 @@ model.eval()
97
  audio_tokenizer = AudioTokenizer(device)
98
 
99
  # ASR
100
- whisper_model = whisper.load_model("medium").cpu()
101
 
102
  def clear_prompts():
103
  try:
@@ -125,7 +125,7 @@ def transcribe_one(model, audio_path):
125
  print(f"Detected language: {max(probs, key=probs.get)}")
126
  lang = max(probs, key=probs.get)
127
  # decode the audio
128
- options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=100)
129
  result = whisper.decode(model, mel, options)
130
 
131
  # print the recognized text
@@ -168,7 +168,6 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio):
168
  def make_prompt(name, wav, sr, save=True):
169
 
170
  global whisper_model
171
- whisper_model.to(device)
172
  if not isinstance(wav, torch.FloatTensor):
173
  wav = torch.tensor(wav)
174
  if wav.abs().max() > 1:
@@ -188,7 +187,6 @@ def make_prompt(name, wav, sr, save=True):
188
  os.remove(f"./prompts/{name}.wav")
189
  os.remove(f"./prompts/{name}.txt")
190
 
191
- whisper_model.cpu()
192
  torch.cuda.empty_cache()
193
  return text, lang
194
 
 
97
  audio_tokenizer = AudioTokenizer(device)
98
 
99
  # ASR
100
+ whisper_model = whisper.load_model("medium")
101
 
102
  def clear_prompts():
103
  try:
 
125
  print(f"Detected language: {max(probs, key=probs.get)}")
126
  lang = max(probs, key=probs.get)
127
  # decode the audio
128
+ options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
129
  result = whisper.decode(model, mel, options)
130
 
131
  # print the recognized text
 
168
  def make_prompt(name, wav, sr, save=True):
169
 
170
  global whisper_model
 
171
  if not isinstance(wav, torch.FloatTensor):
172
  wav = torch.tensor(wav)
173
  if wav.abs().max() > 1:
 
187
  os.remove(f"./prompts/{name}.wav")
188
  os.remove(f"./prompts/{name}.txt")
189
 
 
190
  torch.cuda.empty_cache()
191
  return text, lang
192