jlondonobo commited on
Commit
75d8ce0
1 Parent(s): 1be8515

🌟 convert to whisper

Browse files
Files changed (3) hide show
  1. app.py +9 -11
  2. hf_to_whisper.py +70 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,22 +1,20 @@
1
  import gradio as gr
2
  import pytube as pt
3
  import torch
4
- from huggingface_hub import model_info
5
- from transformers import pipeline
 
6
 
7
  MODEL_NAME = "jlondonobo/whisper-medium-pt" #this always needs to stay in line 8 :D sorry for the hackiness
8
  lang = "pt"
9
 
10
  device = 0 if torch.cuda.is_available() else "cpu"
11
 
12
- pipe = pipeline(
13
- task="automatic-speech-recognition",
14
- model=MODEL_NAME,
15
- chunk_length_s=30,
16
- device=device,
17
- )
18
 
19
- pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
20
 
21
  def transcribe(microphone, file_upload):
22
  warn_output = ""
@@ -31,7 +29,7 @@ def transcribe(microphone, file_upload):
31
 
32
  file = microphone if microphone is not None else file_upload
33
 
34
- text = pipe(file)["text"]
35
 
36
  return warn_output + text
37
 
@@ -51,7 +49,7 @@ def yt_transcribe(yt_url):
51
  stream = yt.streams.filter(only_audio=True)[0]
52
  stream.download(filename="audio.mp3")
53
 
54
- text = pipe("audio.mp3")["text"]
55
 
56
  return html_embed_str, text
57
 
 
1
  import gradio as gr
2
  import pytube as pt
3
  import torch
4
+ import whisper
5
+ from hf_to_whisper import write_whisper_model_to_memory
6
+ import os
7
 
8
  MODEL_NAME = "jlondonobo/whisper-medium-pt" #this always needs to stay in line 8 :D sorry for the hackiness
9
  lang = "pt"
10
 
11
  device = 0 if torch.cuda.is_available() else "cpu"
12
 
13
+ local_model_path = "whisper-pt.pt"
14
+ if not os.path.exists(local_model_path):
15
+ write_whisper_model_to_memory(MODEL_NAME, local_model_path)
 
 
 
16
 
17
+ model = whisper.load_model(local_model_path)
18
 
19
  def transcribe(microphone, file_upload):
20
  warn_output = ""
 
29
 
30
  file = microphone if microphone is not None else file_upload
31
 
32
+ text = model.transcribe(file)["text"]
33
 
34
  return warn_output + text
35
 
 
49
  stream = yt.streams.filter(only_audio=True)[0]
50
  stream.download(filename="audio.mp3")
51
 
52
+ text = model.transcribe("audio.mp3", decode_options={"language": lang})["text"]
53
 
54
  return html_embed_str, text
55
 
hf_to_whisper.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original script: bayartsogt-ya/whisper-multiple-hf-datasets
2
+ from copy import deepcopy
3
+ import torch
4
+ from transformers import WhisperForConditionalGeneration
5
+
6
+
7
+ WHISPER_MAPPING = {
8
+ "layers": "blocks",
9
+ "fc1": "mlp.0",
10
+ "fc2": "mlp.2",
11
+ "final_layer_norm": "mlp_ln",
12
+ "layers": "blocks",
13
+ ".self_attn.q_proj": ".attn.query",
14
+ ".self_attn.k_proj": ".attn.key",
15
+ ".self_attn.v_proj": ".attn.value",
16
+ ".self_attn_layer_norm": ".attn_ln",
17
+ ".self_attn.out_proj": ".attn.out",
18
+ ".encoder_attn.q_proj": ".cross_attn.query",
19
+ ".encoder_attn.k_proj": ".cross_attn.key",
20
+ ".encoder_attn.v_proj": ".cross_attn.value",
21
+ ".encoder_attn_layer_norm": ".cross_attn_ln",
22
+ ".encoder_attn.out_proj": ".cross_attn.out",
23
+ "decoder.layer_norm.": "decoder.ln.",
24
+ "encoder.layer_norm.": "encoder.ln_post.",
25
+ "embed_tokens": "token_embedding",
26
+ "encoder.embed_positions.weight": "encoder.positional_embedding",
27
+ "decoder.embed_positions.weight": "decoder.positional_embedding",
28
+ "layer_norm": "ln_post",
29
+ }
30
+
31
+
32
+ def rename_keys(s_dict):
33
+ keys = list(s_dict.keys())
34
+ for key in keys:
35
+ new_key = key
36
+ for k, v in WHISPER_MAPPING.items():
37
+ if k in key:
38
+ new_key = new_key.replace(k, v)
39
+
40
+ print(f"{key} -> {new_key}")
41
+
42
+ s_dict[new_key] = s_dict.pop(key)
43
+ return s_dict
44
+
45
+
46
+ def write_whisper_model_to_memory(
47
+ hf_model_name_or_path: str,
48
+ whisper_state_path: str
49
+ ):
50
+ transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
51
+ config = transformer_model.config
52
+
53
+ # first build dims
54
+ dims = {
55
+ 'n_mels': config.num_mel_bins,
56
+ 'n_vocab': config.vocab_size,
57
+ 'n_audio_ctx': config.max_source_positions,
58
+ 'n_audio_state': config.d_model,
59
+ 'n_audio_head': config.encoder_attention_heads,
60
+ 'n_audio_layer': config.encoder_layers,
61
+ 'n_text_ctx': config.max_target_positions,
62
+ 'n_text_state': config.d_model,
63
+ 'n_text_head': config.decoder_attention_heads,
64
+ 'n_text_layer': config.decoder_layers
65
+ }
66
+
67
+ state_dict = deepcopy(transformer_model.model.state_dict())
68
+ state_dict = rename_keys(state_dict)
69
+
70
+ torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  git+https://github.com/huggingface/transformers
 
2
  torch
3
  pytube
 
1
  git+https://github.com/huggingface/transformers
2
+ git+https://github.com/openai/whisper.git
3
  torch
4
  pytube