lmzjms's picture
Upload 46 files
8121fee
raw
history blame
5.64 kB
# -*- coding: utf-8 -*-
#!/usr/bin/env python3
import os
import sys
import logging
from typing import Callable, Dict, Union
import yaml
import torch
from torch.optim.swa_utils import AveragedModel as torch_average_model
import numpy as np
import pandas as pd
from pprint import pformat
def load_dict_from_csv(csv, cols):
df = pd.read_csv(csv, sep="\t")
output = dict(zip(df[cols[0]], df[cols[1]]))
return output
def init_logger(filename, level="INFO"):
formatter = logging.Formatter(
"[ %(levelname)s : %(asctime)s ] - %(message)s")
logger = logging.getLogger(__name__ + "." + filename)
logger.setLevel(getattr(logging, level))
# Log results to std
# stdhandler = logging.StreamHandler(sys.stdout)
# stdhandler.setFormatter(formatter)
# Dump log to file
filehandler = logging.FileHandler(filename)
filehandler.setFormatter(formatter)
logger.addHandler(filehandler)
# logger.addHandler(stdhandler)
return logger
def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
obj_args = config["args"].copy()
obj_args.update(kwargs)
return getattr(module, config["type"])(**obj_args)
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
"""pprint_dict
:param outputfun: function to use, defaults to sys.stdout
:param in_dict: dict to print
"""
if formatter == 'yaml':
format_fun = yaml.dump
elif formatter == 'pretty':
format_fun = pformat
for line in format_fun(in_dict).split('\n'):
outputfun(line)
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def load_config(config_file):
with open(config_file, "r") as reader:
config = yaml.load(reader, Loader=yaml.FullLoader)
if "inherit_from" in config:
base_config_file = config["inherit_from"]
base_config_file = os.path.join(
os.path.dirname(config_file), base_config_file
)
assert not os.path.samefile(config_file, base_config_file), \
"inherit from itself"
base_config = load_config(base_config_file)
del config["inherit_from"]
merge_a_into_b(config, base_config)
return base_config
return config
def parse_config_or_kwargs(config_file, **kwargs):
yaml_config = load_config(config_file)
# passed kwargs will override yaml config
args = dict(yaml_config, **kwargs)
return args
def store_yaml(config, config_file):
with open(config_file, "w") as con_writer:
yaml.dump(config, con_writer, indent=4, default_flow_style=False)
class MetricImprover:
def __init__(self, mode):
assert mode in ("min", "max")
self.mode = mode
# min: lower -> better; max: higher -> better
self.best_value = np.inf if mode == "min" else -np.inf
def compare(self, x, best_x):
return x < best_x if self.mode == "min" else x > best_x
def __call__(self, x):
if self.compare(x, self.best_value):
self.best_value = x
return True
return False
def state_dict(self):
return self.__dict__
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def fix_batchnorm(model: torch.nn.Module):
def inner(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
model.apply(inner)
def load_pretrained_model(model: torch.nn.Module,
pretrained: Union[str, Dict],
output_fn: Callable = sys.stdout.write):
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
output_fn(f"pretrained {pretrained} not exist!")
return
if hasattr(model, "load_pretrained"):
model.load_pretrained(pretrained)
return
if isinstance(pretrained, dict):
state_dict = pretrained
else:
state_dict = torch.load(pretrained, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
model_dict = model.state_dict()
pretrained_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (
model_dict[k].shape == v.shape)
}
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=True)
class AveragedModel(torch_average_model):
def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
device = b_swa.device
b_model_ = b_model.detach().to(device)
if self.n_averaged == 0:
b_swa.detach().copy_(b_model_)
else:
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
self.n_averaged.to(device)))
self.n_averaged += 1