lewtun HF staff commited on
Commit
adb6f04
1 Parent(s): aaf8abe

Add model file and checkpoint

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. hubert_sd.ckpt +3 -0
  3. model.py +55 -0
.gitattributes CHANGED
@@ -15,3 +15,4 @@
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
+ hubert_sd.ckpt filter=lfs diff=lfs merge=lfs -text
hubert_sd.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a6414fe26484dd73d1696a3f8dffa1499747307eb7541ea715ffea4c678fba8
3
+ size 31526316
model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is just an example of what people would submit for inference.
3
+ """
4
+
5
+ import os
6
+ from typing import Dict
7
+
8
+ import torch
9
+ from s3prl.downstream.runner import Runner
10
+
11
+
12
+ class PreTrainedModel(Runner):
13
+ def __init__(self, path=""):
14
+ """
15
+ Initialize downstream model.
16
+ """
17
+ ckp_file = os.path.join(path, "hubert_sd.ckpt")
18
+ ckp = torch.load(ckp_file, map_location="cpu")
19
+ ckp["Args"].init_ckpt = ckp_file
20
+ ckp["Args"].mode = "inference"
21
+ ckp["Args"].device = "cpu" # Just to try in my computer
22
+
23
+ Runner.__init__(self, ckp["Args"], ckp["Config"])
24
+
25
+ def __call__(self, inputs) -> Dict[str, str]:
26
+ """
27
+ Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
28
+ default at 16KHz.
29
+
30
+ Return: A :obj:`dict`:. The object should return a dictionary {"frames":
31
+ "XXX"} containing the frames where one, both, or none of the
32
+ speakers are speaking.
33
+ """
34
+ for entry in self.all_entries:
35
+ entry.model.eval()
36
+
37
+ inputs = [torch.FloatTensor(inputs)]
38
+
39
+ with torch.no_grad():
40
+ features = self.upstream.model(inputs)
41
+ features = self.featurizer.model(inputs, features)
42
+ preds = self.downstream.model.inference(features, [])
43
+ return preds[0]
44
+
45
+
46
+ """
47
+ import io
48
+ import soundfile as sf
49
+ from urllib.request import urlopen
50
+
51
+ model = PreTrainedModel()
52
+ url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
53
+ data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
54
+ print(model(data))
55
+ """