Spaces:
Running
Running
File size: 15,454 Bytes
5c72fe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
import argparse
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from src.dataset import TokenizerDataset
from src.bert import BERT
from src.pretrainer import BERTFineTuneTrainer1
from src.vocab import Vocab
import pandas as pd
# class CustomBERTModel(nn.Module):
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
# super(CustomBERTModel, self).__init__()
# hidden_size = 768
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
# if isinstance(checkpoint, dict):
# self.bert.load_state_dict(checkpoint)
# elif isinstance(checkpoint, BERT):
# self.bert = checkpoint
# else:
# raise TypeError(f"Expected state_dict or BERT instance, got {type(checkpoint)} instead.")
# self.fc = nn.Linear(hidden_size, output_dim)
# def forward(self, sequence, segment_info):
# sequence = sequence.to(next(self.parameters()).device)
# segment_info = segment_info.to(sequence.device)
# if sequence.size(0) == 0 or sequence.size(1) == 0:
# raise ValueError("Input sequence tensor has 0 elements. Check data preprocessing.")
# x = self.bert(sequence, segment_info)
# print(f"BERT output shape: {x.shape}")
# if x.size(0) == 0 or x.size(1) == 0:
# raise ValueError("BERT output tensor has 0 elements. Check input dimensions.")
# cls_embeddings = x[:, 0]
# logits = self.fc(cls_embeddings)
# return logits
# class CustomBERTModel(nn.Module):
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
# super(CustomBERTModel, self).__init__()
# hidden_size = 764 # Ensure this is 768
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
# # Load the pre-trained model's state_dict
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
# if isinstance(checkpoint, dict):
# self.bert.load_state_dict(checkpoint)
# else:
# raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
# # Fully connected layer with input size 768
# self.fc = nn.Linear(hidden_size, output_dim)
# def forward(self, sequence, segment_info):
# sequence = sequence.to(next(self.parameters()).device)
# segment_info = segment_info.to(sequence.device)
# x = self.bert(sequence, segment_info)
# print(f"BERT output shape: {x.shape}") # Should output (batch_size, seq_len, 768)
# cls_embeddings = x[:, 0] # Extract CLS token embeddings
# print(f"CLS Embeddings shape: {cls_embeddings.shape}") # Should output (batch_size, 768)
# logits = self.fc(cls_embeddings) # Should now pass a tensor of size (batch_size, 768) to `fc`
# return logits
# for test
class CustomBERTModel(nn.Module):
def __init__(self, vocab_size, output_dim, pre_trained_model_path):
super(CustomBERTModel, self).__init__()
self.hidden = 764 # Ensure this is defined correctly
self.bert = BERT(vocab_size=vocab_size, hidden=self.hidden, n_layers=12, attn_heads=12, dropout=0.1)
# Load the pre-trained model's state_dict
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict):
self.bert.load_state_dict(checkpoint)
else:
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
self.fc = nn.Linear(self.hidden, output_dim)
def forward(self, sequence, segment_info):
x = self.bert(sequence, segment_info)
cls_embeddings = x[:, 0] # Extract CLS token embeddings
logits = self.fc(cls_embeddings) # Pass to fully connected layer
return logits
def preprocess_labels(label_csv_path):
try:
labels_df = pd.read_csv(label_csv_path)
labels = labels_df['last_hint_class'].values.astype(int)
return torch.tensor(labels, dtype=torch.long)
except Exception as e:
print(f"Error reading dataset file: {e}")
return None
def preprocess_data(data_path, vocab, max_length=128):
try:
with open(data_path, 'r') as f:
sequences = f.readlines()
except Exception as e:
print(f"Error reading data file: {e}")
return None, None
if len(sequences) == 0:
raise ValueError(f"No sequences found in data file {data_path}. Check the file content.")
tokenized_sequences = []
for sequence in sequences:
sequence = sequence.strip()
if sequence:
encoded = vocab.to_seq(sequence, seq_len=max_length)
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
segment_label = [0] * max_length
tokenized_sequences.append({
'input_ids': torch.tensor(encoded),
'segment_label': torch.tensor(segment_label)
})
if not tokenized_sequences:
raise ValueError("Tokenization resulted in an empty list. Check the sequences and tokenization logic.")
tokenized_sequences = [t for t in tokenized_sequences if len(t['input_ids']) == max_length]
if not tokenized_sequences:
raise ValueError("All tokenized sequences are of unexpected length. This suggests an issue with the tokenization logic.")
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
print(f"Input IDs shape: {input_ids.shape}")
print(f"Segment labels shape: {segment_labels.shape}")
return input_ids, segment_labels
def collate_fn(batch):
inputs = []
labels = []
segment_labels = []
for item in batch:
if item is None:
continue
if isinstance(item, dict):
inputs.append(item['input_ids'].unsqueeze(0))
labels.append(item['label'].unsqueeze(0))
segment_labels.append(item['segment_label'].unsqueeze(0))
if len(inputs) == 0 or len(segment_labels) == 0:
print("Empty batch encountered. Returning None to skip this batch.")
return None
try:
inputs = torch.cat(inputs, dim=0)
labels = torch.cat(labels, dim=0)
segment_labels = torch.cat(segment_labels, dim=0)
except Exception as e:
print(f"Error concatenating tensors: {e}")
return None
return {
'input': inputs,
'label': labels,
'segment_label': segment_labels
}
def custom_collate_fn(batch):
processed_batch = collate_fn(batch)
if processed_batch is None or len(processed_batch['input']) == 0:
# Return a valid batch with at least one element instead of an empty one
return {
'input': torch.zeros((1, 128), dtype=torch.long),
'label': torch.zeros((1,), dtype=torch.long),
'segment_label': torch.zeros((1, 128), dtype=torch.long)
}
return processed_batch
def train_without_progress_status(trainer, epoch, shuffle):
for epoch_idx in range(epoch):
print(f"EP_train:{epoch_idx}:")
for batch in trainer.train_data:
if batch is None:
continue
# Check if batch is a string (indicating an issue)
if isinstance(batch, str):
print(f"Error: Received a string instead of a dictionary in batch: {batch}")
raise ValueError(f"Unexpected string in batch: {batch}")
# Validate the batch structure before passing to iteration
if isinstance(batch, dict):
# Verify that all expected keys are present and that the values are tensors
if all(key in batch for key in ['input_ids', 'segment_label', 'labels']):
if all(isinstance(batch[key], torch.Tensor) for key in batch):
try:
print(f"Batch Structure: {batch}") # Debugging batch before iteration
trainer.iteration(epoch_idx, batch)
except Exception as e:
print(f"Error during batch processing: {e}")
sys.stdout.flush()
raise e # Propagate the exception for better debugging
else:
print(f"Error: Expected all values in batch to be tensors, but got: {batch}")
raise ValueError("Batch contains non-tensor values.")
else:
print(f"Error: Batch missing expected keys. Batch keys: {batch.keys()}")
raise ValueError("Batch does not contain expected keys.")
else:
print(f"Error: Expected batch to be a dictionary but got {type(batch)} instead.")
raise ValueError(f"Invalid batch structure: {batch}")
# def main(opt):
# # device = torch.device("cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# vocab = Vocab(opt.vocab_file)
# vocab.load_vocab()
# input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
# labels = preprocess_labels(opt.dataset)
# if input_ids is None or segment_labels is None or labels is None:
# print("Error in preprocessing data. Exiting.")
# return
# dataset = TensorDataset(input_ids, segment_labels, torch.tensor(labels, dtype=torch.long))
# val_size = len(dataset) - int(0.8 * len(dataset))
# val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
# train_dataloader = DataLoader(
# train_dataset,
# batch_size=32,
# shuffle=True,
# collate_fn=custom_collate_fn
# )
# val_dataloader = DataLoader(
# val_dataset,
# batch_size=32,
# shuffle=False,
# collate_fn=custom_collate_fn
# )
# custom_model = CustomBERTModel(
# vocab_size=len(vocab.vocab),
# output_dim=2,
# pre_trained_model_path=opt.pre_trained_model_path
# ).to(device)
# trainer = BERTFineTuneTrainer1(
# bert=custom_model.bert,
# vocab_size=len(vocab.vocab),
# train_dataloader=train_dataloader,
# test_dataloader=val_dataloader,
# lr=5e-5,
# num_labels=2,
# with_cuda=torch.cuda.is_available(),
# log_freq=10,
# workspace_name=opt.output_dir,
# log_folder_path=opt.log_folder_path
# )
# trainer.train(epoch=20)
# # os.makedirs(opt.output_dir, exist_ok=True)
# # output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model.pth')
# # torch.save(custom_model.state_dict(), output_model_file)
# # print(f'Model saved to {output_model_file}')
# os.makedirs(opt.output_dir, exist_ok=True)
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
# torch.save(custom_model, output_model_file)
# print(f'Model saved to {output_model_file}')
def main(opt):
# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available()) # Should return True if GPU is available
print(torch.cuda.device_count())
# Load vocabulary
vocab = Vocab(opt.vocab_file)
vocab.load_vocab()
# Preprocess data and labels
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
labels = preprocess_labels(opt.dataset)
if input_ids is None or segment_labels is None or labels is None:
print("Error in preprocessing data. Exiting.")
return
# Transfer tensors to the correct device (GPU/CPU)
input_ids = input_ids.to(device)
segment_labels = segment_labels.to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
# Create TensorDataset and split into train and validation sets
dataset = TensorDataset(input_ids, segment_labels, labels)
val_size = len(dataset) - int(0.8 * len(dataset))
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
# Create DataLoaders for training and validation
train_dataloader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
collate_fn=custom_collate_fn
)
val_dataloader = DataLoader(
val_dataset,
batch_size=32,
shuffle=False,
collate_fn=custom_collate_fn
)
# Initialize custom BERT model and move it to the device
custom_model = CustomBERTModel(
vocab_size=len(vocab.vocab),
output_dim=2,
pre_trained_model_path=opt.pre_trained_model_path
).to(device)
# Initialize the fine-tuning trainer
trainer = BERTFineTuneTrainer1(
bert=custom_model.bert,
vocab_size=len(vocab.vocab),
train_dataloader=train_dataloader,
test_dataloader=val_dataloader,
lr=5e-5,
num_labels=2,
with_cuda=torch.cuda.is_available(),
log_freq=10,
workspace_name=opt.output_dir,
log_folder_path=opt.log_folder_path
)
# Train the model
trainer.train(epoch=20)
# Save the model to the specified output directory
# os.makedirs(opt.output_dir, exist_ok=True)
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
# torch.save(custom_model.state_dict(), output_model_file)
# print(f'Model saved to {output_model_file}')
os.makedirs(opt.output_dir, exist_ok=True)
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
torch.save(custom_model, output_model_file)
print(f'Model saved to {output_model_file}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct_logs', help='Path to the folder for saving logs.')
opt = parser.parse_args()
main(opt) |