Spaces:
Runtime error
Runtime error
import glob | |
from dataclasses import dataclass | |
from typing import List | |
import pytorch_lightning as pl | |
import sentencepiece as spm | |
import torch | |
from torch.functional import Tensor | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data.dataloader import DataLoader | |
import helpers | |
class Example: | |
title_token_ids: List[int] | |
description_token_ids: List[int] | |
condition_token_ids: List[int] | |
fact_token_ids: List[int] | |
description: str | |
title: str | |
class TensorDict: | |
def detach(self): | |
detached_dict = { | |
field: getattr(self, field).detach() | |
if isinstance(getattr(self, field), torch.Tensor) | |
else getattr(self, field) | |
for field in self.__dataclass_fields__ | |
} | |
return self.__class__(**detached_dict) | |
def cpu(self): | |
detached_dict = { | |
field: getattr(self, field).cpu() | |
if isinstance(getattr(self, field), torch.Tensor) | |
else getattr(self, field) | |
for field in self.__dataclass_fields__ | |
} | |
return self.__class__(**detached_dict) | |
class Batched(TensorDict): | |
# Source | |
title_token_ids: torch.Tensor | |
title_token_ids_mask: torch.Tensor | |
# Attribute Fusion | |
cond_title_token_ids: torch.Tensor | |
cond_title_token_ids_mask: torch.Tensor | |
# Knowledge Incorporation | |
fact_token_ids: torch.Tensor | |
fact_token_ids_mask: torch.Tensor | |
title_fact_token_ids: torch.Tensor | |
title_fact_token_ids_mask: torch.Tensor | |
# Attribute Fusion + Knowledge Incorporation | |
cond_title_fact_token_ids: torch.Tensor | |
cond_title_fact_token_ids_mask: torch.Tensor | |
# Target | |
#description_token_ids: torch.Tensor | |
#description_token_ids_mask: torch.Tensor | |
#descriptions: List[str] | |
#titles: List[str] | |
class EncodedBatch(TensorDict): | |
context_encodings: torch.Tensor | |
context_encodings_mask: torch.Tensor | |
class DecodedBatch: | |
loss: float | |
acc: float | |
generated: List[str] | |
descriptions: List[str] | |
titles: List[str] | |
def from_processed(url: str, train=False): | |
urls = sorted(glob.glob(url)) | |
def my_split_by_worker(urls): | |
wi = torch.utils.data.get_worker_info() | |
if wi is None: | |
return urls | |
else: | |
return urls[wi.id::wi.num_workers] | |
def my_split_by_node(urls): | |
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size() | |
return urls[node_id::node_count] | |
if train: | |
return ( | |
wds.WebDataset(urls) | |
#wds.WebDataset(urls,nodesplitter=my_split_by_node) | |
#wds.WebDataset(urls,nodesplitter=wds.split_by_node) | |
.shuffle(size=10000000, initial=100000) | |
.decode() | |
.map(lambda d: Example(**d["json"])) | |
) | |
else: | |
print(list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))[0]) | |
sys.exit() | |
return list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"]))) | |
#return list(wds.WebDataset(urls, nodesplitter=my_split_by_node).decode().map(lambda d: Example(**d["json"]))) | |
#return list(wds.WebDataset(urls, nodesplitter=wds.split_by_node).decode().map(lambda d: Example(**d["json"]))) | |
def get_collate_fn(text_vocab_size: int, max_seq_length: int): | |
def collate_fn(examples: List[Example]) -> Batched: | |
from kobe.data.vocab import BOS_ID, EOS_ID | |
title_token_ids = pad_sequence( | |
[ | |
torch.tensor( | |
[BOS_ID] + e.title_token_ids[: max_seq_length - 2] + [EOS_ID] | |
) | |
for e in examples | |
] | |
) | |
fact_token_ids = pad_sequence( | |
[ | |
torch.tensor( | |
[BOS_ID] + e.fact_token_ids[: max_seq_length - 2] + [EOS_ID] | |
) | |
for e in examples | |
] | |
) | |
""" | |
description_token_ids = pad_sequence( | |
[ | |
torch.tensor( | |
[BOS_ID] + e.description_token_ids[: max_seq_length - 2] + [EOS_ID] | |
) | |
for e in examples | |
] | |
) | |
""" | |
cond_title_token_ids = pad_sequence( | |
[ | |
torch.tensor( | |
( | |
[BOS_ID] | |
+ [ | |
cond_id + text_vocab_size | |
for cond_id in e.condition_token_ids | |
] | |
+ e.title_token_ids | |
)[: max_seq_length - 1] | |
+ [EOS_ID] | |
) | |
for e in examples | |
] | |
) | |
title_fact_token_ids = pad_sequence( | |
[ | |
torch.tensor( | |
([BOS_ID] + e.title_token_ids + [EOS_ID] + e.fact_token_ids)[ | |
: max_seq_length - 1 | |
] | |
+ [EOS_ID] | |
) | |
for e in examples | |
] | |
) | |
cond_title_fact_token_ids = pad_sequence( | |
[ | |
torch.tensor( | |
( | |
[BOS_ID] | |
+ [ | |
cond_id + text_vocab_size | |
for cond_id in e.condition_token_ids | |
] | |
+ e.title_token_ids | |
+ [EOS_ID] | |
+ e.fact_token_ids | |
)[: max_seq_length - 1] | |
+ [EOS_ID] | |
) | |
for e in examples | |
] | |
) | |
#descriptions = [e.description for e in examples] | |
#titles = [e.title for e in examples] | |
return Batched( | |
title_token_ids=title_token_ids, | |
title_token_ids_mask=(title_token_ids == 0).T, | |
fact_token_ids=fact_token_ids, | |
fact_token_ids_mask=(fact_token_ids == 0).T, | |
cond_title_token_ids=cond_title_token_ids, | |
cond_title_token_ids_mask=(cond_title_token_ids == 0).T, | |
title_fact_token_ids=title_fact_token_ids, | |
title_fact_token_ids_mask=(title_fact_token_ids == 0).T, | |
cond_title_fact_token_ids=cond_title_fact_token_ids, | |
cond_title_fact_token_ids_mask=(cond_title_fact_token_ids == 0).T, | |
#description_token_ids="", | |
#description_token_ids_mask=(description_token_ids == 0).T, | |
#descriptions="", | |
#titles="", | |
) | |
return collate_fn | |
class KobeDataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
test_data: str, | |
vocab_path: str, | |
max_seq_length: int, | |
batch_size: int, | |
num_workers: int, | |
): | |
super().__init__() | |
self.test_data = test_data | |
self.max_seq_length = max_seq_length | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.text_vocab_size = helpers.get_bert_vocab_size(vocab_path) | |
""" | |
def train_dataloader(self): | |
return DataLoader( | |
self.train, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length), | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.valid, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length), | |
) | |
""" | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_data, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length), | |
) | |
if __name__ == "__main__": | |
dm = KobeDataModule( | |
train_data="saved/processed/train-*.tar", | |
valid_data="saved/processed/valid.tar", | |
test_data="saved/processed/test.tar", | |
vocab_path="bert-base-chinese", | |
max_seq_length=512, | |
batch_size=32, | |
num_workers=8, | |
) | |
dm.setup("test") | |
max_len = 0 | |
from tqdm import tqdm | |
tqdm_iter = tqdm(dm.test_dataloader()) | |
for batch in tqdm_iter: | |
max_len = max(max_len, batch.cond_title_fact_token_ids.shape[0]) | |
max_len = max(max_len, batch.description_token_ids.shape[0]) | |
tqdm_iter.set_description(f"max len = {max_len}") | |