Pedro Cuenca commited on
Commit
de74f11
·
1 Parent(s): 9c0e5c9

fix typos and update requirements

Browse files
seq2seq/requirements.txt CHANGED
@@ -4,3 +4,5 @@ jaxlib>=0.1.59
4
  flax>=0.3.4
5
  optax>=0.0.8
6
  tensorboard
 
 
 
4
  flax>=0.3.4
5
  optax>=0.0.8
6
  tensorboard
7
+ nltk
8
+ wandb
seq2seq/run_seq2seq_flax.py CHANGED
@@ -19,7 +19,7 @@ Script adapted from run_summarization_flax.py
19
  """
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
- import logging
23
  import os
24
  import sys
25
  import time
@@ -60,7 +60,7 @@ from transformers.file_utils import is_offline_mode
60
 
61
  import wandb
62
 
63
- logger = logging.getLogger(__name__)
64
 
65
  try:
66
  nltk.data.find("tokenizers/punkt")
@@ -389,7 +389,7 @@ def main():
389
  data_files["validation"] = data_args.validation_file
390
  if data_args.test_file is not None:
391
  data_files["test"] = data_args.test_file
392
- dataset = load_dataset"csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
393
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
394
  # https://huggingface.co/docs/datasets/loading_datasets.html.
395
 
@@ -411,7 +411,7 @@ def main():
411
 
412
 
413
  # Create a custom model and initialize it randomly
414
- model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
415
 
416
  # Use pre-trained weights for encoder
417
  model.params['model']['encoder'] = base_model.params['model']['encoder']
 
19
  """
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
+ import logging as pylogging # To avoid collision with transformers.utils.logging
23
  import os
24
  import sys
25
  import time
 
60
 
61
  import wandb
62
 
63
+ logger = pylogging.getLogger(__name__)
64
 
65
  try:
66
  nltk.data.find("tokenizers/punkt")
 
389
  data_files["validation"] = data_args.validation_file
390
  if data_args.test_file is not None:
391
  data_files["test"] = data_args.test_file
392
+ dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
393
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
394
  # https://huggingface.co/docs/datasets/loading_datasets.html.
395
 
 
411
 
412
 
413
  # Create a custom model and initialize it randomly
414
+ model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
415
 
416
  # Use pre-trained weights for encoder
417
  model.params['model']['encoder'] = base_model.params['model']['encoder']