|
|
|
|
|
from dataclasses import dataclass, field |
|
import os |
|
from pathlib import Path |
|
import platform |
|
import re |
|
from typing import Dict, List, Optional, Union |
|
|
|
if platform.system() == "Windows": |
|
from project_settings import project_path |
|
else: |
|
project_path = os.path.abspath("./") |
|
project_path = Path(project_path) |
|
|
|
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache |
|
|
|
from datasets import concatenate_datasets, load_dataset |
|
import huggingface_hub |
|
import torch |
|
import torch.multiprocessing as mp |
|
from transformers import HfArgumentParser |
|
from transformers.data.data_collator import DataCollatorForLanguageModeling |
|
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel |
|
from transformers.trainer import Trainer |
|
from transformers.trainer_callback import EarlyStoppingCallback |
|
from transformers.training_args import TrainingArguments |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
|
|
dataset_path: str = field(default="qgyd2021/chinese_chitchat") |
|
dataset_name: str = field(default=None) |
|
dataset_split: str = field(default=None) |
|
dataset_cache_dir: str = field(default=(project_path / "hub_datasets").as_posix()) |
|
dataset_streaming: bool = field(default=False) |
|
num_workers: int = field(default=None if platform.system() == "Windows" else os.cpu_count() // 2) |
|
|
|
valid_dataset_size: int = field(default=10000) |
|
seed: int = field(default=3407) |
|
|
|
|
|
|
|
|
|
|
|
pretrained_model_name_or_path: str = field( |
|
default="qgyd2021/chinese_chitchat" |
|
) |
|
hf_token: str = field(default="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") |
|
|
|
|
|
def get_args(): |
|
parser = HfArgumentParser(ScriptArguments) |
|
args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0] |
|
return args |
|
|
|
|
|
def train_model(local_rank, world_size, args): |
|
os.environ["RANK"] = f"{local_rank}" |
|
os.environ["LOCAL_RANK"] = f"{local_rank}" |
|
os.environ["WORLD_SIZE"] = f"{world_size}" |
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = "12355" |
|
|
|
huggingface_hub.login(token=args.hf_token) |
|
|
|
|
|
names = [ |
|
|
|
|
|
"xiaohuangji" |
|
] |
|
dataset_list = list() |
|
for name in names: |
|
dataset_dict = load_dataset( |
|
path=args.dataset_path, |
|
name=name, |
|
split=args.dataset_split, |
|
cache_dir=args.dataset_cache_dir, |
|
|
|
streaming=args.dataset_streaming, |
|
) |
|
|
|
dataset = dataset_dict["train"] |
|
dataset_list.append(dataset) |
|
|
|
dataset = concatenate_datasets(dataset_list) |
|
|
|
if args.dataset_streaming: |
|
valid_dataset = dataset.take(args.valid_dataset_size) |
|
train_dataset = dataset.skip(args.valid_dataset_size) |
|
train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer_size, seed=args.seed) |
|
else: |
|
dataset = dataset.train_test_split(test_size=args.valid_dataset_size, seed=args.seed) |
|
train_dataset = dataset["train"] |
|
valid_dataset = dataset["test"] |
|
|
|
|
|
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path) |
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) |
|
|
|
|
|
def encode(examples: dict): |
|
conversation_ = examples.pop("conversation") |
|
|
|
utterances = list() |
|
for row_ in conversation_: |
|
message_ = row_["message"] |
|
utterance = tokenizer.sep_token.join(message_) |
|
utterances.append(utterance) |
|
|
|
utterances = tokenizer.__call__( |
|
text=utterances, |
|
truncation=True, |
|
padding="longest", |
|
max_length=1024, |
|
return_special_tokens_mask=True, |
|
) |
|
return utterances |
|
|
|
train_dataset = train_dataset.map( |
|
encode, |
|
batched=True, |
|
drop_last_batch=True, |
|
batch_size=10, |
|
num_proc=args.num_workers if not args.dataset_streaming else None, |
|
cache_file_name="train.cache" |
|
) |
|
valid_dataset = valid_dataset.map( |
|
encode, |
|
batched=True, |
|
drop_last_batch=True, |
|
batch_size=10, |
|
num_proc=args.num_workers if not args.dataset_streaming else None, |
|
cache_file_name="valid.cache" |
|
) |
|
dataset_info = f""" |
|
train dataset: {len(train_dataset)} |
|
valid dataset: {len(valid_dataset)} |
|
""" |
|
dataset_info = re.sub(r"[\u0020]{4,}", "", dataset_info) |
|
print(dataset_info) |
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, mlm=False |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="output_dir", |
|
evaluation_strategy="steps", |
|
per_device_train_batch_size=16, |
|
gradient_accumulation_steps=4, |
|
learning_rate=2e-4, |
|
weight_decay=0, |
|
max_grad_norm=1.0, |
|
num_train_epochs=40.0, |
|
warmup_steps=10000, |
|
logging_steps=1000, |
|
save_strategy="steps", |
|
save_steps=1000, |
|
save_total_limit=2, |
|
no_cuda=False, |
|
fp16=True if torch.cuda.is_available() else False, |
|
local_rank=local_rank, |
|
ddp_backend="nccl", |
|
remove_unused_columns=True, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="loss", |
|
greater_is_better=False, |
|
report_to="tensorboard", |
|
push_to_hub=True, |
|
hub_model_id="chinese_chitchat", |
|
hub_strategy="every_save", |
|
gradient_checkpointing=True, |
|
) |
|
|
|
partial_state_str = f""" |
|
distributed_type: {training_args.distributed_state.distributed_type} |
|
local_process_index: {training_args.distributed_state.local_process_index} |
|
num_processes: {training_args.distributed_state.num_processes} |
|
process_index: {training_args.distributed_state.process_index} |
|
device: {training_args.distributed_state.device} |
|
""" |
|
partial_state_str = re.sub(r"[\u0020]{4,}", "", partial_state_str) |
|
print(partial_state_str) |
|
|
|
environ = f""" |
|
RANK: {os.environ.get("RANK", -1)} |
|
WORLD_SIZE: {os.environ.get("WORLD_SIZE", -1)} |
|
LOCAL_RANK: {os.environ.get("LOCAL_RANK", -1)} |
|
""" |
|
environ = re.sub(r"[\u0020]{4,}", "", environ) |
|
print(environ) |
|
|
|
callbacks = [ |
|
EarlyStoppingCallback(early_stopping_patience=5) |
|
] |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=valid_dataset, |
|
tokenizer=tokenizer, |
|
callbacks=callbacks |
|
) |
|
train_result = trainer.train() |
|
|
|
|
|
final_save_path = os.path.join(training_args.output_dir, "final") |
|
trainer.save_model(final_save_path) |
|
|
|
metrics = train_result.metrics |
|
trainer.log_metrics("train", metrics) |
|
trainer.save_metrics("train", metrics) |
|
trainer.save_state() |
|
|
|
tokenizer.save_pretrained(final_save_path) |
|
return |
|
|
|
|
|
def train_on_cpu(): |
|
args = get_args() |
|
|
|
train_model(0, 1, args) |
|
return |
|
|
|
|
|
def train_on_kaggle_notebook(): |
|
""" |
|
train on kaggle notebook with GPU T4 x2 |
|
|
|
from shutil import copyfile |
|
copyfile(src = "../input/tempdataset/step_2_train_model.py", dst = "../working/step_2_train_model.py") |
|
|
|
import step_2_train_model |
|
step_2_train_model.train_on_kaggle_notebook() |
|
|
|
""" |
|
args = get_args() |
|
|
|
world_size = torch.cuda.device_count() |
|
print("world_size: {}".format(world_size)) |
|
|
|
mp.spawn(train_model, |
|
args=(world_size, args), |
|
nprocs=world_size, |
|
join=True) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
train_on_cpu() |
|
|