|
from __gin__ import dynamic_registration |
|
import tasks |
|
|
|
import __main__ as eval_script |
|
from t5.data import mixtures |
|
from t5x import partitioning |
|
from t5x import utils |
|
|
|
include "t5x/examples/t5/mt5/base.gin" |
|
|
|
CHECKPOINT_PATH = %gin.REQUIRED |
|
SPLIT = %gin.REQUIRED |
|
EVAL_OUTPUT_DIR = "./log/" |
|
DROPOUT_RATE = 0.0 |
|
MIXTURE_OR_TASK_NAME = "parliament" |
|
|
|
eval_script.evaluate: |
|
model = %MODEL |
|
dataset_cfg = @utils.DatasetConfig() |
|
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() |
|
output_dir = %EVAL_OUTPUT_DIR |
|
partitioner = @partitioning.PjitPartitioner() |
|
|
|
utils.DatasetConfig: |
|
mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
|
task_feature_lengths = {"inputs": 512, "targets": 2} |
|
split = %SPLIT |
|
batch_size = 16 |
|
shuffle = False |
|
seed = 42 |
|
|
|
partitioning.PjitPartitioner.num_partitions = 2 |
|
|
|
utils.RestoreCheckpointConfig: |
|
path = %CHECKPOINT_PATH |
|
mode = 'specific' |
|
|