Joshua Lochner commited on
Commit
721bf64
1 Parent(s): 09cabec

Revert model input size back to 512 tokens

Browse files
Files changed (2) hide show
  1. src/model.py +3 -5
  2. src/train.py +0 -1
src/model.py CHANGED
@@ -106,15 +106,13 @@ def get_model_tokenizer(model_name_or_path, cache_dir=None, no_cuda=False):
106
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
107
 
108
  tokenizer = AutoTokenizer.from_pretrained(
109
- model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir)
110
 
111
  # Ensure model and tokenizer contain the custom tokens
112
  CustomTokens.add_custom_tokens(tokenizer)
113
  model.resize_token_embeddings(len(tokenizer))
114
 
115
- # TODO add this back: means that different models will have different training data
116
- # Currently we only send 512 tokens to the model each time...
117
- # Adjust based on dimensions of model
118
- tokenizer.model_max_length = model.config.d_model
119
 
120
  return model, tokenizer
 
106
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
107
 
108
  tokenizer = AutoTokenizer.from_pretrained(
109
+ model_name_or_path, cache_dir=cache_dir)
110
 
111
  # Ensure model and tokenizer contain the custom tokens
112
  CustomTokens.add_custom_tokens(tokenizer)
113
  model.resize_token_embeddings(len(tokenizer))
114
 
115
+ # TODO find a way to adjust based on model's input size
116
+ # print('tokenizer.model_max_length', tokenizer.model_max_length)
 
 
117
 
118
  return model, tokenizer
src/train.py CHANGED
@@ -298,7 +298,6 @@ def main():
298
  from model import get_model_tokenizer
299
  model, tokenizer = get_model_tokenizer(
300
  model_args.model_name_or_path, model_args.cache_dir, training_args.no_cuda)
301
- # max_tokenizer_length = model.config.d_model
302
 
303
  # Preprocessing the datasets.
304
  # We need to tokenize inputs and targets.
 
298
  from model import get_model_tokenizer
299
  model, tokenizer = get_model_tokenizer(
300
  model_args.model_name_or_path, model_args.cache_dir, training_args.no_cuda)
 
301
 
302
  # Preprocessing the datasets.
303
  # We need to tokenize inputs and targets.