mariagrandury commited on
Commit
b8db3fd
1 Parent(s): 5670b4e
Files changed (3) hide show
  1. .gitignore +2 -0
  2. requirements.txt +12 -0
  3. test_setup.py +17 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ data_cache/
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa
2
+ ffmpeg
3
+ dataclasses
4
+ pathlib
5
+ tqdm
6
+ numpy
7
+ tensorflow
8
+ datasets
9
+ transformers
10
+ flax
11
+ jax
12
+ optax
test_setup.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#tpu-vm
2
+ from transformers import FlaxRobertaModel, RobertaTokenizerFast
3
+ from datasets import load_dataset
4
+ import jax
5
+
6
+ dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
7
+
8
+ dummy_input = next(iter(dataset))["text"]
9
+
10
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
11
+ input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
12
+
13
+ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
14
+
15
+ # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
16
+ model(input_ids)
17
+ print("hello!")