|
|
try: |
|
|
import cPickle as pickle |
|
|
except: |
|
|
import pickle |
|
|
import ast |
|
|
import re |
|
|
import inspect |
|
|
import os |
|
|
import logging |
|
|
import numpy as np |
|
|
|
|
|
def cross_entropy_npy(a, b): |
|
|
return a * np.log(b + 1E-9) + (1 - a) * np.log(1 - b + 1E-9) |
|
|
|
|
|
|
|
|
def safe_eval(expr): |
|
|
if type(expr) is str: |
|
|
return ast.literal_eval(expr) |
|
|
else: |
|
|
return expr |
|
|
|
|
|
|
|
|
def logging_config(folder=None, name=None, |
|
|
level=logging.INFO, |
|
|
console_level=logging.DEBUG): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
folder : str or None |
|
|
name : str or None |
|
|
level : int |
|
|
console_level |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
if name is None: |
|
|
name = inspect.stack()[1][1].split('.')[0] |
|
|
if folder is None: |
|
|
folder = os.path.join(os.getcwd(), name) |
|
|
if not os.path.exists(folder): |
|
|
os.makedirs(folder) |
|
|
|
|
|
for handler in logging.root.handlers: |
|
|
logging.root.removeHandler(handler) |
|
|
logging.root.handlers = [] |
|
|
logpath = os.path.join(folder, name + ".log") |
|
|
print("All Logs will be saved to %s" %logpath) |
|
|
logging.root.setLevel(level) |
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logfile = logging.FileHandler(logpath) |
|
|
logfile.setLevel(level) |
|
|
logfile.setFormatter(formatter) |
|
|
logging.root.addHandler(logfile) |
|
|
|
|
|
logconsole = logging.StreamHandler() |
|
|
logconsole.setLevel(console_level) |
|
|
logconsole.setFormatter(formatter) |
|
|
logging.root.addHandler(logconsole) |
|
|
return folder |
|
|
|
|
|
|
|
|
def load_params(prefix, epoch): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prefix : str |
|
|
epoch : int |
|
|
|
|
|
Returns |
|
|
------- |
|
|
arg_params : dict |
|
|
aux_params : dict |
|
|
""" |
|
|
import mxnet.ndarray as nd |
|
|
save_dict = nd.load('%s-%04d.params' % (prefix, epoch)) |
|
|
arg_params = {} |
|
|
aux_params = {} |
|
|
for k, v in save_dict.items(): |
|
|
tp, name = k.split(':', 1) |
|
|
if tp == 'arg': |
|
|
arg_params[name] = v |
|
|
if tp == 'aux': |
|
|
aux_params[name] = v |
|
|
return arg_params, aux_params |
|
|
|
|
|
|
|
|
def parse_ctx(ctx_args): |
|
|
import mxnet as mx |
|
|
ctx = re.findall('([a-z]+)(\d*)', ctx_args) |
|
|
ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx] |
|
|
ctx = [mx.Context(*ele) for ele in ctx] |
|
|
return ctx |
|
|
|