traffic_main / main.py
Mitocho's picture
Create main.py
9a1491b verified
from datasets import load_dataset
from transformers import GPT2LMHeadModel
dataset = load_dataset("Mireu-Lab/UNSW-NB15", name="en-US")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# 训练模型
trainer = Trainer(
model=model,
train_dataset=train_data,
)
trainer.train()
# 评估模型
trainer.evaluate()