EZ-Crossword / Data_utils_inf.py
Ujjwal123's picture
copied the whole api code from django and updated the dockerfile
a04b340
import json
import logging
import math
import pickle
import random
from typing import List, Iterator, Callable
from torch import Tensor as T
logger = logging.getLogger()
def read_serialized_data_from_files(paths: List[str]) -> List:
results = []
for i, path in enumerate(paths):
with open(path, "rb") as reader:
logger.info("Reading file %s", path)
data = pickle.load(reader)
results.extend(data)
logger.info("Aggregated data size: {}".format(len(results)))
logger.info("Total data size: {}".format(len(results)))
return results
def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List:
results = []
if upsample_rates is None:
upsample_rates = [1] * len(paths)
assert len(upsample_rates) == len(
paths
), "up-sample rates parameter doesn't match input files amount"
for i, path in enumerate(paths):
with open(path, "r", encoding="utf-8") as f:
logger.info("Reading file %s" % path)
data = json.load(f)
upsample_factor = int(upsample_rates[i])
data = data * upsample_factor
results.extend(data)
logger.info("Aggregated data size: {}".format(len(results)))
return results
class ShardedDataIterator(object):
"""
General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of
the data.
Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size.
It fills the extra sample by just taking first samples in a shard.
It can also optionally enforce identical batch size for all iterations (might be useful for DP mode).
"""
def __init__(
self,
data: list,
shard_id: int = 0,
num_shards: int = 1,
batch_size: int = 1,
shuffle=True,
shuffle_seed: int = 0,
offset: int = 0,
strict_batch_size: bool = False,
):
self.data = data
total_size = len(data)
self.shards_num = max(num_shards, 1)
self.shard_id = max(shard_id, 0)
samples_per_shard = math.ceil(total_size / self.shards_num)
self.shard_start_idx = self.shard_id * samples_per_shard
self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size)
if strict_batch_size:
self.max_iterations = math.ceil(samples_per_shard / batch_size)
else:
self.max_iterations = int(samples_per_shard / batch_size)
logger.debug(
"samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d",
samples_per_shard,
self.shard_start_idx,
self.shard_end_idx,
self.max_iterations,
)
self.iteration = offset # to track in-shard iteration status
self.shuffle = shuffle
self.batch_size = batch_size
self.shuffle_seed = shuffle_seed
self.strict_batch_size = strict_batch_size
def total_data_len(self) -> int:
return len(self.data)
def iterate_data(self, epoch: int = 0) -> Iterator[List]:
if self.shuffle:
# to be able to resume, same shuffling should be used when starting from a failed/stopped iteration
epoch_rnd = random.Random(self.shuffle_seed + epoch)
epoch_rnd.shuffle(self.data)
# if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations
max_iterations = self.max_iterations - self.iteration
shard_samples = self.data[self.shard_start_idx : self.shard_end_idx]
for i in range(
self.iteration * self.batch_size, len(shard_samples), self.batch_size
):
items = shard_samples[i : i + self.batch_size]
if self.strict_batch_size and len(items) < self.batch_size:
logger.debug("Extending batch to max size")
items.extend(shard_samples[0 : self.batch_size - len(items)])
self.iteration += 1
yield items
# some shards may done iterating while the others are at the last batch. Just return the first batch
while self.iteration < max_iterations:
logger.debug("Fulfilling non complete shard=".format(self.shard_id))
self.iteration += 1
batch = shard_samples[0 : self.batch_size]
yield batch
logger.debug(
"Finished iterating, iteration={}, shard={}".format(
self.iteration, self.shard_id
)
)
# reset the iteration status
self.iteration = 0
def get_iteration(self) -> int:
return self.iteration
def apply(self, visitor_func: Callable):
for sample in self.data:
visitor_func(sample)
def normalize_question(question: str) -> str:
if question[-1] == "?":
question = question[:-1]
return question
class Tensorizer(object):
"""
Component for all text to model input data conversions and related utility methods
"""
# Note: title, if present, is supposed to be put before text (i.e. optional title + document body)
def text_to_tensor(
self, text: str, title: str = None, add_special_tokens: bool = True
):
raise NotImplementedError
def get_pair_separator_ids(self) -> T:
raise NotImplementedError
def get_pad_id(self) -> int:
raise NotImplementedError
def get_attn_mask(self, tokens_tensor: T):
raise NotImplementedError
def is_sub_word_id(self, token_id: int):
raise NotImplementedError
def to_string(self, token_ids, skip_special_tokens=True):
raise NotImplementedError
def set_pad_to_max(self, pad: bool):
raise NotImplementedError