from transformers import FlaxRobertaModel, RobertaTokenizerFast from datasets import load_dataset import jax dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True) dummy_input = next(iter(dataset))["text"] tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10] model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown") # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling` z = model(input_ids)