File size: 6,507 Bytes
e61ddcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# First, we grab tools from our toolbox. These tools help us with different tasks like reading books (datasets),
# learning new languages (tokenization), and solving puzzles (models).
from datasets import load_dataset  # This tool helps us get our book, where the puzzles are.
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_scheduler  # These help us understand and solve puzzles.
from transformers import DataCollatorWithPadding  # This makes sure all puzzle pieces are the same size.
from torch.utils.data import DataLoader  # This helps us handle one page of puzzles at a time.
import torch  # This is like the brain of our operations, helping us think through puzzles.
from tqdm.auto import tqdm  # This is our progress bar, showing us how far we've come in solving the book.
import evaluate  # This tells us how well we did in solving puzzles.
from accelerate import Accelerator  # This makes everything go super fast, like a rocket!

def train_and_save_model():
    # Now, let's pick up the book we're going to solve today.
    raw_datasets = load_dataset("glue", "mrpc")  # This is a book filled with puzzles about matching sentences.

    # Before we start solving puzzles, we need to understand the language they're written in.
    checkpoint = "bert-base-uncased"  # This is a guidebook to help us understand the puzzles' language.
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)  # This tool helps us read and understand the language in our book.

    # To solve puzzles, we need to make sure we understand each sentence properly.
    def tokenize_function(example):  # This is like reading each sentence carefully and understanding each word.
        return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

    # We prepare all puzzles in the book so they're ready to solve.
    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)  # This is like marking all the important parts of the sentences.

    # Puzzles can be different sizes, but our puzzle solver works best when all puzzles are the same size.
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)  # This adds extra paper to smaller puzzles to make them all the same size.

    # We're setting up our puzzle pages, making sure we're ready to solve them one by one.
    tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])  # We remove stuff we don't need.
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")  # We make sure the puzzle answers are labeled correctly.
    tokenized_datasets.set_format("torch")  # We make sure our puzzles are in the right format for our brain to understand.

    # Now, we're ready to start solving puzzles, one page at a time.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
    )  # This is our training puzzles.
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
    )  # These are puzzles we use to check our progress.

    # We need a puzzle solver, which is specially trained to solve these types of puzzles.
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)  # This is our puzzle-solving robot.

    # Our robot needs instructions on how to get better at solving puzzles.
    optimizer = AdamW(model.parameters(), lr=5e-5)  # This tells our robot how to improve.
    num_epochs = 3  # This is how many times we'll go through the whole book of puzzles.
    num_training_steps = num_epochs * len(train_dataloader)  # This is the total number of puzzles we'll solve.
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )  # This adjusts how quickly our robot learns over time.

    # To solve puzzles super fast, we're going to use a rocket!
    accelerator = Accelerator()  # This is our rocket that makes everything go faster.
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )  # We make sure our robot, our puzzles, and our instructions are all ready for the rocket.

    # It's time to start solving puzzles!
    progress_bar = tqdm(range(num_training_steps))  # This shows us our progress.
    model.train()  # We tell our robot it's time to start learning.
    for epoch in range(num_epochs):  # We go through our book of puzzles multiple times to get really good.
        for batch in train_dataloader:  # Each time, we take a page of puzzles to solve.
            outputs = model(**batch)  # Our robot tries to solve the puzzles.
            loss = outputs.loss  # We check how many mistakes it made.
            accelerator.backward(loss)  # We give feedback to our robot so it can learn from its mistakes.
            optimizer.step()  # We update our robot's puzzle-solving strategy.
            lr_scheduler.step()  # We adjust how quickly our robot is learning.
            optimizer.zero_grad()  # We reset some settings to make sure our robot is ready for the next page.
            progress_bar.update(1)  # We update our progress bar to show how many puzzles we've solved.

    # After all that practice, it's time to test how good our robot has become at solving puzzles.
    metric = evaluate.load("glue", "mrpc")  # This is like the answer key to check our robot's work.
    model.eval()  # We tell our robot it's time to show what it's learned.
    for batch in eval_dataloader:  # We take a page of puzzles we haven't solved yet.
        with torch.no_grad():  # We make sure we're just testing, not learning anymore.
            outputs = model(**batch)  # Our robot solves the puzzles.
        logits = outputs.logits  # We look at our robot's answers.
        predictions = torch.argmax(logits, dim=-1)  # We decide which answer our robot thinks is right.
        metric.add_batch(predictions=predictions, references=batch["labels"])  # We compare our robot's answers to the correct answers.

    final_score = metric.compute()  # We calculate how well our robot did.
    print(final_score)  # We print out the score to see how well our robot solved the puzzles!

    model.save_pretrained("path/to/save/model")
    tokenizer.save_pretrained("path/to/save/tokenizer")

if __name__ == "__main__":
    train_and_save_model()