Spaces:
Build error
Build error
File size: 5,644 Bytes
8121fee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# -*- coding: utf-8 -*-
#!/usr/bin/env python3
import os
import sys
import logging
from typing import Callable, Dict, Union
import yaml
import torch
from torch.optim.swa_utils import AveragedModel as torch_average_model
import numpy as np
import pandas as pd
from pprint import pformat
def load_dict_from_csv(csv, cols):
df = pd.read_csv(csv, sep="\t")
output = dict(zip(df[cols[0]], df[cols[1]]))
return output
def init_logger(filename, level="INFO"):
formatter = logging.Formatter(
"[ %(levelname)s : %(asctime)s ] - %(message)s")
logger = logging.getLogger(__name__ + "." + filename)
logger.setLevel(getattr(logging, level))
# Log results to std
# stdhandler = logging.StreamHandler(sys.stdout)
# stdhandler.setFormatter(formatter)
# Dump log to file
filehandler = logging.FileHandler(filename)
filehandler.setFormatter(formatter)
logger.addHandler(filehandler)
# logger.addHandler(stdhandler)
return logger
def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
obj_args = config["args"].copy()
obj_args.update(kwargs)
return getattr(module, config["type"])(**obj_args)
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
"""pprint_dict
:param outputfun: function to use, defaults to sys.stdout
:param in_dict: dict to print
"""
if formatter == 'yaml':
format_fun = yaml.dump
elif formatter == 'pretty':
format_fun = pformat
for line in format_fun(in_dict).split('\n'):
outputfun(line)
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def load_config(config_file):
with open(config_file, "r") as reader:
config = yaml.load(reader, Loader=yaml.FullLoader)
if "inherit_from" in config:
base_config_file = config["inherit_from"]
base_config_file = os.path.join(
os.path.dirname(config_file), base_config_file
)
assert not os.path.samefile(config_file, base_config_file), \
"inherit from itself"
base_config = load_config(base_config_file)
del config["inherit_from"]
merge_a_into_b(config, base_config)
return base_config
return config
def parse_config_or_kwargs(config_file, **kwargs):
yaml_config = load_config(config_file)
# passed kwargs will override yaml config
args = dict(yaml_config, **kwargs)
return args
def store_yaml(config, config_file):
with open(config_file, "w") as con_writer:
yaml.dump(config, con_writer, indent=4, default_flow_style=False)
class MetricImprover:
def __init__(self, mode):
assert mode in ("min", "max")
self.mode = mode
# min: lower -> better; max: higher -> better
self.best_value = np.inf if mode == "min" else -np.inf
def compare(self, x, best_x):
return x < best_x if self.mode == "min" else x > best_x
def __call__(self, x):
if self.compare(x, self.best_value):
self.best_value = x
return True
return False
def state_dict(self):
return self.__dict__
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def fix_batchnorm(model: torch.nn.Module):
def inner(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
model.apply(inner)
def load_pretrained_model(model: torch.nn.Module,
pretrained: Union[str, Dict],
output_fn: Callable = sys.stdout.write):
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
output_fn(f"pretrained {pretrained} not exist!")
return
if hasattr(model, "load_pretrained"):
model.load_pretrained(pretrained)
return
if isinstance(pretrained, dict):
state_dict = pretrained
else:
state_dict = torch.load(pretrained, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
model_dict = model.state_dict()
pretrained_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (
model_dict[k].shape == v.shape)
}
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=True)
class AveragedModel(torch_average_model):
def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
device = b_swa.device
b_model_ = b_model.detach().to(device)
if self.n_averaged == 0:
b_swa.detach().copy_(b_model_)
else:
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
self.n_averaged.to(device)))
self.n_averaged += 1
|