File size: 320 Bytes
9a1491b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from datasets import load_dataset
from transformers import GPT2LMHeadModel

dataset = load_dataset("Mireu-Lab/UNSW-NB15", name="en-US")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# 训练模型
trainer = Trainer(
    model=model,
    train_dataset=train_data,
)

trainer.train()

# 评估模型
trainer.evaluate()