|
from datasets import load_dataset |
|
from transformers import DataCollatorForLanguageModeling |
|
from transformers import Trainer, TrainingArguments |
|
import os |
|
import torch |
|
|
|
|
|
|
|
def main(): |
|
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
rank = int(os.environ['RANK']) |
|
world_size = int(os.environ['WORLD_SIZE']) |
|
|
|
torch.distributed.init_process_group("nccl") |
|
print(f"Local Rank = {local_rank}/{world_size}") |
|
|
|
|
|
|
|
|
|
dataset = load_dataset('json', data_files='../../data/m2_250514_1150.jsonl', split='train') |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
model_name = "FacebookAI/roberta-base" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["text"], truncation=True, max_length=512) |
|
|
|
tokenized_dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.1) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, mlm=False |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
overwrite_output_dir=True, |
|
num_train_epochs=3, |
|
per_device_train_batch_size=4, |
|
per_device_eval_batch_size=4, |
|
dataloader_num_workers=8, |
|
eval_steps=500, |
|
save_steps=1000, |
|
warmup_steps=500, |
|
prediction_loss_only=True, |
|
logging_dir="./logs", |
|
logging_steps=100, |
|
learning_rate=5e-5, |
|
fp16=True, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=split_dataset["train"], |
|
eval_dataset=split_dataset["test"], |
|
data_collator=data_collator, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
torch.distributed.destroy_process_group() |
|
|
|
|
|
model.save_pretrained("./fine_tuned_model") |
|
tokenizer.save_pretrained("./fine_tuned_model") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|