a43992899 commited on
Commit
b3b4638
1 Parent(s): ca2f72a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -6,7 +6,7 @@ inference: false
6
  A simple use case:
7
 
8
  ```shell
9
- from transformers import Wav2Vec2Processor, AutoModel
10
  import torch
11
  from torch import nn
12
  from datasets import load_dataset
@@ -15,10 +15,10 @@ from datasets import load_dataset
15
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
16
  dataset = dataset.sort("id")
17
  sampling_rate = dataset.features["audio"].sampling_rate
18
- processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
19
 
20
  # loading our model weights
21
- model = AutoModel.from_pretrained("m-a-p/MERT-v0")
22
 
23
  # audio file is decoded on the fly
24
  inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
@@ -36,6 +36,6 @@ print(time_reduced_hidden_states.shape) # [13, 768]
36
 
37
  # you can even use a learnable weighted average representation
38
  aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
39
- weighted_avg_hidden_states = aggregator(time_reduced_hidden_states).squeeze()
40
  print(weighted_avg_hidden_states.shape) # [768]
41
  ```
 
6
  A simple use case:
7
 
8
  ```shell
9
+ from transformers import Wav2Vec2Processor, HubertModel
10
  import torch
11
  from torch import nn
12
  from datasets import load_dataset
 
15
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
16
  dataset = dataset.sort("id")
17
  sampling_rate = dataset.features["audio"].sampling_rate
18
+ processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
19
 
20
  # loading our model weights
21
+ model = HubertModel.from_pretrained("m-a-p/MERT-v0")
22
 
23
  # audio file is decoded on the fly
24
  inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
 
36
 
37
  # you can even use a learnable weighted average representation
38
  aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
39
+ weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
40
  print(weighted_avg_hidden_states.shape) # [768]
41
  ```