|
|
|
|
|
import os |
|
import sys |
|
from datasets import load_dataset, load_from_disk, concatenate_datasets |
|
from transformers import PreTrainedTokenizerFast |
|
import transformers |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
Trainer, |
|
TrainingArguments, |
|
default_data_collator, |
|
) |
|
from transformers.trainer_utils import get_last_checkpoint |
|
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModel |
|
|
|
from transformers import GPT2Model |
|
from transformers import GPT2TokenizerFast |
|
import transformers |
|
import torch |
|
import numpy as np |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('test', type=int) |
|
parser.add_argument('length', type=int) |
|
|
|
args = parser.parse_args() |
|
|
|
def compute_metrics(eval_pred): |
|
logits,labels = eval_pred |
|
|
|
predictions = np.zeros(logits.shape) |
|
predictions[np.arange(len(predictions)),logits.argmax(1)] = 1 |
|
predictions = predictions > 0.5 |
|
|
|
|
|
labels = labels > 0.5 |
|
return {"acc":np.all(predictions == labels,axis=1).sum()/predictions.shape[0]} |
|
|
|
def compute_metrics_regression(eval_pred): |
|
logits,labels = eval_pred |
|
|
|
labels = np.expand_dims(labels,1) |
|
val = np.abs(logits-labels).mean() |
|
perc = ((np.abs(logits-labels).round() < 1).sum()*100) / (len(labels)) |
|
perc_50 = ((np.abs(logits-labels).round()[0:50] < 1).sum()*100) / (50) |
|
|
|
return {"dev":val,"perc":perc,"perc_50":perc_50} |
|
|
|
|
|
|
|
class MultilabelTrainer(Trainer): |
|
def compute_loss(self,model,inputs,return_outputs=False): |
|
labels = inputs.pop("labels") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
loss_fct = torch.nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits.view(-1,self.model.config.num_labels), |
|
labels.float().view(-1,self.model.config.num_labels)) |
|
return (loss,outputs) if return_outputs else loss |
|
|
|
def main(): |
|
ds_names = ["yle", "online_reviews","xed","ylilauta"] |
|
|
|
print("test:",args.test) |
|
ds_name = ds_names[args.test] |
|
ds_size = args.length |
|
print(ds_name, ds_size) |
|
|
|
metric = compute_metrics_regression if ds_name == "online_reviews" else compute_metrics |
|
|
|
|
|
|
|
|
|
|
|
output_dir = "/scratch/project_462000007/hatanpav/output/dippa/gpt/"+ds_name |
|
|
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
per_device_train_batch_size=4, |
|
per_device_eval_batch_size=4, |
|
learning_rate=5e-6, |
|
adam_beta1=0.95, |
|
adam_beta2=0.985, |
|
adam_epsilon=1e-8, |
|
weight_decay=0.001, |
|
lr_scheduler_type="linear", |
|
gradient_accumulation_steps=2, |
|
max_steps=10000, |
|
num_train_epochs=20000, |
|
save_total_limit=2, |
|
dataloader_num_workers=5, |
|
save_steps=100000, |
|
warmup_steps=500, |
|
do_eval=True, |
|
eval_steps=500, |
|
evaluation_strategy="steps", |
|
logging_strategy="steps", |
|
logging_steps=50, |
|
fp16_opt_level="O2", |
|
half_precision_backend="amp", |
|
log_on_each_node=False, |
|
disable_tqdm=True |
|
) |
|
|
|
print(training_args) |
|
|
|
dataset = load_from_disk(r"/path/to/data/"+ds_name) |
|
|
|
|
|
n_labels = 1 |
|
trainer_class = MultilabelTrainer |
|
try: |
|
n_labels = len(dataset["train"][0]["labels"]) |
|
except: |
|
|
|
n_labels = 1 |
|
trainer_class = Trainer |
|
if ds_size > len(dataset["train"]): |
|
ds_size = len(dataset["train"]) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("/checkpoint/loc",num_labels=n_labels) |
|
tokenizer = AutoTokenizer.from_pretrained("/checkpoint/loc") |
|
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) |
|
|
|
print("init trainer") |
|
train_set = dataset["train"].select(range(ds_size)) |
|
test_set = dataset["test"] |
|
trainer = trainer_class( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_set, |
|
eval_dataset=test_set, |
|
tokenizer=tokenizer, |
|
compute_metrics=metric, |
|
data_collator=default_data_collator |
|
) |
|
checkpoint = None |
|
|
|
train_result = trainer.train(resume_from_checkpoint=checkpoint) |
|
|
|
metrics = trainer.evaluate() |
|
print(metrics) |
|
trainer.save_model() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|