yizhilll commited on
Commit
91f804c
1 Parent(s): 7a16d1f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -6
README.md CHANGED
@@ -53,30 +53,46 @@ Larger models trained with more data are on the way.
53
  # Model Usage
54
 
55
  ```python
56
- from transformers import Wav2Vec2Processor
57
  from transformers import AutoModel
58
  import torch
59
  from torch import nn
 
60
  from datasets import load_dataset
61
 
 
 
 
 
 
 
62
  # load demo audio and set processor
63
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
64
  dataset = dataset.sort("id")
65
  sampling_rate = dataset.features["audio"].sampling_rate
66
- processor = Wav2Vec2Processor.from_pretrained("m-a-p/MERT-v0")
67
 
68
- # loading our model weights
69
- model = AutoModel.from_pretrained("m-a-p/MERT-v0")
 
 
 
 
 
70
 
71
  # audio file is decoded on the fly
72
- inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
 
 
 
 
 
73
  with torch.no_grad():
74
  outputs = model(**inputs, output_hidden_states=True)
75
 
76
  # take a look at the output shape, there are 13 layers of representation
77
  # each layer performs differently in different downstream tasks, you should choose empirically
78
  all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
79
- print(all_layer_hidden_states.shape) # [13 layer, 292 timestep, 768 feature_dim]
80
 
81
  # for utterance level classification tasks, you can simply reduce the representation in time
82
  time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
@@ -86,6 +102,8 @@ print(time_reduced_hidden_states.shape) # [13, 768]
86
  aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
87
  weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
88
  print(weighted_avg_hidden_states.shape) # [768]
 
 
89
  ```
90
 
91
  # Citation
 
53
  # Model Usage
54
 
55
  ```python
56
+ from transformers import Wav2Vec2FeatureExtractor
57
  from transformers import AutoModel
58
  import torch
59
  from torch import nn
60
+ import torchaudio.transforms as T
61
  from datasets import load_dataset
62
 
63
+
64
+ # loading our model weights
65
+ model = AutoModel.from_pretrained("m-a-p/MERT-v0", trust_remote_code=True)
66
+ # loading the corresponding preprocessor config
67
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0",trust_remote_code=True)
68
+
69
  # load demo audio and set processor
70
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
71
  dataset = dataset.sort("id")
72
  sampling_rate = dataset.features["audio"].sampling_rate
 
73
 
74
+ resample_rate = processor.sampling_rate
75
+ # make sure the sample_rate aligned
76
+ if resample_rate != sampling_rate:
77
+ print(f'setting rate from {sampling_rate} to {resample_rate}')
78
+ resampler = T.Resample(sampling_rate, resample_rate)
79
+ else:
80
+ resampler = None
81
 
82
  # audio file is decoded on the fly
83
+ if resampler is None:
84
+ input_audio = dataset[0]["audio"]["array"]
85
+ else:
86
+ input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
87
+
88
+ inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
89
  with torch.no_grad():
90
  outputs = model(**inputs, output_hidden_states=True)
91
 
92
  # take a look at the output shape, there are 13 layers of representation
93
  # each layer performs differently in different downstream tasks, you should choose empirically
94
  all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
95
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
96
 
97
  # for utterance level classification tasks, you can simply reduce the representation in time
98
  time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
 
102
  aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
103
  weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
104
  print(weighted_avg_hidden_states.shape) # [768]
105
+
106
+
107
  ```
108
 
109
  # Citation