|
""" |
|
Main command to run diverse genomic embedding benchmarks (DGEB) on a model. |
|
example command to run DGEB: |
|
python run_dgeb.py -m facebook/esm2_t6_8M_UR50D |
|
""" |
|
|
|
import argparse |
|
import logging |
|
import os |
|
|
|
import dgeb |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
ALL_TASK_NAMES = dgeb.get_all_task_names() |
|
ALL_MODEL_NAMES = dgeb.get_all_model_names() |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-m", |
|
"--model", |
|
type=str, |
|
default=None, |
|
help=f"Model to evaluate. Choose from {ALL_MODEL_NAMES}", |
|
) |
|
parser.add_argument( |
|
"-t", |
|
"--tasks", |
|
type=lambda s: [item for item in s.split(",")], |
|
default=None, |
|
help=f"Comma separated tasks to evaluate on. Choose from {ALL_TASK_NAMES} or do not specify to evaluate on all tasks", |
|
) |
|
parser.add_argument( |
|
"-l", |
|
"--layers", |
|
type=str, |
|
default=None, |
|
help="Layer to evaluate. Comma separated list of integers or 'mid' and 'last'. Default is 'mid,last'", |
|
) |
|
parser.add_argument( |
|
"--devices", |
|
type=str, |
|
default="0", |
|
help="Comma separated list of GPU device ids to use. Default is 0 (if GPUs are detected).", |
|
) |
|
parser.add_argument( |
|
"--output_folder", |
|
type=str, |
|
default=None, |
|
help="Output directory for results. Will default to results/model_name if not set.", |
|
) |
|
parser.add_argument( |
|
"-v", "--verbosity", type=int, default=2, help="Verbosity level" |
|
) |
|
parser.add_argument( |
|
"-b", "--batch_size", type=int, default=64, help="Batch size for evaluation" |
|
) |
|
parser.add_argument( |
|
"--max_seq_len", |
|
type=int, |
|
default=1024, |
|
help="Maximum sequence length for model, default is 1024.", |
|
) |
|
parser.add_argument( |
|
"--pool_type", |
|
type=str, |
|
default="mean", |
|
help="Pooling type for model, choose from mean, max, cls, last. Default is mean.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.verbosity == 0: |
|
logging.getLogger("geb").setLevel(logging.CRITICAL) |
|
elif args.verbosity == 1: |
|
logging.getLogger("geb").setLevel(logging.WARNING) |
|
elif args.verbosity == 2: |
|
logging.getLogger("geb").setLevel(logging.INFO) |
|
elif args.verbosity == 3: |
|
logging.getLogger("geb").setLevel(logging.DEBUG) |
|
|
|
if args.model is None: |
|
raise ValueError("Please specify a model using the -m or --model argument") |
|
|
|
|
|
try: |
|
devices = [int(device) for device in args.devices.split(",")] |
|
except ValueError: |
|
raise ValueError("Devices must be comma separated list of integers") |
|
|
|
layers = args.layers |
|
if layers: |
|
if layers not in ["mid", "last"]: |
|
|
|
try: |
|
layers = [int(layer) for layer in layers.split(",")] |
|
except ValueError: |
|
raise ValueError("Layers must be a list of integers.") |
|
|
|
model_name = args.model.split("/")[-1] |
|
output_folder = args.output_folder |
|
if output_folder is None: |
|
output_folder = os.path.join("results", model_name) |
|
|
|
if not os.path.exists(output_folder): |
|
os.makedirs(output_folder) |
|
logger.info(f"Results will be saved to {output_folder}") |
|
|
|
|
|
model = dgeb.get_model( |
|
model_name=args.model, |
|
layers=layers, |
|
devices=devices, |
|
max_seq_length=args.max_seq_len, |
|
batch_size=args.batch_size, |
|
pool_type=args.pool_type, |
|
) |
|
|
|
all_tasks_for_modality = dgeb.get_tasks_by_modality(model.modality) |
|
|
|
if args.tasks: |
|
task_list = dgeb.get_tasks_by_name(args.tasks) |
|
if not all([task.metadata.modality == model.modality for task in task_list]): |
|
raise ValueError(f"Tasks must be one of {all_tasks_for_modality}") |
|
else: |
|
task_list = all_tasks_for_modality |
|
evaluation = dgeb.DGEB(tasks=task_list) |
|
_ = evaluation.run(model) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|