Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
#!/usr/bin/env python3 | |
import os | |
import sys | |
import logging | |
from typing import Callable, Dict, Union | |
import yaml | |
import torch | |
from torch.optim.swa_utils import AveragedModel as torch_average_model | |
import numpy as np | |
import pandas as pd | |
from pprint import pformat | |
def load_dict_from_csv(csv, cols): | |
df = pd.read_csv(csv, sep="\t") | |
output = dict(zip(df[cols[0]], df[cols[1]])) | |
return output | |
def init_logger(filename, level="INFO"): | |
formatter = logging.Formatter( | |
"[ %(levelname)s : %(asctime)s ] - %(message)s") | |
logger = logging.getLogger(__name__ + "." + filename) | |
logger.setLevel(getattr(logging, level)) | |
# Log results to std | |
# stdhandler = logging.StreamHandler(sys.stdout) | |
# stdhandler.setFormatter(formatter) | |
# Dump log to file | |
filehandler = logging.FileHandler(filename) | |
filehandler.setFormatter(formatter) | |
logger.addHandler(filehandler) | |
# logger.addHandler(stdhandler) | |
return logger | |
def init_obj(module, config, **kwargs):# 'captioning.models.encoder' | |
obj_args = config["args"].copy() | |
obj_args.update(kwargs) | |
return getattr(module, config["type"])(**obj_args) | |
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'): | |
"""pprint_dict | |
:param outputfun: function to use, defaults to sys.stdout | |
:param in_dict: dict to print | |
""" | |
if formatter == 'yaml': | |
format_fun = yaml.dump | |
elif formatter == 'pretty': | |
format_fun = pformat | |
for line in format_fun(in_dict).split('\n'): | |
outputfun(line) | |
def merge_a_into_b(a, b): | |
# merge dict a into dict b. values in a will overwrite b. | |
for k, v in a.items(): | |
if isinstance(v, dict) and k in b: | |
assert isinstance( | |
b[k], dict | |
), "Cannot inherit key '{}' from base!".format(k) | |
merge_a_into_b(v, b[k]) | |
else: | |
b[k] = v | |
def load_config(config_file): | |
with open(config_file, "r") as reader: | |
config = yaml.load(reader, Loader=yaml.FullLoader) | |
if "inherit_from" in config: | |
base_config_file = config["inherit_from"] | |
base_config_file = os.path.join( | |
os.path.dirname(config_file), base_config_file | |
) | |
assert not os.path.samefile(config_file, base_config_file), \ | |
"inherit from itself" | |
base_config = load_config(base_config_file) | |
del config["inherit_from"] | |
merge_a_into_b(config, base_config) | |
return base_config | |
return config | |
def parse_config_or_kwargs(config_file, **kwargs): | |
yaml_config = load_config(config_file) | |
# passed kwargs will override yaml config | |
args = dict(yaml_config, **kwargs) | |
return args | |
def store_yaml(config, config_file): | |
with open(config_file, "w") as con_writer: | |
yaml.dump(config, con_writer, indent=4, default_flow_style=False) | |
class MetricImprover: | |
def __init__(self, mode): | |
assert mode in ("min", "max") | |
self.mode = mode | |
# min: lower -> better; max: higher -> better | |
self.best_value = np.inf if mode == "min" else -np.inf | |
def compare(self, x, best_x): | |
return x < best_x if self.mode == "min" else x > best_x | |
def __call__(self, x): | |
if self.compare(x, self.best_value): | |
self.best_value = x | |
return True | |
return False | |
def state_dict(self): | |
return self.__dict__ | |
def load_state_dict(self, state_dict): | |
self.__dict__.update(state_dict) | |
def fix_batchnorm(model: torch.nn.Module): | |
def inner(module): | |
class_name = module.__class__.__name__ | |
if class_name.find("BatchNorm") != -1: | |
module.eval() | |
model.apply(inner) | |
def load_pretrained_model(model: torch.nn.Module, | |
pretrained: Union[str, Dict], | |
output_fn: Callable = sys.stdout.write): | |
if not isinstance(pretrained, dict) and not os.path.exists(pretrained): | |
output_fn(f"pretrained {pretrained} not exist!") | |
return | |
if hasattr(model, "load_pretrained"): | |
model.load_pretrained(pretrained) | |
return | |
if isinstance(pretrained, dict): | |
state_dict = pretrained | |
else: | |
state_dict = torch.load(pretrained, map_location="cpu") | |
if "model" in state_dict: | |
state_dict = state_dict["model"] | |
model_dict = model.state_dict() | |
pretrained_dict = { | |
k: v for k, v in state_dict.items() if (k in model_dict) and ( | |
model_dict[k].shape == v.shape) | |
} | |
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}") | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict, strict=True) | |
class AveragedModel(torch_average_model): | |
def update_parameters(self, model): | |
for p_swa, p_model in zip(self.parameters(), model.parameters()): | |
device = p_swa.device | |
p_model_ = p_model.detach().to(device) | |
if self.n_averaged == 0: | |
p_swa.detach().copy_(p_model_) | |
else: | |
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, | |
self.n_averaged.to(device))) | |
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()): | |
device = b_swa.device | |
b_model_ = b_model.detach().to(device) | |
if self.n_averaged == 0: | |
b_swa.detach().copy_(b_model_) | |
else: | |
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_, | |
self.n_averaged.to(device))) | |
self.n_averaged += 1 | |