demo / dataset.py
joechou's picture
Update dataset.py
e588819
raw
history blame
8.52 kB
import glob
from dataclasses import dataclass
from typing import List
import pytorch_lightning as pl
import sentencepiece as spm
import torch
from torch.functional import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataloader import DataLoader
import helpers
@dataclass
class Example:
title_token_ids: List[int]
description_token_ids: List[int]
condition_token_ids: List[int]
fact_token_ids: List[int]
description: str
title: str
@dataclass
class TensorDict:
def detach(self):
detached_dict = {
field: getattr(self, field).detach()
if isinstance(getattr(self, field), torch.Tensor)
else getattr(self, field)
for field in self.__dataclass_fields__
}
return self.__class__(**detached_dict)
def cpu(self):
detached_dict = {
field: getattr(self, field).cpu()
if isinstance(getattr(self, field), torch.Tensor)
else getattr(self, field)
for field in self.__dataclass_fields__
}
return self.__class__(**detached_dict)
@dataclass
class Batched(TensorDict):
# Source
title_token_ids: torch.Tensor
title_token_ids_mask: torch.Tensor
# Attribute Fusion
cond_title_token_ids: torch.Tensor
cond_title_token_ids_mask: torch.Tensor
# Knowledge Incorporation
fact_token_ids: torch.Tensor
fact_token_ids_mask: torch.Tensor
title_fact_token_ids: torch.Tensor
title_fact_token_ids_mask: torch.Tensor
# Attribute Fusion + Knowledge Incorporation
cond_title_fact_token_ids: torch.Tensor
cond_title_fact_token_ids_mask: torch.Tensor
# Target
#description_token_ids: torch.Tensor
#description_token_ids_mask: torch.Tensor
#descriptions: List[str]
#titles: List[str]
@dataclass
class EncodedBatch(TensorDict):
context_encodings: torch.Tensor
context_encodings_mask: torch.Tensor
@dataclass
class DecodedBatch:
loss: float
acc: float
generated: List[str]
descriptions: List[str]
titles: List[str]
def from_processed(url: str, train=False):
urls = sorted(glob.glob(url))
def my_split_by_worker(urls):
wi = torch.utils.data.get_worker_info()
if wi is None:
return urls
else:
return urls[wi.id::wi.num_workers]
def my_split_by_node(urls):
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return urls[node_id::node_count]
if train:
return (
wds.WebDataset(urls)
#wds.WebDataset(urls,nodesplitter=my_split_by_node)
#wds.WebDataset(urls,nodesplitter=wds.split_by_node)
.shuffle(size=10000000, initial=100000)
.decode()
.map(lambda d: Example(**d["json"]))
)
else:
print(list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))[0])
sys.exit()
return list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))
#return list(wds.WebDataset(urls, nodesplitter=my_split_by_node).decode().map(lambda d: Example(**d["json"])))
#return list(wds.WebDataset(urls, nodesplitter=wds.split_by_node).decode().map(lambda d: Example(**d["json"])))
def get_collate_fn(text_vocab_size: int, max_seq_length: int):
def collate_fn(examples: List[Example]) -> Batched:
from kobe.data.vocab import BOS_ID, EOS_ID
title_token_ids = pad_sequence(
[
torch.tensor(
[BOS_ID] + e.title_token_ids[: max_seq_length - 2] + [EOS_ID]
)
for e in examples
]
)
fact_token_ids = pad_sequence(
[
torch.tensor(
[BOS_ID] + e.fact_token_ids[: max_seq_length - 2] + [EOS_ID]
)
for e in examples
]
)
"""
description_token_ids = pad_sequence(
[
torch.tensor(
[BOS_ID] + e.description_token_ids[: max_seq_length - 2] + [EOS_ID]
)
for e in examples
]
)
"""
cond_title_token_ids = pad_sequence(
[
torch.tensor(
(
[BOS_ID]
+ [
cond_id + text_vocab_size
for cond_id in e.condition_token_ids
]
+ e.title_token_ids
)[: max_seq_length - 1]
+ [EOS_ID]
)
for e in examples
]
)
title_fact_token_ids = pad_sequence(
[
torch.tensor(
([BOS_ID] + e.title_token_ids + [EOS_ID] + e.fact_token_ids)[
: max_seq_length - 1
]
+ [EOS_ID]
)
for e in examples
]
)
cond_title_fact_token_ids = pad_sequence(
[
torch.tensor(
(
[BOS_ID]
+ [
cond_id + text_vocab_size
for cond_id in e.condition_token_ids
]
+ e.title_token_ids
+ [EOS_ID]
+ e.fact_token_ids
)[: max_seq_length - 1]
+ [EOS_ID]
)
for e in examples
]
)
#descriptions = [e.description for e in examples]
#titles = [e.title for e in examples]
return Batched(
title_token_ids=title_token_ids,
title_token_ids_mask=(title_token_ids == 0).T,
fact_token_ids=fact_token_ids,
fact_token_ids_mask=(fact_token_ids == 0).T,
cond_title_token_ids=cond_title_token_ids,
cond_title_token_ids_mask=(cond_title_token_ids == 0).T,
title_fact_token_ids=title_fact_token_ids,
title_fact_token_ids_mask=(title_fact_token_ids == 0).T,
cond_title_fact_token_ids=cond_title_fact_token_ids,
cond_title_fact_token_ids_mask=(cond_title_fact_token_ids == 0).T,
#description_token_ids="",
#description_token_ids_mask=(description_token_ids == 0).T,
#descriptions="",
#titles="",
)
return collate_fn
class KobeDataModule(pl.LightningDataModule):
def __init__(
self,
test_data: str,
vocab_path: str,
max_seq_length: int,
batch_size: int,
num_workers: int,
):
super().__init__()
self.test_data = test_data
self.max_seq_length = max_seq_length
self.batch_size = batch_size
self.num_workers = num_workers
self.text_vocab_size = helpers.get_bert_vocab_size(vocab_path)
"""
def train_dataloader(self):
return DataLoader(
self.train,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
)
def val_dataloader(self):
return DataLoader(
self.valid,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
)
"""
def test_dataloader(self):
return DataLoader(
self.test_data,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
)
if __name__ == "__main__":
dm = KobeDataModule(
train_data="saved/processed/train-*.tar",
valid_data="saved/processed/valid.tar",
test_data="saved/processed/test.tar",
vocab_path="bert-base-chinese",
max_seq_length=512,
batch_size=32,
num_workers=8,
)
dm.setup("test")
max_len = 0
from tqdm import tqdm
tqdm_iter = tqdm(dm.test_dataloader())
for batch in tqdm_iter:
max_len = max(max_len, batch.cond_title_fact_token_ids.shape[0])
max_len = max(max_len, batch.description_token_ids.shape[0])
tqdm_iter.set_description(f"max len = {max_len}")