{ "cells": [ { "cell_type": "code", "execution_count": 1, "source": [ "from transformers import FlaxRobertaModel, RobertaTokenizerFast\n", "from datasets import load_dataset\n", "import jax\n", "\n", "dataset = load_dataset('oscar', \"unshuffled_deduplicated_en\", split='train', streaming=True)\n", "\n", "dummy_input = next(iter(dataset))[\"text\"]\n", "\n", "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n", "input_ids = tokenizer(dummy_input, return_tensors=\"np\").input_ids[:, :10]\n", "\n", "model = FlaxRobertaModel.from_pretrained(\"julien-c/dummy-unknown\")\n", "\n", "# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`\n", "model(input_ids)" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading: 0%| | 0.00/5.58k [00:00 512). Running this sequence through the model will result in indexing errors\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading: 0%| | 0.00/496 [00:00