Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
raw
history blame
No virus
1.63 kB
import logging
import os
import sys
import warnings
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig
from transforna import compute_cv, infer_benchmark, infer_tcga, train
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
def add_config_to_sys_path():
cfg = HydraConfig.get()
config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0]
sys.path.append(config_path)
#transforna could called from anywhere:
#python -m transforna --config-dir = /path/to/configs
@hydra.main(config_path='../conf', config_name="main_config")
def main(cfg: DictConfig) -> None:
add_config_to_sys_path()
#get path of hydra outputs folder
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
path = os.getcwd()
#init train and model config
cfg['train_config'] = instantiate(cfg['train_config']).__dict__
cfg['model_config'] = instantiate(cfg['model_config']).__dict__
#update model config with the name of the model
cfg['model_config']["model_input"] = cfg["model_name"]
#inference or train
if cfg["inference"]:
logger.info(f"Started inference on {cfg['task']}")
if cfg['task'] == 'tcga':
return infer_tcga(cfg,path=path)
else:
return infer_benchmark(cfg,path=path)
else:
if cfg["cross_val"]:
compute_cv(cfg,path,output_dir=output_dir)
else:
train(cfg,path=path,output_dir=output_dir)
if __name__ == "__main__":
main()