cc
Browse files
src/data/mixed_datamodule.py
CHANGED
|
@@ -6,7 +6,7 @@ from .mixed_dataset import MixedDataset
|
|
| 6 |
|
| 7 |
class MixedDataModule(LightningDataModule):
|
| 8 |
def __init__(
|
| 9 |
-
self, bert_model, dataset_path, tool_capacity, batch_size, num_workers
|
| 10 |
):
|
| 11 |
super().__init__()
|
| 12 |
self.bert_model = bert_model
|
|
@@ -14,6 +14,7 @@ class MixedDataModule(LightningDataModule):
|
|
| 14 |
self.tool_capacity = tool_capacity
|
| 15 |
self.batch_size = batch_size
|
| 16 |
self.num_workers = num_workers
|
|
|
|
| 17 |
|
| 18 |
def setup(self, stage=None):
|
| 19 |
if stage == "fit":
|
|
@@ -22,12 +23,14 @@ class MixedDataModule(LightningDataModule):
|
|
| 22 |
"train",
|
| 23 |
os.path.join(self.dataset_path, "train.json"),
|
| 24 |
self.tool_capacity,
|
|
|
|
| 25 |
)
|
| 26 |
self.val_dataset = MixedDataset(
|
| 27 |
self.bert_model,
|
| 28 |
"test",
|
| 29 |
os.path.join(self.dataset_path, "test.json"),
|
| 30 |
self.tool_capacity,
|
|
|
|
| 31 |
)
|
| 32 |
elif stage == "test":
|
| 33 |
self.test_dataset = MixedDataset(
|
|
@@ -35,6 +38,7 @@ class MixedDataModule(LightningDataModule):
|
|
| 35 |
"test",
|
| 36 |
os.path.join(self.dataset_path, "test.json"),
|
| 37 |
self.tool_capacity,
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
def train_dataloader(self):
|
|
|
|
| 6 |
|
| 7 |
class MixedDataModule(LightningDataModule):
|
| 8 |
def __init__(
|
| 9 |
+
self, bert_model, dataset_path, tool_capacity, batch_size, num_workers, seed
|
| 10 |
):
|
| 11 |
super().__init__()
|
| 12 |
self.bert_model = bert_model
|
|
|
|
| 14 |
self.tool_capacity = tool_capacity
|
| 15 |
self.batch_size = batch_size
|
| 16 |
self.num_workers = num_workers
|
| 17 |
+
self.seed = seed
|
| 18 |
|
| 19 |
def setup(self, stage=None):
|
| 20 |
if stage == "fit":
|
|
|
|
| 23 |
"train",
|
| 24 |
os.path.join(self.dataset_path, "train.json"),
|
| 25 |
self.tool_capacity,
|
| 26 |
+
seed=self.seed,
|
| 27 |
)
|
| 28 |
self.val_dataset = MixedDataset(
|
| 29 |
self.bert_model,
|
| 30 |
"test",
|
| 31 |
os.path.join(self.dataset_path, "test.json"),
|
| 32 |
self.tool_capacity,
|
| 33 |
+
seed=self.seed,
|
| 34 |
)
|
| 35 |
elif stage == "test":
|
| 36 |
self.test_dataset = MixedDataset(
|
|
|
|
| 38 |
"test",
|
| 39 |
os.path.join(self.dataset_path, "test.json"),
|
| 40 |
self.tool_capacity,
|
| 41 |
+
seed=self.seed,
|
| 42 |
)
|
| 43 |
|
| 44 |
def train_dataloader(self):
|
src/data/mixed_dataset.py
CHANGED
|
@@ -26,7 +26,7 @@ class MixedDataset(Dataset):
|
|
| 26 |
return tools, samples
|
| 27 |
|
| 28 |
def encode_text(self, text):
|
| 29 |
-
inputs = self.tokenizer
|
| 30 |
text,
|
| 31 |
max_length=128,
|
| 32 |
padding="max_length",
|
|
@@ -53,8 +53,8 @@ class MixedDataset(Dataset):
|
|
| 53 |
return {
|
| 54 |
"inst_ids": inst_ids,
|
| 55 |
"inst_mask": inst_mask,
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
}
|
| 59 |
else:
|
| 60 |
# for testing, we sample a random set of tools + the correct tool, size = tool_capacity
|
|
@@ -74,7 +74,7 @@ class MixedDataset(Dataset):
|
|
| 74 |
)
|
| 75 |
|
| 76 |
tools = correct_tools + wrong_tools
|
| 77 |
-
tool_ids,
|
| 78 |
[self.tools[tool_id]["description"] for tool_id in tools]
|
| 79 |
)
|
| 80 |
|
|
@@ -82,6 +82,6 @@ class MixedDataset(Dataset):
|
|
| 82 |
"inst_ids": inst_ids,
|
| 83 |
"inst_mask": inst_mask,
|
| 84 |
"tool_ids": tool_ids,
|
| 85 |
-
"
|
| 86 |
"correct_tool_mask": correct_tool_mask,
|
| 87 |
}
|
|
|
|
| 26 |
return tools, samples
|
| 27 |
|
| 28 |
def encode_text(self, text):
|
| 29 |
+
inputs = self.tokenizer(
|
| 30 |
text,
|
| 31 |
max_length=128,
|
| 32 |
padding="max_length",
|
|
|
|
| 53 |
return {
|
| 54 |
"inst_ids": inst_ids,
|
| 55 |
"inst_mask": inst_mask,
|
| 56 |
+
"tool_ids": tool_desc_ids,
|
| 57 |
+
"tool_mask": tool_desc_mask,
|
| 58 |
}
|
| 59 |
else:
|
| 60 |
# for testing, we sample a random set of tools + the correct tool, size = tool_capacity
|
|
|
|
| 74 |
)
|
| 75 |
|
| 76 |
tools = correct_tools + wrong_tools
|
| 77 |
+
tool_ids, tool_mask = self.encode_text(
|
| 78 |
[self.tools[tool_id]["description"] for tool_id in tools]
|
| 79 |
)
|
| 80 |
|
|
|
|
| 82 |
"inst_ids": inst_ids,
|
| 83 |
"inst_mask": inst_mask,
|
| 84 |
"tool_ids": tool_ids,
|
| 85 |
+
"tool_mask": tool_mask,
|
| 86 |
"correct_tool_mask": correct_tool_mask,
|
| 87 |
}
|
src/models/miniagent_module.py
CHANGED
|
@@ -33,7 +33,7 @@ class MiniAgentModule(LightningModule):
|
|
| 33 |
self.tool_proj_model = tool_proj_model
|
| 34 |
self.pred_model = pred_model
|
| 35 |
|
| 36 |
-
self.val_acc = Accuracy()
|
| 37 |
|
| 38 |
self.lr = lr
|
| 39 |
|
|
@@ -50,8 +50,8 @@ class MiniAgentModule(LightningModule):
|
|
| 50 |
|
| 51 |
inst_ids = batch["inst_ids"]
|
| 52 |
inst_mask = batch["inst_mask"]
|
| 53 |
-
tool_ids = batch["
|
| 54 |
-
tool_mask = batch["
|
| 55 |
|
| 56 |
inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[1]
|
| 57 |
tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[1]
|
|
@@ -82,8 +82,8 @@ class MiniAgentModule(LightningModule):
|
|
| 82 |
) -> None:
|
| 83 |
inst_ids = batch["inst_ids"]
|
| 84 |
inst_mask = batch["inst_mask"]
|
| 85 |
-
tool_ids = batch["
|
| 86 |
-
tool_mask = batch["
|
| 87 |
correct_tool_mask = batch["correct_tool_mask"]
|
| 88 |
|
| 89 |
B = inst_ids.shape[0] # batch size
|
|
|
|
| 33 |
self.tool_proj_model = tool_proj_model
|
| 34 |
self.pred_model = pred_model
|
| 35 |
|
| 36 |
+
self.val_acc = Accuracy(task="binary")
|
| 37 |
|
| 38 |
self.lr = lr
|
| 39 |
|
|
|
|
| 50 |
|
| 51 |
inst_ids = batch["inst_ids"]
|
| 52 |
inst_mask = batch["inst_mask"]
|
| 53 |
+
tool_ids = batch["tool_ids"]
|
| 54 |
+
tool_mask = batch["tool_mask"]
|
| 55 |
|
| 56 |
inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[1]
|
| 57 |
tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[1]
|
|
|
|
| 82 |
) -> None:
|
| 83 |
inst_ids = batch["inst_ids"]
|
| 84 |
inst_mask = batch["inst_mask"]
|
| 85 |
+
tool_ids = batch["tool_ids"]
|
| 86 |
+
tool_mask = batch["tool_mask"]
|
| 87 |
correct_tool_mask = batch["correct_tool_mask"]
|
| 88 |
|
| 89 |
B = inst_ids.shape[0] # batch size
|