|
import os
|
|
import tqdm
|
|
import torch
|
|
import datetime
|
|
import itertools
|
|
|
|
from multiprocessing import Pool
|
|
from collections import OrderedDict, defaultdict
|
|
|
|
|
|
def print_message(*s, condition=True):
|
|
s = ' '.join([str(x) for x in s])
|
|
msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)
|
|
|
|
if condition:
|
|
print(msg, flush=True)
|
|
|
|
return msg
|
|
|
|
|
|
def timestamp():
|
|
format_str = "%Y-%m-%d_%H.%M.%S"
|
|
result = datetime.datetime.now().strftime(format_str)
|
|
return result
|
|
|
|
|
|
def file_tqdm(file):
|
|
print(f"#> Reading {file.name}")
|
|
|
|
with tqdm.tqdm(total=os.path.getsize(file.name) / 1024.0 / 1024.0, unit="MiB") as pbar:
|
|
for line in file:
|
|
yield line
|
|
pbar.update(len(line) / 1024.0 / 1024.0)
|
|
|
|
pbar.close()
|
|
|
|
|
|
def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer, arguments=None):
|
|
print(f"#> Saving a checkpoint to {path} ..")
|
|
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
|
|
checkpoint = {}
|
|
checkpoint['epoch'] = epoch_idx
|
|
checkpoint['batch'] = mb_idx
|
|
checkpoint['model_state_dict'] = model.state_dict()
|
|
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
|
checkpoint['arguments'] = arguments
|
|
|
|
torch.save(checkpoint, path)
|
|
|
|
|
|
def load_checkpoint(path, model, optimizer=None, do_print=True):
|
|
if do_print:
|
|
print_message("#> Loading checkpoint", path, "..")
|
|
|
|
if path.startswith("http:") or path.startswith("https:"):
|
|
checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
|
|
else:
|
|
checkpoint = torch.load(path, map_location='cpu')
|
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k
|
|
if k[:7] == 'module.':
|
|
name = k[7:]
|
|
new_state_dict[name] = v
|
|
|
|
checkpoint['model_state_dict'] = new_state_dict
|
|
|
|
try:
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
except:
|
|
print_message("[WARNING] Loading checkpoint with strict=False")
|
|
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
|
|
|
if optimizer:
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
if do_print:
|
|
print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
|
|
print_message("#> checkpoint['batch'] =", checkpoint['batch'])
|
|
|
|
return checkpoint
|
|
|
|
|
|
def create_directory(path):
|
|
if os.path.exists(path):
|
|
print('\n')
|
|
print_message("#> Note: Output directory", path, 'already exists\n\n')
|
|
else:
|
|
print('\n')
|
|
print_message("#> Creating directory", path, '\n\n')
|
|
os.makedirs(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def f7(seq):
|
|
"""
|
|
Source: https://stackoverflow.com/a/480227/1493011
|
|
"""
|
|
|
|
seen = set()
|
|
return [x for x in seq if not (x in seen or seen.add(x))]
|
|
|
|
|
|
def batch(group, bsize, provide_offset=False):
|
|
offset = 0
|
|
while offset < len(group):
|
|
L = group[offset: offset + bsize]
|
|
yield ((offset, L) if provide_offset else L)
|
|
offset += len(L)
|
|
return
|
|
|
|
|
|
class dotdict(dict):
|
|
"""
|
|
dot.notation access to dictionary attributes
|
|
Credit: derek73 @ https://stackoverflow.com/questions/2352181
|
|
"""
|
|
__getattr__ = dict.__getitem__
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
|
|
|
|
def flatten(L):
|
|
return [x for y in L for x in y]
|
|
|
|
|
|
def zipstar(L, lazy=False):
|
|
"""
|
|
A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...])
|
|
May return lists or tuples.
|
|
"""
|
|
|
|
if len(L) == 0:
|
|
return L
|
|
|
|
width = len(L[0])
|
|
|
|
if width < 100:
|
|
return [[elem[idx] for elem in L] for idx in range(width)]
|
|
|
|
L = zip(*L)
|
|
|
|
return L if lazy else list(L)
|
|
|
|
|
|
def zip_first(L1, L2):
|
|
length = len(L1) if type(L1) in [tuple, list] else None
|
|
|
|
L3 = list(zip(L1, L2))
|
|
|
|
assert length in [None, len(L3)], "zip_first() failure: length differs!"
|
|
|
|
return L3
|
|
|
|
|
|
def int_or_float(val):
|
|
if '.' in val:
|
|
return float(val)
|
|
|
|
return int(val)
|
|
|
|
def load_ranking(path, types=None, lazy=False):
|
|
print_message(f"#> Loading the ranked lists from {path} ..")
|
|
|
|
try:
|
|
lists = torch.load(path)
|
|
lists = zipstar([l.tolist() for l in tqdm.tqdm(lists)], lazy=lazy)
|
|
except:
|
|
if types is None:
|
|
types = itertools.cycle([int_or_float])
|
|
|
|
with open(path) as f:
|
|
lists = [[typ(x) for typ, x in zip_first(types, line.strip().split('\t'))]
|
|
for line in file_tqdm(f)]
|
|
|
|
return lists
|
|
|
|
|
|
def save_ranking(ranking, path):
|
|
lists = zipstar(ranking)
|
|
lists = [torch.tensor(l) for l in lists]
|
|
|
|
torch.save(lists, path)
|
|
|
|
return lists
|
|
|
|
|
|
def groupby_first_item(lst):
|
|
groups = defaultdict(list)
|
|
|
|
for first, *rest in lst:
|
|
rest = rest[0] if len(rest) == 1 else rest
|
|
groups[first].append(rest)
|
|
|
|
return groups
|
|
|
|
|
|
def process_grouped_by_first_item(lst):
|
|
"""
|
|
Requires items in list to already be grouped by first item.
|
|
"""
|
|
|
|
groups = defaultdict(list)
|
|
|
|
started = False
|
|
last_group = None
|
|
|
|
for first, *rest in lst:
|
|
rest = rest[0] if len(rest) == 1 else rest
|
|
|
|
if started and first != last_group:
|
|
yield (last_group, groups[last_group])
|
|
assert first not in groups, f"{first} seen earlier --- violates precondition."
|
|
|
|
groups[first].append(rest)
|
|
|
|
last_group = first
|
|
started = True
|
|
|
|
return groups
|
|
|
|
|
|
def grouper(iterable, n, fillvalue=None):
|
|
"""
|
|
Collect data into fixed-length chunks or blocks
|
|
Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
|
|
Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
|
|
"""
|
|
|
|
args = [iter(iterable)] * n
|
|
return itertools.zip_longest(*args, fillvalue=fillvalue)
|
|
|
|
|
|
|
|
class NullContextManager(object):
|
|
def __init__(self, dummy_resource=None):
|
|
self.dummy_resource = dummy_resource
|
|
def __enter__(self):
|
|
return self.dummy_resource
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
|
|
def load_batch_backgrounds(args, qids):
|
|
if args.qid2backgrounds is None:
|
|
return None
|
|
|
|
qbackgrounds = []
|
|
|
|
for qid in qids:
|
|
back = args.qid2backgrounds[qid]
|
|
|
|
if len(back) and type(back[0]) == int:
|
|
x = [args.collection[pid] for pid in back]
|
|
else:
|
|
x = [args.collectionX.get(pid, '') for pid in back]
|
|
|
|
x = ' [SEP] '.join(x)
|
|
qbackgrounds.append(x)
|
|
|
|
return qbackgrounds
|
|
|