|
import os |
|
import sys |
|
|
|
import hydra |
|
|
|
from refine import refine |
|
from repeated_sample import repeated_sample |
|
from funsearch import funsearch |
|
|
|
@hydra.main(config_path='configs', config_name='default', version_base=None) |
|
def main(cfg): |
|
|
|
print(f'Method: {cfg.method.name}') |
|
print(f'Model name: {cfg.model.name}') |
|
print(f'PDE name: {cfg.pde.name}') |
|
|
|
print(f'Working folder: {cfg.working_folder}') |
|
if not os.path.exists(cfg.working_folder): |
|
os.makedirs(cfg.working_folder) |
|
if cfg.redirect_stdout: |
|
sys.stdout = open(os.path.join(cfg.working_folder, 'stdout.txt'), 'w') |
|
|
|
if cfg.method.name[:6] == 'refine': |
|
refine(cfg) |
|
elif cfg.method.name == 'repeated_sample': |
|
repeated_sample(cfg) |
|
elif cfg.method.name == 'funsearch': |
|
funsearch(cfg) |
|
else: |
|
raise NotImplementedError(f'Unknown method: {cfg.method.name}') |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|