Pedro Cuenca commited on
Commit
b765185
2 Parent(s): d7be08c 7e6e1fe

Merge branch 'main' of github.com:borisdayma/dalle-mini into main

Browse files
README.md CHANGED
@@ -1,42 +1,76 @@
1
- ## DALL-E Mini - Generate image from text
 
 
 
 
 
 
 
 
2
 
3
- ## Tentative Strategy of training (proposed by Luke and Suraj)
4
 
5
- ### Data:
6
- * [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m) Dataset (already loaded and preprocessed in TPU VM by Luke).
7
- * [YFCC100M Subset](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md)
8
- * [Coneptual Captions 3M](https://github.com/google-research-datasets/conceptual-captions)
9
 
10
- ### Architecture:
11
- * Use the Taming Transformers VQ-GAN (with 16384 tokens)
12
- * Use a seq2seq (language encoder --> image decoder) model with a pretrained non-autoregressive encoder (e.g. BERT) and an autoregressive decoder (like GPT).
13
 
14
- ### Remaining Architecture Questions:
15
- * Whether to freeze the text encoder?
16
- * Whether to finetune the VQ-GAN?
17
- * Which text encoder to use (e.g. BERT, RoBERTa, etc.)?
18
- * Hyperparameter choices for the decoder (e.g. positional embedding, initialization, etc.)
19
 
20
- ## TODO
21
 
22
- * experiment with flax/jax and setup of the TPU instance that we should get shortly
23
- * work on dataset loading - [see suggested datasets](https://discuss.huggingface.co/t/dall-e-mini-version/7324/4)
24
- * Optionally create the OpenAI YFCC100M subset (see [this post](https://discuss.huggingface.co/t/dall-e-mini-version/7324/30?u=boris))
25
- * work on text/image encoding
26
- * concatenate inputs (not sure if we need fixed length for text or use a special token separating text & image)
27
- * adapt training script
28
- * create inference function
29
- * integrate CLIP for better results (only if we have the time)
30
- * work on a demo (streamlit or colab or maybe just HF widget)
31
- * document (set up repo on model hub per instructions, start on README writeup…)
32
- * help with coordinating activities & progress
33
 
 
34
 
35
- ## Dependencies Installation
36
- You should create a new python virtual environment and install the project dependencies inside the virtual env. You need to use the `-f` (`--find-links`) option for `pip` to be able to find the appropriate `libtpu` required for the TPU hardware:
 
 
 
 
 
 
 
 
 
37
 
38
  ```
39
  $ pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
40
  ```
41
 
42
  If you use `conda`, you can create the virtual env and install everything using: `conda env update -f environments.yaml`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Dalle Mini
3
+ emoji: 🎨
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ app_file: app/app.py
8
+ pinned: false
9
+ ---
10
 
11
+ # DALL-E Mini
12
 
13
+ _Generate images from a text prompt_
 
 
 
14
 
15
+ TODO: add some cool example
 
 
16
 
17
+ ## [Create my own images with the demo →](TODO)
 
 
 
 
18
 
19
+ ## How does it work?
20
 
21
+ Refer to [our report](TODO).
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ## Development
24
 
25
+ This section is for the adventurous people wanting to look into the code.
26
+
27
+ ### Dependencies Installation
28
+
29
+ The root folder and associated `requirements.txt` is only for the app.
30
+
31
+ You will find necessary requirements in each sub-section.
32
+
33
+ You should create a new python virtual environment and install the project dependencies inside the virtual env. You need to use the `-f` (`--find-links`) option for `pip` to be able to find the appropriate `libtpu` required for the TPU hardware.
34
+
35
+ Adapt the installation to your own hardware and follow library installation instructions.
36
 
37
  ```
38
  $ pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
39
  ```
40
 
41
  If you use `conda`, you can create the virtual env and install everything using: `conda env update -f environments.yaml`
42
+
43
+ ### Training of VQGAN
44
+
45
+ The VQGAN was trained using [taming-transformers](https://github.com/CompVis/taming-transformers).
46
+
47
+ We recommend using the latest version available.
48
+
49
+ ### Conversion of VQGAN to JAX
50
+
51
+ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
52
+
53
+ ### Training of Seq2Seq
54
+
55
+ Refer to `seq2seq` folder (some parameters may have been hardcoded for convenience when training on our TPU VM).
56
+
57
+ ### Inference
58
+
59
+ Refer to the demo notebooks.
60
+ TODO: add links
61
+
62
+ ## Authors
63
+
64
+ - [Boris Dayma](https://github.com/borisdayma)
65
+ - [Suraj Patil](https://github.com/patil-suraj)
66
+ - [Pedro Cuenca](https://github.com/pcuenca)
67
+ - [Khalid Saifullah](https://github.com/khalidsaifullaah)
68
+ - [Tanishq Abraham](https://github.com/tmabraham)
69
+ - [Phúc Lê Khắc](https://github.com/lkhphuc)
70
+ - [Luke Melas](https://github.com/lukemelas)
71
+ - [Ritobrata Ghosh](https://github.com/ghosh-r)
72
+
73
+ ## Acknowledgements
74
+
75
+ - 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
76
+ - Google Cloud team for providing access to TPU's
app/app_gradio.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # Uncomment to run on cpu
5
+ #import os
6
+ #os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
+
8
+ import random
9
+
10
+ import jax
11
+ import flax.linen as nn
12
+ from flax.training.common_utils import shard
13
+ from flax.jax_utils import replicate, unreplicate
14
+
15
+ from transformers.models.bart.modeling_flax_bart import *
16
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
17
+
18
+
19
+ import requests
20
+ from PIL import Image
21
+ import numpy as np
22
+ import matplotlib.pyplot as plt
23
+
24
+
25
+ from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
26
+
27
+ import gradio as gr
28
+
29
+
30
+ # TODO: set those args in a config file
31
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
32
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
33
+ BOS_TOKEN_ID = 16384
34
+ BASE_MODEL = 'flax-community/dalle-mini'
35
+
36
+ class CustomFlaxBartModule(FlaxBartModule):
37
+ def setup(self):
38
+ # we keep shared to easily load pre-trained weights
39
+ self.shared = nn.Embed(
40
+ self.config.vocab_size,
41
+ self.config.d_model,
42
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
43
+ dtype=self.dtype,
44
+ )
45
+ # a separate embedding is used for the decoder
46
+ self.decoder_embed = nn.Embed(
47
+ OUTPUT_VOCAB_SIZE,
48
+ self.config.d_model,
49
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
50
+ dtype=self.dtype,
51
+ )
52
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
53
+
54
+ # the decoder has a different config
55
+ decoder_config = BartConfig(self.config.to_dict())
56
+ decoder_config.max_position_embeddings = OUTPUT_LENGTH
57
+ decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
58
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
59
+
60
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
61
+ def setup(self):
62
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
63
+ self.lm_head = nn.Dense(
64
+ OUTPUT_VOCAB_SIZE,
65
+ use_bias=False,
66
+ dtype=self.dtype,
67
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
68
+ )
69
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
70
+
71
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
72
+ module_class = CustomFlaxBartForConditionalGenerationModule
73
+
74
+ # create our model
75
+ # FIXME: Save tokenizer to hub so we can load from there
76
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
77
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
78
+ model.config.force_bos_token_to_be_generated = False
79
+ model.config.forced_bos_token_id = None
80
+ model.config.forced_eos_token_id = None
81
+
82
+ vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
83
+
84
+ def custom_to_pil(x):
85
+ x = np.clip(x, 0., 1.)
86
+ x = (255*x).astype(np.uint8)
87
+ x = Image.fromarray(x)
88
+ if not x.mode == "RGB":
89
+ x = x.convert("RGB")
90
+ return x
91
+
92
+ def generate(input, rng, params):
93
+ return model.generate(
94
+ **input,
95
+ max_length=257,
96
+ num_beams=1,
97
+ do_sample=True,
98
+ prng_key=rng,
99
+ eos_token_id=50000,
100
+ pad_token_id=50000,
101
+ params=params,
102
+ )
103
+
104
+ def get_images(indices, params):
105
+ return vqgan.decode_code(indices, params=params)
106
+
107
+ def plot_images(images):
108
+ fig = plt.figure(figsize=(40, 20))
109
+ columns = 4
110
+ rows = 2
111
+ plt.subplots_adjust(hspace=0, wspace=0)
112
+
113
+ for i in range(1, columns*rows +1):
114
+ fig.add_subplot(rows, columns, i)
115
+ plt.imshow(images[i-1])
116
+ plt.gca().axes.get_yaxis().set_visible(False)
117
+ plt.show()
118
+
119
+ def stack_reconstructions(images):
120
+ w, h = images[0].size[0], images[0].size[1]
121
+ img = Image.new("RGB", (len(images)*w, h))
122
+ for i, img_ in enumerate(images):
123
+ img.paste(img_, (i*w,0))
124
+ return img
125
+
126
+ p_generate = jax.pmap(generate, "batch")
127
+ p_get_images = jax.pmap(get_images, "batch")
128
+
129
+ bart_params = replicate(model.params)
130
+ vqgan_params = replicate(vqgan.params)
131
+
132
+ # ## CLIP Scoring
133
+ from transformers import CLIPProcessor, FlaxCLIPModel
134
+
135
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
136
+ print("Initialize FlaxCLIPModel")
137
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
138
+ print("Initialize CLIPProcessor")
139
+
140
+ def hallucinate(prompt, num_images=64):
141
+ prompt = [prompt] * jax.device_count()
142
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
143
+ inputs = shard(inputs)
144
+
145
+ all_images = []
146
+ for i in range(num_images // jax.device_count()):
147
+ key = random.randint(0, 1e7)
148
+ rng = jax.random.PRNGKey(key)
149
+ rngs = jax.random.split(rng, jax.local_device_count())
150
+ indices = p_generate(inputs, rngs, bart_params).sequences
151
+ indices = indices[:, :, 1:]
152
+
153
+ images = p_get_images(indices, vqgan_params)
154
+ images = np.squeeze(np.asarray(images), 1)
155
+ for image in images:
156
+ all_images.append(custom_to_pil(image))
157
+ return all_images
158
+
159
+ def clip_top_k(prompt, images, k=8):
160
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
161
+ outputs = clip(**inputs)
162
+ logits = outputs.logits_per_text
163
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
164
+ return [images[score] for score in scores]
165
+
166
+ def captioned_strip(images, caption):
167
+ increased_h = 0 if caption is None else 48
168
+ w, h = images[0].size[0], images[0].size[1]
169
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
170
+ for i, img_ in enumerate(images):
171
+ img.paste(img_, (i*w, increased_h))
172
+
173
+ if caption is not None:
174
+ draw = ImageDraw.Draw(img)
175
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
176
+ draw.text((20, 3), caption, (255,255,255), font=font)
177
+ return img
178
+
179
+ def run_inference(prompt, num_images=64, num_preds=8):
180
+ images = hallucinate(prompt, num_images=num_images)
181
+ images = clip_top_k(prompt, images, k=num_preds)
182
+ predictions_strip = captioned_strip(images, None)
183
+ return predictions_strip
184
+
185
+ gr.Interface(run_inference,
186
+ inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
187
+ outputs=gr.outputs.Image(label='Generated image'),
188
+ title='DALLE-mini - HuggingFace Community Week',
189
+ description='This is a demo of the DALLE-mini model trained with Jax/Flax on TPU v3-8s during the HuggingFace Community Week',
190
+ article="<p style='text-align: center'> DALLE-mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
191
+ layout='vertical',
192
+ theme='huggingface',
193
+ examples=[['an armchair in the shape of an avocado']],
194
+ server_port=8999).launch(share=True)
app/requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- # Requirements for huggingface spaces
2
- -f https://storage.googleapis.com/jax-releases/jax_releases.html
3
- jax[cuda111]
4
- flax
5
- requests
6
- -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
- -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
- flax
9
- jupyter
10
- wandb
11
- ftfy
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/vqgan_jax/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ## vqgan-jax
2
+
3
+ Files copied from [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax/tree/main/vqgan_jax)
4
+
5
+ Required for VQGAN Jax model.
dalle_mini/vqgan_jax/convert_pt_model_to_jax.py DELETED
@@ -1,109 +0,0 @@
1
- import re
2
-
3
- import jax.numpy as jnp
4
- from flax.traverse_util import flatten_dict, unflatten_dict
5
-
6
- import torch
7
-
8
- from modeling_flax_vqgan import VQModel
9
- from configuration_vqgan import VQGANConfig
10
-
11
-
12
- regex = r"\w+[.]\d+"
13
-
14
-
15
- def rename_key(key):
16
- pats = re.findall(regex, key)
17
- for pat in pats:
18
- key = key.replace(pat, "_".join(pat.split(".")))
19
- return key
20
-
21
-
22
- # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
- # convert pytorch tensor to numpy
25
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
-
27
- random_flax_state_dict = flatten_dict(flax_model.params)
28
- flax_state_dict = {}
29
-
30
- remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
- flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
- )
33
- add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
- flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
- )
36
-
37
- # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
- for pt_key, pt_tensor in pt_state_dict.items():
39
- pt_tuple_key = tuple(pt_key.split("."))
40
-
41
- has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
- require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
-
44
- if remove_base_model_prefix and has_base_model_prefix:
45
- pt_tuple_key = pt_tuple_key[1:]
46
- elif add_base_model_prefix and require_base_model_prefix:
47
- pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
-
49
- # Correctly rename weight parameters
50
- if (
51
- "norm" in pt_key
52
- and (pt_tuple_key[-1] == "bias")
53
- and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
- ):
55
- pt_tensor = pt_tensor[None, None, None, :]
56
- elif (
57
- "norm" in pt_key
58
- and (pt_tuple_key[-1] == "bias")
59
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
- ):
61
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
- pt_tensor = pt_tensor[None, None, None, :]
63
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
- pt_tensor = pt_tensor[None, None, None, :]
66
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
- elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
- # conv layer
70
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
- elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
- # linear layer
74
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
- pt_tensor = pt_tensor.T
76
- elif pt_tuple_key[-1] == "gamma":
77
- pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
- elif pt_tuple_key[-1] == "beta":
79
- pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
-
81
- if pt_tuple_key in random_flax_state_dict:
82
- if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
- raise ValueError(
84
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
- f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
- )
87
-
88
- # also add unexpected weight so that warning is thrown
89
- flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
-
91
- return unflatten_dict(flax_state_dict)
92
-
93
-
94
- def convert_model(config_path, pt_state_dict_path, save_path):
95
- config = VQGANConfig.from_pretrained(config_path)
96
- model = VQModel(config)
97
-
98
- state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
- keys = list(state_dict.keys())
100
- for key in keys:
101
- if key.startswith("loss"):
102
- state_dict.pop(key)
103
- continue
104
- renamed_key = rename_key(key)
105
- state_dict[renamed_key] = state_dict.pop(key)
106
-
107
- state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
- model.params = unflatten_dict(state)
109
- model.save_pretrained(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,12 +1,13 @@
1
- # Note: install with the following command:
2
- # pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
- # Otherwise it won't find the appropriate libtpu_nightly
 
4
  requests
5
- jax[tpu]>=0.2.16
6
  -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
  -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
  flax
9
  jupyter
10
-
11
- # Inference
12
  ftfy
 
 
 
1
+ # Requirements for huggingface spaces
2
+ -f https://storage.googleapis.com/jax-releases/jax_releases.html
3
+ jax[cuda111]
4
+ flax
5
  requests
 
6
  -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
  -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
  flax
9
  jupyter
10
+ wandb
 
11
  ftfy
12
+ streamlit
13
+ gradio
environment.yaml → seq2seq/environment.yaml RENAMED
File without changes
seq2seq/requirements.txt CHANGED
@@ -1,8 +1,15 @@
1
- datasets >= 1.1.3
2
- jax>=0.2.8
3
- jaxlib>=0.1.59
4
- flax>=0.3.4
5
- optax>=0.0.8
6
- tensorboard
7
- nltk
 
 
8
  wandb
 
 
 
 
 
 
1
+ # Note: install with the following command:
2
+ # pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
+ # Otherwise it won't find the appropriate libtpu_nightly
4
+ requests
5
+ jax[tpu]>=0.2.16
6
+ -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
+ -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
+ flax
9
+ jupyter
10
  wandb
11
+ nltk
12
+ optax
13
+
14
+ # Inference
15
+ ftfy