voidful commited on
Commit
61e51f4
1 Parent(s): 6683b74

prediction code with GPT

Browse files
Files changed (1) hide show
  1. README.md +19 -5
README.md CHANGED
@@ -19,20 +19,25 @@ from datasets import load_dataset, load_metric
19
  from transformers import (
20
  Wav2Vec2ForCTC,
21
  Wav2Vec2Processor,
 
 
22
  )
23
  import torch
24
  import re
25
  import sys
26
 
27
- model_name = "voidful/wav2vec2-large-xlsr-53-tw"
28
  device = "cuda"
29
- processor_name = "voidful/wav2vec2-large-xlsr-53-tw"
30
 
31
  chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
32
 
33
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
34
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
35
 
 
 
 
36
  resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
37
 
38
  def load_file_to_data(file):
@@ -42,16 +47,25 @@ def load_file_to_data(file):
42
  batch["sampling_rate"] = resampler.new_freq
43
  return batch
44
 
45
-
46
  def predict(data):
47
  features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
48
  input_values = features.input_values.to(device)
49
  attention_mask = features.attention_mask.to(device)
50
  with torch.no_grad():
51
  logits = model(input_values, attention_mask=attention_mask).logits
52
- pred_ids = torch.argmax(logits, dim=-1)
53
- return processor.batch_decode(pred_ids)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ```
56
 
57
  Predict
 
19
  from transformers import (
20
  Wav2Vec2ForCTC,
21
  Wav2Vec2Processor,
22
+ AutoTokenizer,
23
+ AutoModelWithLMHead
24
  )
25
  import torch
26
  import re
27
  import sys
28
 
29
+ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
30
  device = "cuda"
31
+ processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
32
 
33
  chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
34
 
35
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
36
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
37
 
38
+ tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")
39
+ gpt_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)
40
+
41
  resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
42
 
43
  def load_file_to_data(file):
 
47
  batch["sampling_rate"] = resampler.new_freq
48
  return batch
49
 
 
50
  def predict(data):
51
  features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
52
  input_values = features.input_values.to(device)
53
  attention_mask = features.attention_mask.to(device)
54
  with torch.no_grad():
55
  logits = model(input_values, attention_mask=attention_mask).logits
 
 
56
 
57
+ decoded_results = []
58
+ for logit in logits:
59
+ pred_ids = torch.argmax(logit, dim=-1)
60
+ mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
61
+ vocab_size = logit.size()[-1]
62
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
63
+ gpt_input = torch.cat((torch.tensor([tokenizer.cls_token_id]).to(device),pred_ids[pred_ids>0]), 0)
64
+ gpt_prob = torch.nn.functional.softmax(gpt_model(gpt_input).logits, dim=-1)[:voice_prob.size()[0],:]
65
+ comb_pred_ids = torch.argmax(gpt_prob*voice_prob, dim=-1)
66
+ decoded_results.append(processor.decode(comb_pred_ids))
67
+
68
+ return decoded_results
69
  ```
70
 
71
  Predict