File size: 601 Bytes
004e84f
 
 
 
47a6a98
004e84f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env python3
from hubert_for_sequence_classification import FlaxHubertForSequenceClassification, FlaxHubertModel
import numpy as np

# need to do some ugly save/reload because of a bug: https://github.com/huggingface/transformers/issues/12532
model = FlaxHubertModel.from_pretrained("facebook/hubert-large-ll60k", from_pt=True)
model.save_pretrained("./")
model = FlaxHubertForSequenceClassification.from_pretrained("./")

dummy_input = np.array(2 * [1024 * [1.0]], dtype=np.float32)

logits = model(dummy_input).logits

# output shape is (batch_size, 2)
print("output shape", logits.shape)