chatlawv1 / trlx /tests /test_minibatch.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
No virus
9.93 kB
import unittest
from dataclasses import dataclass, is_dataclass
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from trlx.pipeline import MiniBatchIterator
from trlx.pipeline.offline_pipeline import (
ILQLRolloutStorage,
ILQLSeq2SeqRolloutStorage,
PromptPipeline,
)
@dataclass
class DataclassBatch:
query_tensors: torch.Tensor
response_tensors: torch.Tensor
logprobs: torch.Tensor
values: torch.Tensor
rewards: torch.Tensor
class DummyDataset(Dataset, DataclassBatch):
def __init__(self, num_samples):
self.query_tensors = torch.randn(num_samples, 64)
self.response_tensors = torch.randn(num_samples, 64)
self.logprobs = torch.randn(num_samples, 1)
self.values = torch.randn(num_samples, 1)
self.rewards = torch.randn(num_samples, 1)
def __len__(self):
return len(self.query_tensors)
def __getitem__(self, idx) -> DataclassBatch:
return DataclassBatch(
query_tensors=self.query_tensors[idx],
response_tensors=self.response_tensors[idx],
logprobs=self.logprobs[idx],
values=self.values[idx],
rewards=self.rewards[idx],
)
def collate_fn(batch):
return DataclassBatch(
query_tensors=torch.stack([sample.query_tensors for sample in batch]),
response_tensors=torch.stack([sample.response_tensors for sample in batch]),
logprobs=torch.stack([sample.logprobs for sample in batch]),
values=torch.stack([sample.values for sample in batch]),
rewards=torch.stack([sample.rewards for sample in batch]),
)
class BaseTestMiniBatchIterator(unittest.TestCase):
def check_mini_batch(self, mb, expected_mini_batch_size):
if is_dataclass(mb):
mb = mb.__dict__
for key, value in mb.items():
self.assertEqual(value.size(0), expected_mini_batch_size)
class TestMiniBatchDL(BaseTestMiniBatchIterator):
def test_batch(self):
batch = DataclassBatch(
torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5])
)
self.assertTrue(is_dataclass(batch))
self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values()))
def test_minibatch_iterator(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)
iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 4)
def test_minibatch_iterator_with_undivisible_mbsize(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)
iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3)
for minibatches in iterator:
for minibatch in minibatches[:-1]:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 3)
# last minibatch has only 2 samples
minibatch = minibatches[-1]
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)
def test_minibatch_iterator_with_remainder(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(36)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)
iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4)
for i in range(4):
minibatches = next(iterator)
for minibatch in minibatches[:-1]:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)
# last iteration has only 2 minibatches
minibatches = next(iterator)
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)
def test_minibatch_iterator_with_smaller_dataset(self):
# Create Dummy Dataset and DataLoader with size smaller than batch size
dummy_dataset = DummyDataset(6)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)
iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4)
minibatches = next(iterator)
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
with self.assertRaises(StopIteration):
minibatches = next(iterator)
def test_minibatch_content(self):
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)
iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2)
idx = 0
for minibatches in iterator:
for minibatch in minibatches:
for key in minibatch.__dict__.keys():
original_data = getattr(dummy_dataset, key)
start_idx = idx * minibatch.__dict__[key].size(0)
end_idx = start_idx + minibatch.__dict__[key].size(0)
expected_data = original_data[start_idx:end_idx]
# Check if the tensor content in the minibatch is consistent with the original dataset
self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data)))
idx += 1
# Test if the iterator covered all the samples in the dataset
self.assertEqual(idx * iterator.mb_size, len(dummy_dataset))
class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator):
def test_minibatch_iterator_with_prompt_pipeline(self):
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Create prompts
prompts = ["This is a test prompt."] * 32
prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer)
prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True)
iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
for minibatch in minibatches:
self.assertTrue("input_ids" in minibatch)
self.assertTrue("attention_mask" in minibatch)
self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor))
self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor))
self.check_mini_batch(minibatch, 4)
class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator):
def create_dummy_tensors(self, num_samples):
input_ids = torch.randint(0, 100, (num_samples, 10))
attention_mask = torch.randint(0, 2, (num_samples, 10))
rewards = torch.randn(num_samples, 1)
states_ixs = torch.randint(0, 100, (num_samples, 1))
actions_ixs = torch.randint(0, 100, (num_samples, 1))
dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool)
return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones
def test_minibatch_iterator_with_ilql_rollout_storage(self):
# Create dummy data
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32)
# Create ILQLRolloutStorage instance
ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones)
ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8)
iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.check_mini_batch(minibatch, expected_mini_batch_size=4)
def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self):
# Create dummy data
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32)
decoder_input_ids = torch.randint(0, 100, (32, 10))
# Create ILQLSeq2SeqRolloutStorage instance
ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage(
input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones
)
ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8)
iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.check_mini_batch(minibatch, expected_mini_batch_size=4)
if __name__ == "__main__":
unittest.main()