update train file
Browse files- src/gptneo_cosmos.py +2 -2
src/gptneo_cosmos.py
CHANGED
@@ -34,7 +34,7 @@ logger.setLevel(logging.INFO)
|
|
34 |
tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
|
35 |
|
36 |
dataset=load_dataset('cosmos_qa')
|
37 |
-
num_choices=
|
38 |
|
39 |
def preprocess(example):
|
40 |
example['context&question']=example['context']+example['question']
|
@@ -232,4 +232,4 @@ if jax.process_index() == 0:
|
|
232 |
push_to_hub=True,
|
233 |
commit_message=f"Cosmos:Saving weights of epoch {epoch} at step {idx}",)
|
234 |
|
235 |
-
summary_writer.flush()
|
|
|
34 |
tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
|
35 |
|
36 |
dataset=load_dataset('cosmos_qa')
|
37 |
+
num_choices=4
|
38 |
|
39 |
def preprocess(example):
|
40 |
example['context&question']=example['context']+example['question']
|
|
|
232 |
push_to_hub=True,
|
233 |
commit_message=f"Cosmos:Saving weights of epoch {epoch} at step {idx}",)
|
234 |
|
235 |
+
summary_writer.flush()
|