Spaces:
Build error
Build error
import logging | |
import os | |
import time | |
import cv2 | |
import numpy as np | |
import torch | |
import yaml | |
from matplotlib import colors | |
from matplotlib import pyplot as plt | |
from torch import Tensor, nn | |
from torch.utils.data import ConcatDataset | |
class CharsetMapper(object): | |
"""A simple class to map ids into strings. | |
It works only when the character set is 1:1 mapping between individual | |
characters and individual ids. | |
""" | |
def __init__(self, | |
filename='', | |
max_length=30, | |
null_char=u'\u2591'): | |
"""Creates a lookup table. | |
Args: | |
filename: Path to charset file which maps characters to ids. | |
max_sequence_length: The max length of ids and string. | |
null_char: A unicode character used to replace '<null>' character. | |
the default value is a light shade block '░'. | |
""" | |
self.null_char = null_char | |
self.max_length = max_length | |
self.label_to_char = self._read_charset(filename) | |
self.char_to_label = dict(map(reversed, self.label_to_char.items())) | |
self.num_classes = len(self.label_to_char) | |
def _read_charset(self, filename): | |
"""Reads a charset definition from a tab separated text file. | |
Args: | |
filename: a path to the charset file. | |
Returns: | |
a dictionary with keys equal to character codes and values - unicode | |
characters. | |
""" | |
import re | |
pattern = re.compile(r'(\d+)\t(.+)') | |
charset = {} | |
self.null_label = 0 | |
charset[self.null_label] = self.null_char | |
with open(filename, 'r') as f: | |
for i, line in enumerate(f): | |
m = pattern.match(line) | |
assert m, f'Incorrect charset file. line #{i}: {line}' | |
label = int(m.group(1)) + 1 | |
char = m.group(2) | |
charset[label] = char | |
return charset | |
def trim(self, text): | |
assert isinstance(text, str) | |
return text.replace(self.null_char, '') | |
def get_text(self, labels, length=None, padding=True, trim=False): | |
""" Returns a string corresponding to a sequence of character ids. | |
""" | |
length = length if length else self.max_length | |
labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] | |
if padding: | |
labels = labels + [self.null_label] * (length-len(labels)) | |
text = ''.join([self.label_to_char[label] for label in labels]) | |
if trim: text = self.trim(text) | |
return text | |
def get_labels(self, text, length=None, padding=True, case_sensitive=False): | |
""" Returns the labels of the corresponding text. | |
""" | |
length = length if length else self.max_length | |
if padding: | |
text = text + self.null_char * (length - len(text)) | |
if not case_sensitive: | |
text = text.lower() | |
labels = [self.char_to_label[char] for char in text] | |
return labels | |
def pad_labels(self, labels, length=None): | |
length = length if length else self.max_length | |
return labels + [self.null_label] * (length - len(labels)) | |
def digits(self): | |
return '0123456789' | |
def digit_labels(self): | |
return self.get_labels(self.digits, padding=False) | |
def alphabets(self): | |
all_chars = list(self.char_to_label.keys()) | |
valid_chars = [] | |
for c in all_chars: | |
if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': | |
valid_chars.append(c) | |
return ''.join(valid_chars) | |
def alphabet_labels(self): | |
return self.get_labels(self.alphabets, padding=False) | |
class Timer(object): | |
"""A simple timer.""" | |
def __init__(self): | |
self.data_time = 0. | |
self.data_diff = 0. | |
self.data_total_time = 0. | |
self.data_call = 0 | |
self.running_time = 0. | |
self.running_diff = 0. | |
self.running_total_time = 0. | |
self.running_call = 0 | |
def tic(self): | |
self.start_time = time.time() | |
self.running_time = self.start_time | |
def toc_data(self): | |
self.data_time = time.time() | |
self.data_diff = self.data_time - self.running_time | |
self.data_total_time += self.data_diff | |
self.data_call += 1 | |
def toc_running(self): | |
self.running_time = time.time() | |
self.running_diff = self.running_time - self.data_time | |
self.running_total_time += self.running_diff | |
self.running_call += 1 | |
def total_time(self): | |
return self.data_total_time + self.running_total_time | |
def average_time(self): | |
return self.average_data_time() + self.average_running_time() | |
def average_data_time(self): | |
return self.data_total_time / (self.data_call or 1) | |
def average_running_time(self): | |
return self.running_total_time / (self.running_call or 1) | |
class Logger(object): | |
_handle = None | |
_root = None | |
def init(output_dir, name, phase): | |
format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ | |
'%(message)s'.format(name) | |
logging.basicConfig(level=logging.INFO, format=format) | |
try: os.makedirs(output_dir) | |
except: pass | |
config_path = os.path.join(output_dir, f'{phase}.txt') | |
Logger._handle = logging.FileHandler(config_path) | |
Logger._root = logging.getLogger() | |
def enable_file(): | |
if Logger._handle is None or Logger._root is None: | |
raise Exception('Invoke Logger.init() first!') | |
Logger._root.addHandler(Logger._handle) | |
def disable_file(): | |
if Logger._handle is None or Logger._root is None: | |
raise Exception('Invoke Logger.init() first!') | |
Logger._root.removeHandler(Logger._handle) | |
class Config(object): | |
def __init__(self, config_path, host=True): | |
def __dict2attr(d, prefix=''): | |
for k, v in d.items(): | |
if isinstance(v, dict): | |
__dict2attr(v, f'{prefix}{k}_') | |
else: | |
if k == 'phase': | |
assert v in ['train', 'test'] | |
if k == 'stage': | |
assert v in ['pretrain-vision', 'pretrain-language', | |
'train-semi-super', 'train-super'] | |
self.__setattr__(f'{prefix}{k}', v) | |
assert os.path.exists(config_path), '%s does not exists!' % config_path | |
with open(config_path) as file: | |
config_dict = yaml.load(file, Loader=yaml.FullLoader) | |
with open('configs/template.yaml') as file: | |
default_config_dict = yaml.load(file, Loader=yaml.FullLoader) | |
__dict2attr(default_config_dict) | |
__dict2attr(config_dict) | |
self.global_workdir = os.path.join(self.global_workdir, self.global_name) | |
def __getattr__(self, item): | |
attr = self.__dict__.get(item) | |
if attr is None: | |
attr = dict() | |
prefix = f'{item}_' | |
for k, v in self.__dict__.items(): | |
if k.startswith(prefix): | |
n = k.replace(prefix, '') | |
attr[n] = v | |
return attr if len(attr) > 0 else None | |
else: | |
return attr | |
def __repr__(self): | |
str = 'ModelConfig(\n' | |
for i, (k, v) in enumerate(sorted(vars(self).items())): | |
str += f'\t({i}): {k} = {v}\n' | |
str += ')' | |
return str | |
def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): | |
# normalize mask | |
mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) | |
if mask.shape != image.shape: | |
mask = cv2.resize(mask,(image.shape[1], image.shape[0])) | |
# get color map | |
color_map = plt.get_cmap(cmap) | |
mask = color_map(mask)[:,:,:3] | |
# convert float to uint8 | |
mask = (mask * 255).astype(dtype=np.uint8) | |
# set the basic color | |
basic_color = np.array(colors.to_rgb(color)) * 255 | |
basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) | |
basic_color = basic_color.astype(dtype=np.uint8) | |
# blend with basic color | |
blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) | |
# blend with mask | |
blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) | |
return blended_img | |
def onehot(label, depth, device=None): | |
""" | |
Args: | |
label: shape (n1, n2, ..., ) | |
depth: a scalar | |
Returns: | |
onehot: (n1, n2, ..., depth) | |
""" | |
if not isinstance(label, torch.Tensor): | |
label = torch.tensor(label, device=device) | |
onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) | |
onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) | |
return onehot | |
class MyDataParallel(nn.DataParallel): | |
def gather(self, outputs, target_device): | |
r""" | |
Gathers tensors from different GPUs on a specified device | |
(-1 means the CPU). | |
""" | |
def gather_map(outputs): | |
out = outputs[0] | |
if isinstance(out, (str, int, float)): | |
return out | |
if isinstance(out, list) and isinstance(out[0], str): | |
return [o for out in outputs for o in out] | |
if isinstance(out, torch.Tensor): | |
return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) | |
if out is None: | |
return None | |
if isinstance(out, dict): | |
if not all((len(out) == len(d) for d in outputs)): | |
raise ValueError('All dicts must have the same number of keys') | |
return type(out)(((k, gather_map([d[k] for d in outputs])) | |
for k in out)) | |
return type(out)(map(gather_map, zip(*outputs))) | |
# Recursive function calls like this create reference cycles. | |
# Setting the function to None clears the refcycle. | |
try: | |
res = gather_map(outputs) | |
finally: | |
gather_map = None | |
return res | |
class MyConcatDataset(ConcatDataset): | |
def __getattr__(self, k): | |
return getattr(self.datasets[0], k) | |