Spaces:
Paused
Paused
voice_clone
/
transformers
/examples
/research_projects
/self-training-text-classification
/selftraining.py
| # coding=utf-8 | |
| # Copyright 2022 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Self-training for sequence classification.""" | |
| import argparse | |
| import dataclasses | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| from typing import List, Optional | |
| import datasets | |
| from accelerate import Accelerator | |
| from datasets import load_dataset | |
| from finetuning import finetune | |
| from tqdm.auto import tqdm | |
| import transformers | |
| from transformers import AutoConfig, set_seed | |
| from transformers.trainer_utils import IntervalStrategy | |
| logger = logging.getLogger(__name__) | |
| MODEL_BIN_FILE = "pytorch_model.bin" | |
| class STModelArguments: | |
| """Arguments pertaining to which config/tokenizer/model we are going to fine-tune from.""" | |
| model_name_or_path: str = dataclasses.field( | |
| metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} | |
| ) | |
| cache_dir: Optional[str] = dataclasses.field( | |
| default=None, | |
| metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co."}, | |
| ) | |
| class STDataArguments: | |
| """Arguments pertaining to what data we are going to input our model for training and evaluation.""" | |
| train_file: str = dataclasses.field(metadata={"help": "A csv or a json file containing the training data."}) | |
| infer_file: str = dataclasses.field(metadata={"help": "A csv or a json file containing the data to predict on."}) | |
| eval_file: Optional[str] = dataclasses.field( | |
| default=None, metadata={"help": "A csv or a json file containing the validation data."} | |
| ) | |
| task_name: Optional[str] = dataclasses.field( | |
| default=None, | |
| metadata={"help": "The name of the task to train on."}, | |
| ) | |
| label_list: Optional[List[str]] = dataclasses.field( | |
| default=None, metadata={"help": "The list of labels for the task."} | |
| ) | |
| class STTrainingArguments: | |
| """Training arguments pertaining to the training loop itself.""" | |
| output_dir: str = dataclasses.field( | |
| metadata={"help": "The output directory where the model predictions and checkpoints will be written."} | |
| ) | |
| eval_metric: Optional[str] = dataclasses.field( | |
| default="accuracy", metadata={"help": "The evaluation metric used for the task."} | |
| ) | |
| evaluation_strategy: Optional[str] = dataclasses.field( | |
| default="no", | |
| metadata={ | |
| "help": 'The evaluation strategy to adopt during training. Possible values are: ["no", "step", "epoch]' | |
| }, | |
| ) | |
| early_stopping_patience: Optional[int] = dataclasses.field( | |
| default=10, | |
| metadata={"help": "Number of evaluation calls with no improvement after which training will be stopped."}, | |
| ) | |
| early_stopping_threshold: Optional[float] = dataclasses.field( | |
| default=0.0, | |
| metadata={ | |
| "help": "How much the specified evaluation metric must improve to satisfy early stopping conditions." | |
| }, | |
| ) | |
| do_filter_by_confidence: Optional[bool] = dataclasses.field( | |
| default=False, | |
| metadata={"help": "Whether to filter the pseudo-labeled data based on the confidence score."}, | |
| ) | |
| do_filter_by_val_performance: Optional[bool] = dataclasses.field( | |
| default=False, | |
| metadata={"help": "Whether to filter the pseudo-labeled data based on the validation performance."}, | |
| ) | |
| finetune_on_labeled_data: Optional[bool] = dataclasses.field( | |
| default=False, | |
| metadata={"help": "Whether to fine-tune on labeled data after pseudo training."}, | |
| ) | |
| confidence_threshold: Optional[float] = dataclasses.field( | |
| default=0.0, | |
| metadata={"help": "Confidence threshold for pseudo-labeled data filtering."}, | |
| ) | |
| max_selftrain_iterations: Optional[int] = dataclasses.field( | |
| default=100, | |
| metadata={"help": "Number of evaluation calls with no improvement after which training will be stopped."}, | |
| ) | |
| seed: Optional[int] = dataclasses.field( | |
| default=None, | |
| metadata={"help": "Random seed for initialization."}, | |
| ) | |
| def create_pseudo_labeled_data(args, infer_input, infer_output, eval_result, id2label, next_data_dir): | |
| """Create pseudeo labeled data for the next self-training iteration.""" | |
| dataset = datasets.concatenate_datasets([infer_input, infer_output], axis=1) | |
| if args.do_filter_by_confidence: | |
| dataset = dataset.filter(lambda example: example["probability"] > args.confidence_threshold) | |
| if args.do_filter_by_val_performance: | |
| assert eval_result >= 0.0 and eval_result <= 1.0 | |
| num_selected_rows = int(eval_result * len(dataset)) | |
| print(num_selected_rows) | |
| dataset = dataset.sort("probability", reverse=True) | |
| dataset = dataset.select(range(num_selected_rows)) | |
| dataset = dataset.remove_columns(["label", "probability"]) | |
| dataset = dataset.rename_column("prediction", "label") | |
| dataset = dataset.map(lambda example: {"label": id2label[example["label"]]}) | |
| dataset = dataset.shuffle(seed=args.seed) | |
| pseudo_labeled_data_file = os.path.join(next_data_dir, f"train_pseudo.{args.data_file_extension}") | |
| if args.data_file_extension == "csv": | |
| dataset.to_csv(pseudo_labeled_data_file, index=False) | |
| else: | |
| dataset.to_json(pseudo_labeled_data_file) | |
| def selftrain(model_name_or_path, train_file, infer_file, output_dir, **kwargs): | |
| """Self-training a pre-trained model on a downstream task. | |
| Args: | |
| model_name_or_path: Path to pretrained model or model identifier from | |
| huggingface.co/models. | |
| train_file: A csv or a json file containing the training data. | |
| infer_file: A csv or a json file containing the data to predict on. | |
| output_dir: The output directory where the model predictions and checkpoints | |
| will be written. | |
| **kwargs: Dictionary of key/value pairs with which to update the | |
| configuration object after loading. The values in kwargs of any keys which | |
| are configuration attributes will be used to override the loaded values. | |
| """ | |
| # Initialize the accelerator. We will let the accelerator handle device | |
| # placement for us. | |
| accelerator = Accelerator() | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state) | |
| # Setup logging, we only want one process per machine to log things on the | |
| # screen. accelerator.is_local_main_process is only True for one process per | |
| # machine. | |
| logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) | |
| if accelerator.is_local_main_process: | |
| datasets.utils.logging.set_verbosity_warning() | |
| transformers.utils.logging.set_verbosity_info() | |
| else: | |
| datasets.utils.logging.set_verbosity_error() | |
| transformers.utils.logging.set_verbosity_error() | |
| model_args = STModelArguments(model_name_or_path=model_name_or_path) | |
| data_args = STDataArguments(train_file=train_file, infer_file=infer_file) | |
| training_args = STTrainingArguments(output_dir=output_dir) | |
| args = argparse.Namespace() | |
| for arg_class in (model_args, data_args, training_args): | |
| for key, value in vars(arg_class).items(): | |
| setattr(args, key, value) | |
| for key, value in kwargs.items(): | |
| if hasattr(args, key): | |
| setattr(args, key, value) | |
| # Sanity checks | |
| data_files = {} | |
| args.data_file_extension = None | |
| # You need to provide the training data and the data to predict on | |
| assert args.train_file is not None | |
| assert args.infer_file is not None | |
| data_files["train"] = args.train_file | |
| data_files["infer"] = args.infer_file | |
| if args.evaluation_strategy != IntervalStrategy.NO.value: | |
| assert args.eval_file is not None | |
| data_files["eval"] = args.eval_file | |
| for key in data_files: | |
| extension = data_files[key].split(".")[-1] | |
| assert extension in ["csv", "json"], f"`{key}_file` should be a csv or a json file." | |
| if args.data_file_extension is None: | |
| args.data_file_extension = extension | |
| else: | |
| assert extension == args.data_file_extension, f"`{key}_file` should be a {args.data_file_extension} file`." | |
| assert ( | |
| args.eval_metric in datasets.list_metrics() | |
| ), f"{args.eval_metric} not in the list of supported metrics {datasets.list_metrics()}." | |
| # If passed along, set the training seed now. | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| logger.info("Creating the initial data directory for self-training...") | |
| data_dir_format = f"{args.output_dir}/self-train_iter-{{}}".format | |
| initial_data_dir = data_dir_format(0) | |
| if accelerator.is_main_process: | |
| if args.output_dir is not None: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| os.makedirs(initial_data_dir, exist_ok=True) | |
| accelerator.wait_for_everyone() | |
| best_iteration = None | |
| best_eval_result = None | |
| early_stopping_patience_counter = 0 | |
| should_training_stop = False | |
| # Show the progress bar | |
| progress_bar = tqdm(range(args.max_selftrain_iterations), disable=not accelerator.is_local_main_process) | |
| # Self-train | |
| for iteration in range(0, int(args.max_selftrain_iterations)): | |
| current_data_dir = data_dir_format(iteration) | |
| assert os.path.exists(current_data_dir) | |
| # Stage 1: initial fine-tuning for iteration = 0 or pseudo-training for | |
| # iteration > 0 | |
| current_output_dir = os.path.join(current_data_dir, "stage-1") | |
| arguments_dict = { | |
| "accelerator": accelerator, | |
| "model_name_or_path": args.model_name_or_path, | |
| "cache_dir": args.cache_dir, | |
| "do_train": True, | |
| "train_file": data_files["train"] if iteration == 0 else data_files["train_pseudo"], | |
| "do_eval": True if args.eval_file is not None else False, | |
| "eval_file": data_files["eval"], | |
| "do_predict": True, | |
| "infer_file": data_files["infer"], | |
| "task_name": args.task_name, | |
| "label_list": args.label_list, | |
| "output_dir": current_output_dir, | |
| "eval_metric": args.eval_metric, | |
| "evaluation_strategy": args.evaluation_strategy, | |
| "early_stopping_patience": args.early_stopping_patience, | |
| "early_stopping_threshold": args.early_stopping_threshold, | |
| "seed": args.seed, | |
| } | |
| # Add additional training arguments | |
| for key, value in kwargs.items(): | |
| if key not in arguments_dict and not hasattr(training_args, key): | |
| arguments_dict.update({key: value}) | |
| model_bin_file_path = os.path.join(current_output_dir, "best-checkpoint", MODEL_BIN_FILE) | |
| if os.path.exists(model_bin_file_path): | |
| logger.info( | |
| "Found existing model checkpoint at %s. Skipping self-training: iteration: %d, stage: 1.", | |
| model_bin_file_path, | |
| iteration, | |
| ) | |
| else: | |
| logger.info("***** Running self-training: iteration: %d, stage: 1 *****", iteration) | |
| finetune(**arguments_dict) | |
| accelerator.wait_for_everyone() | |
| assert os.path.exists(model_bin_file_path) | |
| logger.info("Self-training job completed: iteration: %d, stage: 1.", iteration) | |
| if iteration > 0 and args.finetune_on_labeled_data: | |
| # Stage 2 (optional): fine-tuning on the original labeled data | |
| model_path = os.path.join(current_output_dir, "best-checkpoint") | |
| current_output_dir = os.path.join(current_data_dir, "stage-2") | |
| # Update arguments_dict | |
| arguments_dict["model_name_or_path"] = model_path | |
| arguments_dict["train_file"] = data_files["train"] | |
| arguments_dict["output_dir"] = current_output_dir | |
| model_bin_file_path = os.path.join(current_output_dir, "best-checkpoint", MODEL_BIN_FILE) | |
| if os.path.exists(model_bin_file_path): | |
| logger.info( | |
| "Found existing model checkpoint at %s. Skipping self-training: iteration: %d, stage: 2.", | |
| model_bin_file_path, | |
| iteration, | |
| ) | |
| else: | |
| logger.info("***** Running self-training: iteration: %d, stage: 2 *****", iteration) | |
| finetune(**arguments_dict) | |
| accelerator.wait_for_everyone() | |
| assert os.path.exists(model_bin_file_path) | |
| logger.info("Self-training job completed: iteration: %d, stage: 2.", iteration) | |
| new_iteration = iteration | |
| next_data_dir = data_dir_format(iteration + 1) | |
| config = AutoConfig.from_pretrained(os.path.join(current_output_dir, "best-checkpoint")) | |
| id2label = config.id2label | |
| eval_results_file = os.path.join(current_output_dir, "eval_results_best-checkpoint.json") | |
| test_results_file = os.path.join(current_output_dir, "test_results_best-checkpoint.json") | |
| assert os.path.exists(eval_results_file) | |
| with open(eval_results_file, "r") as f: | |
| eval_result = float(json.load(f)[args.eval_metric]) | |
| infer_output_file = os.path.join(current_output_dir, "infer_output_best-checkpoint.csv") | |
| assert os.path.exists(infer_output_file) | |
| # Loading the dataset from local csv or json files. | |
| infer_input = load_dataset(args.data_file_extension, data_files={"data": data_files["infer"]})["data"] | |
| infer_output = load_dataset("csv", data_files={"data": infer_output_file})["data"] | |
| if accelerator.is_main_process: | |
| os.makedirs(next_data_dir, exist_ok=True) | |
| shutil.copy(eval_results_file, os.path.join(output_dir, f"eval_results_iter-{iteration}.json")) | |
| if os.path.exists(test_results_file): | |
| shutil.copy(eval_results_file, os.path.join(output_dir, f"test_results_iter-{iteration}.json")) | |
| create_pseudo_labeled_data(args, infer_input, infer_output, eval_result, id2label, next_data_dir) | |
| accelerator.wait_for_everyone() | |
| data_files["train_pseudo"] = os.path.join(next_data_dir, f"train_pseudo.{args.data_file_extension}") | |
| if args.evaluation_strategy != IntervalStrategy.NO.value: | |
| new_eval_result = eval_result | |
| if best_iteration is None: | |
| best_iteration = new_iteration | |
| best_eval_result = new_eval_result | |
| else: | |
| if new_eval_result - best_eval_result > args.early_stopping_threshold: | |
| best_iteration = new_iteration | |
| best_eval_result = new_eval_result | |
| early_stopping_patience_counter = 0 | |
| else: | |
| if new_eval_result == best_eval_result: | |
| best_iteration = new_iteration | |
| best_eval_result = new_eval_result | |
| early_stopping_patience_counter += 1 | |
| if early_stopping_patience_counter >= args.early_stopping_patience: | |
| should_training_stop = True | |
| progress_bar.update(1) | |
| if should_training_stop: | |
| break | |
| if best_iteration is not None: | |
| # Save the best iteration | |
| logger.info("Best iteration: %d", best_iteration) | |
| logger.info("Best evaluation result: %s = %f", args.eval_metric, best_eval_result) | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| shutil.copy( | |
| os.path.join(output_dir, f"eval_results_iter-{iteration}.json"), | |
| os.path.join(output_dir, "eval_results_best-iteration.json"), | |
| ) | |
| else: | |
| # Assume that the last iteration is the best | |
| logger.info("Best iteration: %d", args.max_selftrain_iterations - 1) | |
| logger.info("Best evaluation result: %s = %f", args.eval_metric, eval_result) | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| shutil.copy( | |
| os.path.join(output_dir, f"eval_results_iter-{args.max_selftrain_iterations - 1}.json"), | |
| os.path.join(output_dir, "eval_results_best-iteration.json"), | |
| ) | |