vjeronymo2's picture
Adding model and checkpoint
import os
import tqdm
import torch
import datetime
import itertools
from multiprocessing import Pool
from collections import OrderedDict, defaultdict
def print_message(*s, condition=True):
s = ' '.join([str(x) for x in s])
msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)
if condition:
print(msg, flush=True)
return msg
def timestamp():
format_str = "%Y-%m-%d_%H.%M.%S"
result = datetime.datetime.now().strftime(format_str)
return result
def file_tqdm(file):
print(f"#> Reading {file.name}")
with tqdm.tqdm(total=os.path.getsize(file.name) / 1024.0 / 1024.0, unit="MiB") as pbar:
for line in file:
yield line
pbar.update(len(line) / 1024.0 / 1024.0)
def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer, arguments=None):
print(f"#> Saving a checkpoint to {path} ..")
if hasattr(model, 'module'):
model = model.module # extract model from a distributed/data-parallel wrapper
checkpoint = {}
checkpoint['epoch'] = epoch_idx
checkpoint['batch'] = mb_idx
checkpoint['model_state_dict'] = model.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['arguments'] = arguments
torch.save(checkpoint, path)
def load_checkpoint(path, model, optimizer=None, do_print=True):
if do_print:
print_message("#> Loading checkpoint", path, "..")
if path.startswith("http:") or path.startswith("https:"):
checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
checkpoint = torch.load(path, map_location='cpu')
state_dict = checkpoint['model_state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k
if k[:7] == 'module.':
name = k[7:]
new_state_dict[name] = v
checkpoint['model_state_dict'] = new_state_dict
print_message("[WARNING] Loading checkpoint with strict=False")
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
if optimizer:
if do_print:
print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
print_message("#> checkpoint['batch'] =", checkpoint['batch'])
return checkpoint
def create_directory(path):
if os.path.exists(path):
print_message("#> Note: Output directory", path, 'already exists\n\n')
print_message("#> Creating directory", path, '\n\n')
# def batch(file, bsize):
# while True:
# L = [ujson.loads(file.readline()) for _ in range(bsize)]
# yield L
# return
def f7(seq):
Source: https://stackoverflow.com/a/480227/1493011
seen = set()
return [x for x in seq if not (x in seen or seen.add(x))]
def batch(group, bsize, provide_offset=False):
offset = 0
while offset < len(group):
L = group[offset: offset + bsize]
yield ((offset, L) if provide_offset else L)
offset += len(L)
class dotdict(dict):
dot.notation access to dictionary attributes
Credit: derek73 @ https://stackoverflow.com/questions/2352181
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def flatten(L):
return [x for y in L for x in y]
def zipstar(L, lazy=False):
A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...])
May return lists or tuples.
if len(L) == 0:
return L
width = len(L[0])
if width < 100:
return [[elem[idx] for elem in L] for idx in range(width)]
L = zip(*L)
return L if lazy else list(L)
def zip_first(L1, L2):
length = len(L1) if type(L1) in [tuple, list] else None
L3 = list(zip(L1, L2))
assert length in [None, len(L3)], "zip_first() failure: length differs!"
return L3
def int_or_float(val):
if '.' in val:
return float(val)
return int(val)
def load_ranking(path, types=None, lazy=False):
print_message(f"#> Loading the ranked lists from {path} ..")
lists = torch.load(path)
lists = zipstar([l.tolist() for l in tqdm.tqdm(lists)], lazy=lazy)
if types is None:
types = itertools.cycle([int_or_float])
with open(path) as f:
lists = [[typ(x) for typ, x in zip_first(types, line.strip().split('\t'))]
for line in file_tqdm(f)]
return lists
def save_ranking(ranking, path):
lists = zipstar(ranking)
lists = [torch.tensor(l) for l in lists]
torch.save(lists, path)
return lists
def groupby_first_item(lst):
groups = defaultdict(list)
for first, *rest in lst:
rest = rest[0] if len(rest) == 1 else rest
return groups
def process_grouped_by_first_item(lst):
Requires items in list to already be grouped by first item.
groups = defaultdict(list)
started = False
last_group = None
for first, *rest in lst:
rest = rest[0] if len(rest) == 1 else rest
if started and first != last_group:
yield (last_group, groups[last_group])
assert first not in groups, f"{first} seen earlier --- violates precondition."
last_group = first
started = True
return groups
def grouper(iterable, n, fillvalue=None):
Collect data into fixed-length chunks or blocks
Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
args = [iter(iterable)] * n
return itertools.zip_longest(*args, fillvalue=fillvalue)
# see https://stackoverflow.com/a/45187287
class NullContextManager(object):
def __init__(self, dummy_resource=None):
self.dummy_resource = dummy_resource
def __enter__(self):
return self.dummy_resource
def __exit__(self, *args):
def load_batch_backgrounds(args, qids):
if args.qid2backgrounds is None:
return None
qbackgrounds = []
for qid in qids:
back = args.qid2backgrounds[qid]
if len(back) and type(back[0]) == int:
x = [args.collection[pid] for pid in back]
x = [args.collectionX.get(pid, '') for pid in back]
x = ' [SEP] '.join(x)
return qbackgrounds