Create autofixcode.py
Browse files- autofixcode.py +103 -0
autofixcode.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
class AutofixCodeAILLModel(AutoModelForCausalLM):
|
5 |
+
def __init__(self, *args, **kwargs):
|
6 |
+
super().__init__(*args, **kwargs)
|
7 |
+
self.decoder = AutoDecoder(self.config.decoder_hidden_size, self.config.decoder_num_layers)
|
8 |
+
|
9 |
+
@property
|
10 |
+
def decoder(self):
|
11 |
+
return self._decoder
|
12 |
+
|
13 |
+
@decoder.setter
|
14 |
+
def decoder(self, value):
|
15 |
+
self._decoder = value
|
16 |
+
|
17 |
+
class AutoDecoder(torch.nn.Module):
|
18 |
+
def __init__(self, hidden_size, num_layers):
|
19 |
+
super().__init__()
|
20 |
+
self.layers = torch.nn.ModuleList([torch.nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, dropout=0.1) for _ in range(num_layers)])
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
for layer in self.layers:
|
24 |
+
x = layer(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
# Load the pre-trained model and tokenizer
|
28 |
+
model_name_or_path = "autofixcodeai-base"
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
30 |
+
ll_model = AutofixCodeAILLModel.from_pretrained(model_name_or_path)
|
31 |
+
|
32 |
+
# Define the custom dataset class for your AutofixCodeAI model
|
33 |
+
class CodeFixDataset(torch.utils.data.Dataset):
|
34 |
+
def __init__(self, code_snippets, fix_snippets):
|
35 |
+
self.code_snippets = code_snippets
|
36 |
+
self.fix_snippets = fix_snippets
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.code_snippets)
|
40 |
+
|
41 |
+
def __getitem__(self, idx):
|
42 |
+
code = self.code_snippets[idx]["code"]
|
43 |
+
fix = self.fix_snippets[idx]["fix"]
|
44 |
+
input_ids = tokenizer.encode(code, max_length=512, return_tensors="pt", truncation=True)
|
45 |
+
attention_mask = tokenizer.encode(fix, max_length=512, return_tensors="pt", truncation=True, add_special_tokens=False)
|
46 |
+
labels = torch.tensor(tokenizer.encode(fix, return_tensors="pt", add_special_tokens=False)).flatten()
|
47 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
48 |
+
|
49 |
+
# Load the dataset and create a data loader
|
50 |
+
dataset = CodeFixDataset(code_snippets, fix_snippets)
|
51 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
|
52 |
+
|
53 |
+
# Define the custom trainer class for your AutofixCodeAI model
|
54 |
+
class Trainer(torch.nn.Module):
|
55 |
+
def __init__(self, model, data_loader, device="cuda"):
|
56 |
+
super().__init__()
|
57 |
+
self.model = model
|
58 |
+
self.data_loader = data_loader
|
59 |
+
self.device = device
|
60 |
+
|
61 |
+
def forward(self, input_ids, attention_mask, labels):
|
62 |
+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
63 |
+
loss = self.loss_fn(output, labels)
|
64 |
+
return loss
|
65 |
+
|
66 |
+
@property
|
67 |
+
def loss_fn(self):
|
68 |
+
return torch.nn.CrossEntropyLoss()
|
69 |
+
|
70 |
+
# Train the model using the custom trainer class
|
71 |
+
trainer = Trainer(ll_model, data_loader, device="cuda")
|
72 |
+
for epoch in range(5):
|
73 |
+
trainer.model.train()
|
74 |
+
total_loss = 0
|
75 |
+
for batch in data_loader:
|
76 |
+
input_ids = batch["input_ids"].to(device)
|
77 |
+
attention_mask = batch["attention_mask"].to(device)
|
78 |
+
labels = batch["labels"].to(device)
|
79 |
+
loss = trainer(input_ids, attention_mask, labels).mean()
|
80 |
+
optimizer = torch.optim.Adam(trainer.model.parameters(), lr=1e-4)
|
81 |
+
optimizer.zero_grad()
|
82 |
+
loss.backward()
|
83 |
+
optimizer.step()
|
84 |
+
total_loss += loss.item()
|
85 |
+
print(f"Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}")
|
86 |
+
|
87 |
+
# Evaluate the model using the custom trainer class
|
88 |
+
trainer.model.eval()
|
89 |
+
test_loss = 0
|
90 |
+
correct = 0
|
91 |
+
with torch.no_grad():
|
92 |
+
for batch in data_loader:
|
93 |
+
input_ids = batch["input_ids"].to(device)
|
94 |
+
attention_mask = batch["attention_mask"].to(device)
|
95 |
+
labels = batch["labels"].to(device)
|
96 |
+
output = trainer(input_ids, attention_mask, labels).mean()
|
97 |
+
loss = self.loss_fn(output, labels)
|
98 |
+
test_loss += loss.item()
|
99 |
+
_, predicted = torch.max(output, 1)
|
100 |
+
correct += (predicted == labels).sum().item()
|
101 |
+
|
102 |
+
accuracy = correct / len(data_loader.dataset)
|
103 |
+
print(f"Test Loss: {test_loss / len(data_loader)}, Accuracy: {accuracy:.2f}")
|