Spaces:
Runtime error
Runtime error
import torch | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class CTCLabelConverter(object): | |
""" Convert between text-label and text-index """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
dict_character = list(character) | |
self.dict = {} | |
for i, char in enumerate(dict_character): | |
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
self.dict[char] = i + 1 | |
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
def encode(self, text, batch_max_length=25): | |
"""convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
batch_max_length: max length of text label in the batch. 25 by default | |
output: | |
text: text index for CTCLoss. [batch_size, batch_max_length] | |
length: length of each text. [batch_size] | |
""" | |
length = [len(s) for s in text] | |
# The index used for padding (=0) would not affect the CTC loss calculation. | |
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) | |
for i, t in enumerate(text): | |
text = list(t) | |
text = [self.dict[char] for char in text] | |
batch_text[i][:len(text)] = torch.LongTensor(text) | |
return (batch_text.to(device), torch.IntTensor(length).to(device)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
for index, l in enumerate(length): | |
t = text_index[index, :] | |
char_list = [] | |
for i in range(l): | |
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
char_list.append(self.character[t[i]]) | |
text = ''.join(char_list) | |
texts.append(text) | |
return texts | |
class CTCLabelConverterForBaiduWarpctc(object): | |
""" Convert between text-label and text-index for baidu warpctc """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
dict_character = list(character) | |
self.dict = {} | |
for i, char in enumerate(dict_character): | |
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
self.dict[char] = i + 1 | |
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
def encode(self, text, batch_max_length=25): | |
"""convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
output: | |
text: concatenated text index for CTCLoss. | |
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
length: length of each text. [batch_size] | |
""" | |
length = [len(s) for s in text] | |
text = ''.join(text) | |
text = [self.dict[char] for char in text] | |
return (torch.IntTensor(text), torch.IntTensor(length)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
index = 0 | |
for l in length: | |
t = text_index[index:index + l] | |
char_list = [] | |
for i in range(l): | |
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
char_list.append(self.character[t[i]]) | |
text = ''.join(char_list) | |
texts.append(text) | |
index += l | |
return texts | |
class AttnLabelConverter(object): | |
""" Convert between text-label and text-index """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token. | |
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] | |
list_character = list(character) | |
self.character = list_token + list_character | |
self.dict = {} | |
for i, char in enumerate(self.character): | |
# print(i, char) | |
self.dict[char] = i | |
def encode(self, text, batch_max_length=25): | |
""" convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
batch_max_length: max length of text label in the batch. 25 by default | |
output: | |
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. | |
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. | |
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] | |
""" | |
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. | |
# batch_max_length = max(length) # this is not allowed for multi-gpu setting | |
batch_max_length += 1 | |
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. | |
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) | |
for i, t in enumerate(text): | |
text = list(t) | |
text.append('[s]') | |
text = [self.dict[char] for char in text] | |
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token | |
return (batch_text.to(device), torch.IntTensor(length).to(device)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
for index, l in enumerate(length): | |
text = ''.join([self.character[i] for i in text_index[index, :]]) | |
texts.append(text) | |
return texts | |
class Averager(object): | |
"""Compute average for torch.Tensor, used for loss average.""" | |
def __init__(self): | |
self.reset() | |
def add(self, v): | |
count = v.data.numel() | |
v = v.data.sum() | |
self.n_count += count | |
self.sum += v | |
def reset(self): | |
self.n_count = 0 | |
self.sum = 0 | |
def val(self): | |
res = 0 | |
if self.n_count != 0: | |
res = self.sum / float(self.n_count) | |
return res | |