Something2109 commited on
Commit
25d7e4e
•
1 Parent(s): a96d1b8

Move the py files to a directory

Browse files
tokenizer.py → src/tokenizer.py RENAMED
File without changes
BERT.py → src/train.py RENAMED
File without changes
train.py DELETED
@@ -1,106 +0,0 @@
1
- from transformers import (
2
- AutoTokenizer,
3
- AutoModel,
4
- BertModel,
5
- GPT2Model,
6
- EncoderDecoderModel,
7
- DataCollatorForSeq2Seq,
8
- Seq2SeqTrainer,
9
- Seq2SeqTrainingArguments,
10
- )
11
- from datasets import load_dataset
12
- from laonlp import word_tokenize
13
- from functools import partial
14
- import random
15
-
16
-
17
- def group_texts(tokenizer, examples):
18
- tokenized_inputs = [" ".join(word_tokenize(x)) for x in examples["text"]]
19
-
20
- tokenized_inputs = tokenizer(
21
- examples["text"],
22
- # return_special_tokens_mask=True,
23
- # padding="max_length",
24
- # truncation=True,
25
- # max_length=tokenizer.model_max_length,
26
- # return_tensors="pt",
27
- )
28
-
29
- return tokenized_inputs
30
-
31
-
32
- if __name__ == "__main__":
33
- encoder_src = "BERT\\models\\bert-culturaX-data"
34
- decoder_src = "NlpHUST/gpt2-vietnamese"
35
-
36
- encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_src)
37
- decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_src)
38
- decoder_tokenizer.model_max_length = encoder_tokenizer.model_max_length
39
- decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
40
- print(f"The max length for the tokenizer is: {encoder_tokenizer.model_max_length}")
41
-
42
- encoder = AutoModel.from_pretrained(encoder_src)
43
- decoder = AutoModel.from_pretrained(decoder_src)
44
- decoder.config.max_length = decoder_tokenizer.model_max_length
45
-
46
- model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
47
- model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
48
- model.config.pad_token_id = decoder_tokenizer.pad_token_id
49
- model.config.vocab_size = decoder_tokenizer.vocab_size
50
-
51
- data_collator = DataCollatorForSeq2Seq(decoder_tokenizer, model=model)
52
-
53
- raw_lo_dataset = load_dataset("bert/dataset/original/lo")
54
- raw_vi_dataset = load_dataset("bert/dataset/original/vi")
55
-
56
- train_dataset = raw_lo_dataset["train"].map(
57
- partial(group_texts, encoder_tokenizer),
58
- remove_columns=["text"],
59
- batched=True,
60
- num_proc=12,
61
- )
62
- eval_dataset = raw_lo_dataset["validation"].map(
63
- partial(group_texts, encoder_tokenizer),
64
- batched=True,
65
- remove_columns=["text"],
66
- )
67
- train_labels = raw_vi_dataset["train"].map(
68
- partial(group_texts, decoder_tokenizer),
69
- remove_columns=["text"],
70
- batched=True,
71
- num_proc=12,
72
- )
73
- eval_labels = raw_vi_dataset["validation"].map(
74
- partial(group_texts, decoder_tokenizer),
75
- batched=True,
76
- remove_columns=["text"],
77
- )
78
- train_dataset = train_dataset.add_column("labels", train_labels["input_ids"])
79
- eval_dataset = eval_dataset.add_column("labels", eval_labels["input_ids"])
80
-
81
- print(
82
- f"the dataset contains in total {len(train_dataset)*encoder_tokenizer.model_max_length} tokens"
83
- )
84
-
85
- model_name = "transformer-bert-gpt"
86
-
87
- training_args = Seq2SeqTrainingArguments(
88
- output_dir=f"bert/models/{model_name}",
89
- evaluation_strategy="epoch",
90
- per_device_train_batch_size=16,
91
- per_device_eval_batch_size=16,
92
- weight_decay=0.01,
93
- save_total_limit=3,
94
- num_train_epochs=2,
95
- push_to_hub=True,
96
- )
97
-
98
- trainer = Seq2SeqTrainer(
99
- model=model,
100
- args=training_args,
101
- data_collator=data_collator,
102
- train_dataset=train_dataset,
103
- eval_dataset=eval_dataset,
104
- )
105
-
106
- trainer.train()