t5-parliament-categorisation / eval_categorisation_base.gin
pere's picture
updated batch run
08a2592
raw
history blame
987 Bytes
from __gin__ import dynamic_registration
import tasks
import __main__ as eval_script
from t5.data import mixtures
from t5x import partitioning
from t5x import utils
include "t5x/examples/t5/mt5/base.gin"
CHECKPOINT_PATH = %gin.REQUIRED # passed via commandline
SPLIT = %gin.REQUIRED # passed via commandline
EVAL_OUTPUT_DIR = "./log/"
DROPOUT_RATE = 0.0 # unused boilerplate
MIXTURE_OR_TASK_NAME = "parliament"
eval_script.evaluate:
model = %MODEL # imported from separate gin file
dataset_cfg = @utils.DatasetConfig()
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
output_dir = %EVAL_OUTPUT_DIR
partitioner = @partitioning.PjitPartitioner()
utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = {"inputs": 512, "targets": 2}
split = %SPLIT
batch_size = 16
shuffle = False
seed = 42
partitioning.PjitPartitioner.num_partitions = 2
utils.RestoreCheckpointConfig:
path = %CHECKPOINT_PATH
mode = 'specific'