mariagrandury
commited on
Commit
•
b8db3fd
1
Parent(s):
5670b4e
Set up
Browse files- .gitignore +2 -0
- requirements.txt +12 -0
- test_setup.py +17 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
data_cache/
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa
|
2 |
+
ffmpeg
|
3 |
+
dataclasses
|
4 |
+
pathlib
|
5 |
+
tqdm
|
6 |
+
numpy
|
7 |
+
tensorflow
|
8 |
+
datasets
|
9 |
+
transformers
|
10 |
+
flax
|
11 |
+
jax
|
12 |
+
optax
|
test_setup.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#tpu-vm
|
2 |
+
from transformers import FlaxRobertaModel, RobertaTokenizerFast
|
3 |
+
from datasets import load_dataset
|
4 |
+
import jax
|
5 |
+
|
6 |
+
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
|
7 |
+
|
8 |
+
dummy_input = next(iter(dataset))["text"]
|
9 |
+
|
10 |
+
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
11 |
+
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
|
12 |
+
|
13 |
+
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
|
14 |
+
|
15 |
+
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
|
16 |
+
model(input_ids)
|
17 |
+
print("hello!")
|