hubert-dementia-screening / run_hubert_classifier.py
patrickvonplaten's picture
Update run_hubert_classifier.py
47a6a98
#!/usr/bin/env python3
from hubert_for_sequence_classification import FlaxHubertForSequenceClassification, FlaxHubertModel
import numpy as np
# need to do some ugly save/reload because of a bug: https://github.com/huggingface/transformers/issues/12532
model = FlaxHubertModel.from_pretrained("facebook/hubert-large-ll60k", from_pt=True)
model.save_pretrained("./")
model = FlaxHubertForSequenceClassification.from_pretrained("./")
dummy_input = np.array(2 * [1024 * [1.0]], dtype=np.float32)
logits = model(dummy_input).logits
# output shape is (batch_size, 2)
print("output shape", logits.shape)