lewtun HF staff commited on
Commit
17c79b3
1 Parent(s): 84cea97

Use example expert from s3prl

Browse files
Files changed (1) hide show
  1. {{cookiecutter.repo_name}}/expert.py +64 -43
{{cookiecutter.repo_name}}/expert.py CHANGED
@@ -1,56 +1,77 @@
1
- from packaging import version
 
2
 
3
- import torch
4
  import torch.nn as nn
5
- import torch.nn.functional as F
6
  from torch.nn.utils.rnn import pad_sequence
7
 
8
- import fairseq
9
- from s3prl.upstream.interfaces import UpstreamBase
10
 
11
 
12
- SAMPLE_RATE = 16000
13
- EXAMPLE_SEC = 5
 
 
 
 
 
14
 
15
- class UpstreamExpert(UpstreamBase):
16
- def __init__(self, ckpt, **kwargs):
17
- super().__init__(**kwargs)
18
- assert version.parse(fairseq.__version__) > version.parse(
19
- "0.10.2"
20
- ), "Please install the fairseq master branch."
 
21
 
22
- model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
23
- [ckpt]
24
  )
25
- self.model = model[0]
26
- self.task = task
27
-
28
- if len(self.hooks) == 0:
29
- module_name = "self.model.encoder.layers"
30
- for module_id in range(len(eval(module_name))):
31
- self.add_hook(
32
- f"{module_name}[{module_id}]",
33
- lambda input, output: input[0].transpose(0, 1),
34
- )
35
- self.add_hook("self.model.encoder", lambda input, output: output[0])
36
-
37
- def forward(self, wavs):
38
- if self.task.cfg.normalize:
39
- wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
40
-
41
- device = wavs[0].device
42
- wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
43
- wav_padding_mask = ~torch.lt(
44
- torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
45
- wav_lengths.unsqueeze(1),
46
  )
47
- padded_wav = pad_sequence(wavs, batch_first=True)
48
 
49
- features, feat_padding_mask = self.model.extract_features(
50
- padded_wav,
51
- padding_mask=wav_padding_mask,
52
- mask=None,
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return {
55
- "default": features,
 
 
 
 
 
 
 
 
 
 
 
56
  }
 
1
+ from collections import OrderedDict
2
+ from typing import List, Union, Dict
3
 
 
4
  import torch.nn as nn
5
+ from torch import Tensor
6
  from torch.nn.utils.rnn import pad_sequence
7
 
8
+ HIDDEN_DIM = 8
 
9
 
10
 
11
+ class UpstreamExpert(nn.Module):
12
+ def __init__(self, ckpt: str = None, model_config: str = None, **kwargs):
13
+ """
14
+ Args:
15
+ ckpt:
16
+ The checkpoint path for loading your pretrained weights.
17
+ Can be assigned by the -k option in run_downstream.py
18
 
19
+ model_config:
20
+ The config path for constructing your model.
21
+ Might not needed if you also save that in your checkpoint file.
22
+ Can be assigned by the -g option in run_downstream.py
23
+ """
24
+ super().__init__()
25
+ self.name = "[Example UpstreamExpert]"
26
 
27
+ print(
28
+ f"{self.name} - You can use model_config to construct your customized model: {model_config}"
29
  )
30
+ print(f"{self.name} - You can use ckpt to load your pretrained weights: {ckpt}")
31
+ print(
32
+ f"{self.name} - If you store the pretrained weights and model config in a single file, "
33
+ "you can just choose one argument (ckpt or model_config) to pass. It's up to you!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
 
35
 
36
+ # The model needs to be a nn.Module for finetuning, not required for representation extraction
37
+ self.model1 = nn.Linear(1, HIDDEN_DIM)
38
+ self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
39
+
40
+ def get_downsample_rates(self, key: str) -> int:
41
+ """
42
+ Since we do not do any downsampling in this example upstream
43
+ All keys' corresponding representations have downsample rate of 1
44
+ """
45
+ return 1
46
+
47
+ def forward(self, wavs: List[Tensor]) -> Dict[str, Union[Tensor, List[Tensor]]]:
48
+ """
49
+ When the returning Dict contains the List with more than one Tensor,
50
+ those Tensors should be in the same shape to train a weighted-sum on them.
51
+ """
52
+
53
+ wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1)
54
+ # wavs: (batch_size, max_len, 1)
55
+
56
+ hidden = self.model1(wavs)
57
+ # hidden: (batch_size, max_len, hidden_dim)
58
+
59
+ feature = self.model2(hidden)
60
+ # feature: (batch_size, max_len, hidden_dim)
61
+
62
+ # The "hidden_states" key will be used as default in many cases
63
+ # Others keys in this example are presented for SUPERB Challenge
64
  return {
65
+ "hidden_states": [hidden, feature],
66
+ "PR": [hidden, feature],
67
+ "ASR": [hidden, feature],
68
+ "QbE": [hidden, feature],
69
+ "SID": [hidden, feature],
70
+ "ASV": [hidden, feature],
71
+ "SD": [hidden, feature],
72
+ "ER": [hidden, feature],
73
+ "SF": [hidden, feature],
74
+ "SE": [hidden, feature],
75
+ "SS": [hidden, feature],
76
+ "secret": [hidden, feature],
77
  }