Update README.md
Browse files
README.md
CHANGED
@@ -37,7 +37,7 @@ It achieves Unweighed Average Recall (UAR) values of .8001 and .8084 on the deve
|
|
37 |
- **Paper [optional]:** [More Information Needed]
|
38 |
|
39 |
|
40 |
-
##
|
41 |
|
42 |
The following snippet illustrates the usage of the model.
|
43 |
```python
|
@@ -49,13 +49,15 @@ import librosa
|
|
49 |
checkpoint = "chrlukas/flattery_prediction_speech"
|
50 |
processor = AutoFeatureExtractor.from_pretrained(checkpoint)
|
51 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(checkpoint)
|
|
|
52 |
|
53 |
# predict flattery in a sentence
|
54 |
example_file = 'example.wav'
|
55 |
# audio must be resampled to 16Hz
|
56 |
y, _ = librosa.load(test_file, sr=16000)
|
57 |
inp = processor(y, sampling_rate=16000, return_tensors='pt')
|
58 |
-
|
|
|
59 |
prediction = sigmoid(logits).item()
|
60 |
flattery = prediction >= 0.5
|
61 |
print(f'Flattery detected? {flattery}')
|
|
|
37 |
- **Paper [optional]:** [More Information Needed]
|
38 |
|
39 |
|
40 |
+
## Usage
|
41 |
|
42 |
The following snippet illustrates the usage of the model.
|
43 |
```python
|
|
|
49 |
checkpoint = "chrlukas/flattery_prediction_speech"
|
50 |
processor = AutoFeatureExtractor.from_pretrained(checkpoint)
|
51 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(checkpoint)
|
52 |
+
model.eval()
|
53 |
|
54 |
# predict flattery in a sentence
|
55 |
example_file = 'example.wav'
|
56 |
# audio must be resampled to 16Hz
|
57 |
y, _ = librosa.load(test_file, sr=16000)
|
58 |
inp = processor(y, sampling_rate=16000, return_tensors='pt')
|
59 |
+
with torch.no_grad():
|
60 |
+
logits = model(**inp).logits
|
61 |
prediction = sigmoid(logits).item()
|
62 |
flattery = prediction >= 0.5
|
63 |
print(f'Flattery detected? {flattery}')
|