Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import os.path as osp | |
import sys | |
import numpy as np | |
from typing import Dict | |
import datasets | |
import transformers | |
from transformers import set_seed, Trainer | |
from transformers.trainer_utils import get_last_checkpoint | |
from arguments import get_args | |
from tasks.utils import * | |
os.environ["WANDB_DISABLED"] = "true" | |
logger = logging.getLogger(__name__) | |
def train(trainer, resume_from_checkpoint=None, last_checkpoint=None): | |
checkpoint = None | |
if resume_from_checkpoint is not None: | |
checkpoint = resume_from_checkpoint | |
elif last_checkpoint is not None: | |
checkpoint = last_checkpoint | |
train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
# trainer.save_model() | |
metrics = train_result.metrics | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
trainer.save_state() | |
trainer.log_best_metrics() | |
def evaluate(args, trainer, checkpoint=None): | |
logger.info("*** Evaluate ***") | |
if checkpoint is not None: | |
trainer._load_from_checkpoint(resume_from_checkpoint=checkpoint) | |
trainer._resume_watermark() | |
metrics = trainer.evaluate(ignore_keys=["hidden_states", "attentions"]) | |
score, asr = 0., 0. | |
if training_args.watermark != "clean": | |
score, asr = trainer.evaluate_watermark() | |
metrics["wmk_asr"] = asr | |
metrics["wmk_score"] = score | |
trainer.evaluate_clean() | |
torch.save(trainer.eval_memory, f"{args.output_dir}/exp11_attentions.pth") | |
trainer.log_metrics("eval", metrics) | |
path = osp.join(args.output_dir, "exp11_acc_asr.pth") | |
torch.save(metrics, path) | |
def predict(trainer, predict_dataset=None): | |
if predict_dataset is None: | |
logger.info("No dataset is available for testing") | |
elif isinstance(predict_dataset, dict): | |
for dataset_name, d in predict_dataset.items(): | |
logger.info("*** Predict: %s ***" % dataset_name) | |
predictions, labels, metrics = trainer.predict(d, metric_key_prefix="predict") | |
predictions = np.argmax(predictions, axis=2) | |
trainer.log_metrics("predict", metrics) | |
trainer.save_metrics("predict", metrics) | |
else: | |
logger.info("*** Predict ***") | |
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") | |
predictions = np.argmax(predictions, axis=2) | |
trainer.log_metrics("predict", metrics) | |
trainer.save_metrics("predict", metrics) | |
if __name__ == '__main__': | |
args = get_args() | |
p_type = "prefix" if args[0].prefix else "prompt" | |
output_root = osp.join("checkpoints", f"{args[1].task_name}_{args[1].dataset_name}_{args[0].model_name_or_path}_{args[2].watermark}_{p_type}") | |
output_dir = osp.join(output_root, f"t{args[2].trigger_num}_p{args[2].poison_rate:0.2f}") | |
for path in [output_root, output_dir]: | |
if not osp.exists(path): | |
try: | |
os.makedirs(path) | |
except: | |
pass | |
args[0].output_dir = output_dir | |
args[1].output_dir = output_dir | |
args[2].output_dir = output_dir | |
args[3].output_dir = output_dir | |
torch.save(args, osp.join(output_dir, "args.pt")) | |
model_args, data_args, training_args, _ = args | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
log_level = training_args.get_process_log_level() | |
logger.setLevel(log_level) | |
datasets.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.enable_default_handler() | |
transformers.utils.logging.enable_explicit_format() | |
# Log on each process the small summary: | |
logger.warning( | |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | |
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
) | |
if not os.path.isdir("checkpoints") or not os.path.exists("checkpoints"): | |
os.mkdir("checkpoints") | |
if data_args.task_name.lower() == "superglue": | |
assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS | |
from tasks.superglue.get_trainer import get_trainer | |
elif data_args.task_name.lower() == "glue": | |
assert data_args.dataset_name.lower() in GLUE_DATASETS | |
from tasks.glue.get_trainer import get_trainer | |
elif data_args.task_name.lower() == "ner": | |
assert data_args.dataset_name.lower() in NER_DATASETS | |
from tasks.ner.get_trainer import get_trainer | |
elif data_args.task_name.lower() == "srl": | |
assert data_args.dataset_name.lower() in SRL_DATASETS | |
from tasks.srl.get_trainer import get_trainer | |
elif data_args.task_name.lower() == "qa": | |
assert data_args.dataset_name.lower() in QA_DATASETS | |
from tasks.qa.get_trainer import get_trainer | |
elif data_args.task_name.lower() == "ag_news": | |
from tasks.ag_news.get_trainer import get_trainer | |
elif data_args.task_name.lower() == "imdb": | |
from tasks.imdb.get_trainer import get_trainer | |
else: | |
raise NotImplementedError('Task {} is not implemented. Please choose a task from: {}'.format(data_args.task_name, ", ".join(TASKS))) | |
set_seed(training_args.seed) | |
trainer, predict_dataset = get_trainer(args) | |
last_checkpoint = None | |
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: | |
last_checkpoint = get_last_checkpoint(training_args.output_dir) | |
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: | |
raise ValueError( | |
f"Output directory ({training_args.output_dir}) already exists and is not empty. " | |
"Use --overwrite_output_dir to overcome." | |
) | |
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: | |
logger.info( | |
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " | |
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." | |
) | |
if training_args.do_train: | |
train(trainer, training_args.resume_from_checkpoint, last_checkpoint) | |
if training_args.do_eval: | |
if last_checkpoint is None: | |
last_checkpoint = osp.join(training_args.output_dir, "checkpoint") | |
print(f"-> last_checkpoint:{last_checkpoint}") | |
evaluate(training_args, trainer, checkpoint=last_checkpoint) | |
# if training_args.do_predict: | |
# predict(trainer, predict_dataset) | |