File size: 2,323 Bytes
e67043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import configparser
import logging


def get_conf(conf_file, server_name):
    conf = configparser.ConfigParser()
    conf.read(conf_file)
    sql_server = conf[server_name]
    return sql_server


def get_parser():
    parser = argparse.ArgumentParser(description="Instruction Induction.")

    parser.add_argument("--db_conf", type=str, default="../database/configs/config.ini")

    """ 
    parser.add_argument("--train_data", type=str,
                        default="./data/raw/train/rules.json")
    parser.add_argument("--eval_data", type=str,
                        default="./data/raw/execute/zhenzhi.json")

    parser.add_argument("--data_save", type=str,
                        default="./result/{}/data/")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--runlog", type=str,
                        default="./result/{}/exp_runtime.log")
    parser.add_argument("--logdir", type=str,
                        default="./result/{}/logdir/")
    parser.add_argument("--model_save", type=str,
                        default="./result/{}/model/")

    parser.add_argument("--gen_sample", type=int, default=20)
    parser.add_argument("--gen_demo", type=int, default=16)
    parser.add_argument("--gen_prompt_per_sample", type=int, default=5)
    parser.add_argument("--gen_model", type=str, default="text-davinci-003")
    parser.add_argument("--gen_max_tokens", type=int, default=200)

    parser.add_argument("--eval_sample", type=int, default=20)
    parser.add_argument("--eval_model", type=str, default="text-davinci-003")
    parser.add_argument("--eval_max_tokens", type=int, default=1000)

    parser.add_argument("--storage_budget", type=int, default=500) # limit storage space of built indexes
    """

    return parser


def set_logger(log_file):
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s: - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # log to file
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)

    # log to console
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)

    logger.addHandler(ch)
    logger.addHandler(fh)