In [None]:
!pip install torch transformers scikit-learn wandb accelerate tqdm
from IPython.display import clear_output
clear_output(wait=True)
print(".")

In [None]:
!apt-get update
!apt-get install zstd
!tar --use-compress-program=unzstd -xvf bert_streamed_dataset.tar.zst
clear_output(wait=True)
print(".")

In [None]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb
import json

# Initialize W&B
wandb.init(project="distilbert-ai-text-classification")

# Check if MPS is available and set the device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

# Load pre-trained DistilBERT tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
model.to(device)

In [None]:
# Load the JSONL dataset
data = []
total_num_of_lines = 0
with open('bert_reddit_vs_synth_writing_prompts.jsonl', 'r') as infile:
 for line in tqdm(infile, desc="Checking dataset size"):
 total_num_of_lines += 1

with open('bert_reddit_vs_synth_writing_prompts.jsonl', 'r') as infile:
 for line in tqdm(infile, desc="Loading dataset", total=total_num_of_lines):
 data.append(json.loads(line))

# Extract texts and labels
print("Extracting texts and labels")
texts = [entry['text'] for entry in data]
labels = [entry['label'] for entry in data]

# Tokenize the text
print("Tokenizing text")
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

# Move input tensors to the device
print("Moving input tensors")
inputs = {key: val for key, val in inputs.items()}

# Split the data into training and validation sets
print("Splitting data into train and validation")
train_inputs, val_inputs, train_labels, val_labels = train_test_split(
 inputs['input_ids'], labels, test_size=0.2, random_state=42)

train_attention_masks, val_attention_masks, _, _ = train_test_split(
 inputs['attention_mask'], labels, test_size=0.2, random_state=42)

# Create a PyTorch dataset
class TextDataset(torch.utils.data.Dataset):
 def __init__(self, input_ids, attention_masks, labels):
 self.input_ids = input_ids
 self.attention_masks = attention_masks
 self.labels = labels

 def __len__(self):
 return len(self.labels)

 def __getitem__(self, idx):
 return {
 'input_ids': self.input_ids[idx],
 'attention_mask': self.attention_masks[idx],
 'labels': torch.tensor(self.labels[idx])
 }

print("Creating pytorch datasets")
train_dataset = TextDataset(train_inputs, train_attention_masks, train_labels)
val_dataset = TextDataset(val_inputs, val_attention_masks, val_labels)

In [None]:
# Reduce eval set to X examples to speed up training
NUM_OF_EVAL_EXAMPLES = 1000
val_inputs_subset = val_inputs[:NUM_OF_EVAL_EXAMPLES]
val_attention_masks_subset = val_attention_masks[:NUM_OF_EVAL_EXAMPLES]
val_labels_subset = val_labels[:NUM_OF_EVAL_EXAMPLES]

# Create a TextDataset with only X examples
val_dataset = Textdataset(val_inputs_subset, val_attention_masks_subset, val_labels_subset)

In [None]:
# Define the training arguments
training_args = TrainingArguments(
 output_dir='./distil-bert-train-results', 
 num_train_epochs=3, 
 per_device_train_batch_size=16, 
 per_device_eval_batch_size=16, 
 warmup_steps=500, # number of warmup steps for learning rate scheduler
 weight_decay=0.01, 
 logging_dir='./logs', 
 logging_steps=10, 
 report_to="wandb", 
 evaluation_strategy="steps", # Evaluate every logging step
 eval_steps=100, # Evaluate every 10 steps
 fp16=True,
)

# Create the Trainer
trainer = Trainer(
 model=model, # the instantiated 🤗 Transformers model to be trained
 args=training_args, # training arguments, defined above
 train_dataset=train_dataset, # training dataset
 eval_dataset=val_dataset # evaluation dataset
)

# Train the model
trainer.train()

# Save the model
model.save_pretrained('./distil-bert-train-final-result')

# Finish the W&B run
wandb.finish()