jimbozhang commited on
Commit
db12bbd
1 Parent(s): 9cd3efa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -7
README.md CHANGED
@@ -31,20 +31,21 @@ pip install -r requirements.txt
31
  >>> from ced_model.feature_extraction_ced import CedFeatureExtractor
32
  >>> from ced_model.modeling_ced import CedForAudioClassification
33
 
34
- >>> model_id = "mispeech/ced-base"
35
- >>> feature_extractor = CedFeatureExtractor.from_pretrained(model_id)
36
- >>> model = CedForAudioClassification.from_pretrained(model_id)
37
 
38
  >>> import torchaudio
39
  >>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
40
-
41
  >>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
 
 
42
  >>> with torch.no_grad():
43
  ... logits = model(**inputs).logits
44
 
45
- >>> import torch
46
- >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
47
- >>> model.config.id2label[predicted_class_ids]
48
  'Finger snapping'
49
  ```
50
 
 
31
  >>> from ced_model.feature_extraction_ced import CedFeatureExtractor
32
  >>> from ced_model.modeling_ced import CedForAudioClassification
33
 
34
+ >>> model_name = "mispeech/ced-base"
35
+ >>> feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
36
+ >>> model = CedForAudioClassification.from_pretrained(model_name)
37
 
38
  >>> import torchaudio
39
  >>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
40
+ >>> assert sampling_rate == 16000
41
  >>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
42
+
43
+ >>> import torch
44
  >>> with torch.no_grad():
45
  ... logits = model(**inputs).logits
46
 
47
+ >>> predicted_class_id = torch.argmax(logits, dim=-1).item()
48
+ >>> model.config.id2label[predicted_class_id]
 
49
  'Finger snapping'
50
  ```
51