shellgpt / gpt
ctrlos's picture
Create gpt
67f470f verified
raw history blame
No virus
1.72 kB
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
class ShellcodeDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
intent = self.data[idx]['intent']
snippet = self.data[idx]['snippet']
encoding = self.tokenizer(intent, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
label = self.tokenizer(snippet, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
return {'input_ids': encoding['input_ids'], 'labels': label['input_ids']}
# Initialize tokenizer and model
model_name = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Add padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token
# Load the dataset
dataset = load_dataset('SoLID/shellcode_i_a32')
# Create the dataset and dataloader
train_dataset = ShellcodeDataset(dataset['train'], tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=16)
# Define the optimizer and criterion
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# Training loop
model.train()
for epoch in range(3):
for batch in train_dataloader:
optimizer.zero_grad()
input_ids, labels = batch['input_ids'], batch['labels']
outputs = model(input_ids)
loss = criterion(outputs.logits, labels)
loss.backward()
optimizer.step()