cakiki commited on
Commit
fc0235a
1 Parent(s): e9be75f

Fix config usage

Browse files
Files changed (1) hide show
  1. clone_sentdex_model_tokenizer.py +11 -4
clone_sentdex_model_tokenizer.py CHANGED
@@ -2,12 +2,19 @@
2
  from transformers import AutoTokenizer, TFAutoModelForCausalLM, AutoModelForCausalLM, GPT2Config
3
  import tensorflow as tf
4
 
5
- config = GPT2Config.from_pretrained("Sentdex/GPyT")
 
 
 
 
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained("Sentdex/GPyT")
7
- tf_model = TFAutoModelForCausalLM.from_pretrained("Sentdex/GPyT")
8
- pytorch_model = AutoModelForCausalLM.from_pretrained("Sentdex/GPyT")
9
 
10
  config.save_pretrained('./')
11
  tokenizer.save_pretrained(save_directory='./')
12
- tf_model.save_pretrained(save_directory='./', saved_model=True)
13
  pytorch_model.save_pretrained(save_directory='./')
 
2
  from transformers import AutoTokenizer, TFAutoModelForCausalLM, AutoModelForCausalLM, GPT2Config
3
  import tensorflow as tf
4
 
5
+ task_specific_params = {
6
+ "text-generation": {
7
+ "do_sample": False,
8
+ "max_length": 50
9
+ }
10
+ }
11
+
12
+ config = GPT2Config.from_pretrained("Sentdex/GPyT", _name_or_path='prophetikai/code-gpt', use_cache=True, task_specific_params=task_specific_params)
13
  tokenizer = AutoTokenizer.from_pretrained("Sentdex/GPyT")
14
+ tf_model = TFAutoModelForCausalLM.from_pretrained("Sentdex/GPyT", config=config)
15
+ pytorch_model = AutoModelForCausalLM.from_pretrained("Sentdex/GPyT", config=config)
16
 
17
  config.save_pretrained('./')
18
  tokenizer.save_pretrained(save_directory='./')
19
+ tf_model.save_pretrained(save_directory='./', saved_model=True, version='sentdex')
20
  pytorch_model.save_pretrained(save_directory='./')