File size: 885 Bytes
9dfa4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# A reimplemented version in public environments by Xiao Fu and Mu Hu

import json
import yaml
import logging
import os
import numpy as np
import sys

def load_loss_scheme(loss_config):
    with open(loss_config, 'r') as f:
        loss_json = yaml.safe_load(f)
    return loss_json


DEBUG =0
logger = logging.getLogger()


if DEBUG:
    #coloredlogs.install(level='DEBUG')
    logger.setLevel(logging.DEBUG)
else:
    #coloredlogs.install(level='INFO')
    logger.setLevel(logging.INFO)


strhdlr = logging.StreamHandler()
logger.addHandler(strhdlr)
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s')
strhdlr.setFormatter(formatter)



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_path(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)