SolLlama / datamodule_finetune_sl.py
BrightBlueCheese
app
7f92264
import lightning as L
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorWithPadding
# class CustomLlamaDatasetAbraham(Dataset):
# def __init__(self, smiles_str, tokenizer, max_seq_length):
# # self.keys = df.iloc[:, 0] # 1D array
# # self.labels = df.iloc[:, 1:] # 2D array
# self.keys = smiles_str
# self.tokenizer = tokenizer
# self.max_seq_length = max_seq_length
# def __len__(self):
# return 1
# def fn_token_encode(self, smiles):
# return self.tokenizer(
# smiles,
# truncation=True,
# padding="max_length",
# max_length=self.max_seq_length,
# )
# def __getitem__(self, idx):
# local_encoded = self.fn_token_encode(self.keys)
# return {
# "input_ids": torch.tensor(local_encoded["input_ids"]),
# "attention_mask": torch.tensor(local_encoded["attention_mask"]),
# # "labels": torch.tensor(self.labels.iloc[idx]),
# "labels": None,
# }
class CustomLlamaDatasetAbraham(Dataset):
def __init__(self, df, tokenizer, max_seq_length):
self.keys = df.iloc[:, 0] # 1D array
# self.labels = df.iloc[:, 1:] # 2D array
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def __len__(self):
return self.keys.shape[0]
def fn_token_encode(self, smiles):
return self.tokenizer(
smiles,
truncation=True,
padding="max_length",
max_length=self.max_seq_length,
)
def __getitem__(self, idx):
local_encoded = self.fn_token_encode(self.keys.iloc[idx])
return {
"input_ids": torch.tensor(local_encoded["input_ids"]),
"attention_mask": torch.tensor(local_encoded["attention_mask"]),
"labels": torch.tensor(local_encoded["input_ids"]), # this one does not matter for sl
}
class CustomFinetuneDataModule(L.LightningDataModule):
def __init__(
self,
solute_or_solvent,
tokenizer,
max_seq_length,
batch_size_train,
batch_size_valid,
num_device,
):
super().__init__()
self.solute_or_solvent = solute_or_solvent
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.batch_size_train = batch_size_train
self.batch_size_valid = batch_size_valid
self.data_collator = DataCollatorWithPadding(self.tokenizer)
self.num_device = num_device
def prepare_data(self):
# self.list_df = load_abraham(self.solute_or_solvent)
with open('./smiles_str.txt', 'r') as file:
smiles_str = file.readline()
self.smiles_str = pd.DataFrame([smiles_str])
def setup(self, stage=None):
# self.train_df, self.valid_df, self.test_df = self.list_df
self.train_df = None
self.valid_df = None
self.test_df = self.smiles_str
def train_dataloader(self):
return DataLoader(
dataset=CustomLlamaDatasetAbraham(
self.train_df, self.tokenizer, self.max_seq_length,
),
batch_size=self.batch_size_train,
num_workers=self.num_device,
collate_fn=self.data_collator,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
dataset=CustomLlamaDatasetAbraham(
self.valid_df, self.tokenizer, self.max_seq_length,
),
batch_size=self.batch_size_valid,
num_workers=self.num_device,
collate_fn=self.data_collator,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
dataset=CustomLlamaDatasetAbraham(
self.test_df, self.tokenizer, self.max_seq_length,
),
batch_size=self.batch_size_valid,
num_workers=self.num_device,
collate_fn=self.data_collator,
shuffle=False,
)
# It uses test_df
def predict_dataloader(self):
return DataLoader(
dataset=CustomLlamaDatasetAbraham(
self.test_df, self.tokenizer, self.max_seq_length,
),
batch_size=self.batch_size_valid,
num_workers=self.num_device,
collate_fn=self.data_collator,
shuffle=False,
)