|
import logging |
|
import os |
|
import sys |
|
import warnings |
|
|
|
import hydra |
|
from hydra.core.hydra_config import HydraConfig |
|
from hydra.utils import instantiate |
|
from omegaconf import DictConfig |
|
|
|
from transforna import compute_cv, infer_benchmark, infer_tcga, train |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def add_config_to_sys_path(): |
|
cfg = HydraConfig.get() |
|
config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0] |
|
sys.path.append(config_path) |
|
|
|
|
|
|
|
@hydra.main(config_path='../conf', config_name="main_config") |
|
def main(cfg: DictConfig) -> None: |
|
add_config_to_sys_path() |
|
|
|
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir |
|
|
|
path = os.getcwd() |
|
|
|
cfg['train_config'] = instantiate(cfg['train_config']).__dict__ |
|
cfg['model_config'] = instantiate(cfg['model_config']).__dict__ |
|
|
|
|
|
cfg['model_config']["model_input"] = cfg["model_name"] |
|
|
|
|
|
if cfg["inference"]: |
|
logger.info(f"Started inference on {cfg['task']}") |
|
if cfg['task'] == 'tcga': |
|
return infer_tcga(cfg,path=path) |
|
else: |
|
return infer_benchmark(cfg,path=path) |
|
else: |
|
if cfg["cross_val"]: |
|
compute_cv(cfg,path,output_dir=output_dir) |
|
|
|
else: |
|
train(cfg,path=path,output_dir=output_dir) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|