#!/usr/bin/env python3 from hubert_for_sequence_classification import FlaxHubertForSequenceClassification, FlaxHubertModel import numpy as np # need to do some ugly save/reload because of a bug: https://github.com/huggingface/transformers/issues/12532 model = FlaxHubertModel.from_pretrained("facebook/hubert-large-ll60k", from_pt=True) model.save_pretrained("./") model = FlaxHubertForSequenceClassification.from_pretrained("./") dummy_input = np.array(2 * [1024 * [1.0]], dtype=np.float32) logits = model(dummy_input).logits # output shape is (batch_size, 2) print("output shape", logits.shape)