Spaces:
Paused
Paused
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Fine-tuning the library models for sequence to sequence. | |
""" | |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. | |
# Adapted from | |
import logging | |
import os | |
import sys | |
import torch | |
import json | |
import transformers | |
from transformers import ( | |
AutoConfig, | |
AutoModel, | |
AutoTokenizer, | |
DataCollatorForSeq2Seq, | |
HfArgumentParser, | |
Seq2SeqTrainingArguments, | |
set_seed, | |
) | |
from trainer import PrefixTrainer | |
from arguments import ModelArguments, DataTrainingArguments | |
from preprocess_utils import sanity_check, MultiTurnDataset, InputOutputDataset | |
logger = logging.getLogger(__name__) | |
def main(): | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) | |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | |
# If we pass only one argument to the script and it's the path to a json file, | |
# let's parse it to get our arguments. | |
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | |
else: | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
# Setup logging | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
if training_args.should_log: | |
# The default of training_args.log_level is passive, so we set log level at info here to have that default. | |
transformers.utils.logging.set_verbosity_info() | |
log_level = training_args.get_process_log_level() | |
logger.setLevel(log_level) | |
# datasets.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.enable_default_handler() | |
transformers.utils.logging.enable_explicit_format() | |
# Log on each process the small summary: | |
logger.warning( | |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | |
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
) | |
logger.info(f"Training/evaluation parameters {training_args}") | |
# Set seed before initializing model. | |
set_seed(training_args.seed) | |
# Load pretrained model and tokenizer | |
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) | |
config.pre_seq_len = model_args.pre_seq_len | |
config.prefix_projection = model_args.prefix_projection | |
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) | |
if model_args.ptuning_checkpoint is not None: | |
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) | |
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) | |
new_prefix_state_dict = {} | |
for k, v in prefix_state_dict.items(): | |
if k.startswith("transformer.prefix_encoder."): | |
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v | |
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) | |
else: | |
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) | |
if model_args.quantization_bit is not None: | |
print(f"Quantized to {model_args.quantization_bit} bit") | |
model = model.quantize(model_args.quantization_bit) | |
if model_args.pre_seq_len is not None: | |
# P-tuning v2 | |
model = model.half() | |
model.transformer.prefix_encoder.float() | |
else: | |
# Finetune | |
model = model.float() | |
with open(data_args.train_file, "r", encoding="utf-8") as f: | |
if data_args.train_file.endswith(".json"): | |
train_data = json.load(f) | |
elif data_args.train_file.endswith(".jsonl"): | |
train_data = [json.loads(line) for line in f] | |
if data_args.train_format == "multi-turn": | |
train_dataset = MultiTurnDataset( | |
train_data, | |
tokenizer, | |
data_args.max_seq_length, | |
) | |
elif data_args.train_format == "input-output": | |
train_dataset = InputOutputDataset( | |
train_data, | |
tokenizer, | |
data_args.max_source_length, | |
data_args.max_target_length, | |
) | |
else: | |
raise ValueError(f"Unknown train format: {data_args.train_format}") | |
if training_args.local_rank < 1: | |
sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer) | |
# Data collator | |
data_collator = DataCollatorForSeq2Seq( | |
tokenizer, | |
model=model, | |
label_pad_token_id=-100, | |
pad_to_multiple_of=None, | |
padding=False | |
) | |
# Initialize our Trainer | |
trainer = PrefixTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
save_changed=model_args.pre_seq_len is not None | |
) | |
checkpoint = None | |
if training_args.resume_from_checkpoint is not None: | |
checkpoint = training_args.resume_from_checkpoint | |
model.gradient_checkpointing_enable() | |
model.enable_input_require_grads() | |
trainer.train(resume_from_checkpoint=checkpoint) | |
trainer.save_model() # Saves the tokenizer too for easy upload | |
trainer.save_state() | |
if __name__ == "__main__": | |
main() | |