qninhdt commited on
Commit
0c2ae95
·
1 Parent(s): d1ea2f2
src/data/mixed_datamodule.py CHANGED
@@ -6,7 +6,7 @@ from .mixed_dataset import MixedDataset
6
 
7
  class MixedDataModule(LightningDataModule):
8
  def __init__(
9
- self, bert_model, dataset_path, tool_capacity, batch_size, num_workers
10
  ):
11
  super().__init__()
12
  self.bert_model = bert_model
@@ -14,6 +14,7 @@ class MixedDataModule(LightningDataModule):
14
  self.tool_capacity = tool_capacity
15
  self.batch_size = batch_size
16
  self.num_workers = num_workers
 
17
 
18
  def setup(self, stage=None):
19
  if stage == "fit":
@@ -22,12 +23,14 @@ class MixedDataModule(LightningDataModule):
22
  "train",
23
  os.path.join(self.dataset_path, "train.json"),
24
  self.tool_capacity,
 
25
  )
26
  self.val_dataset = MixedDataset(
27
  self.bert_model,
28
  "test",
29
  os.path.join(self.dataset_path, "test.json"),
30
  self.tool_capacity,
 
31
  )
32
  elif stage == "test":
33
  self.test_dataset = MixedDataset(
@@ -35,6 +38,7 @@ class MixedDataModule(LightningDataModule):
35
  "test",
36
  os.path.join(self.dataset_path, "test.json"),
37
  self.tool_capacity,
 
38
  )
39
 
40
  def train_dataloader(self):
 
6
 
7
  class MixedDataModule(LightningDataModule):
8
  def __init__(
9
+ self, bert_model, dataset_path, tool_capacity, batch_size, num_workers, seed
10
  ):
11
  super().__init__()
12
  self.bert_model = bert_model
 
14
  self.tool_capacity = tool_capacity
15
  self.batch_size = batch_size
16
  self.num_workers = num_workers
17
+ self.seed = seed
18
 
19
  def setup(self, stage=None):
20
  if stage == "fit":
 
23
  "train",
24
  os.path.join(self.dataset_path, "train.json"),
25
  self.tool_capacity,
26
+ seed=self.seed,
27
  )
28
  self.val_dataset = MixedDataset(
29
  self.bert_model,
30
  "test",
31
  os.path.join(self.dataset_path, "test.json"),
32
  self.tool_capacity,
33
+ seed=self.seed,
34
  )
35
  elif stage == "test":
36
  self.test_dataset = MixedDataset(
 
38
  "test",
39
  os.path.join(self.dataset_path, "test.json"),
40
  self.tool_capacity,
41
+ seed=self.seed,
42
  )
43
 
44
  def train_dataloader(self):
src/data/mixed_dataset.py CHANGED
@@ -26,7 +26,7 @@ class MixedDataset(Dataset):
26
  return tools, samples
27
 
28
  def encode_text(self, text):
29
- inputs = self.tokenizer.encode_plus(
30
  text,
31
  max_length=128,
32
  padding="max_length",
@@ -53,8 +53,8 @@ class MixedDataset(Dataset):
53
  return {
54
  "inst_ids": inst_ids,
55
  "inst_mask": inst_mask,
56
- "tool_desc_ids": tool_desc_ids,
57
- "tool_desc_mask": tool_desc_mask,
58
  }
59
  else:
60
  # for testing, we sample a random set of tools + the correct tool, size = tool_capacity
@@ -74,7 +74,7 @@ class MixedDataset(Dataset):
74
  )
75
 
76
  tools = correct_tools + wrong_tools
77
- tool_ids, tool_ids_mask = self.encode_text(
78
  [self.tools[tool_id]["description"] for tool_id in tools]
79
  )
80
 
@@ -82,6 +82,6 @@ class MixedDataset(Dataset):
82
  "inst_ids": inst_ids,
83
  "inst_mask": inst_mask,
84
  "tool_ids": tool_ids,
85
- "tool_ids_mask": tool_ids_mask,
86
  "correct_tool_mask": correct_tool_mask,
87
  }
 
26
  return tools, samples
27
 
28
  def encode_text(self, text):
29
+ inputs = self.tokenizer(
30
  text,
31
  max_length=128,
32
  padding="max_length",
 
53
  return {
54
  "inst_ids": inst_ids,
55
  "inst_mask": inst_mask,
56
+ "tool_ids": tool_desc_ids,
57
+ "tool_mask": tool_desc_mask,
58
  }
59
  else:
60
  # for testing, we sample a random set of tools + the correct tool, size = tool_capacity
 
74
  )
75
 
76
  tools = correct_tools + wrong_tools
77
+ tool_ids, tool_mask = self.encode_text(
78
  [self.tools[tool_id]["description"] for tool_id in tools]
79
  )
80
 
 
82
  "inst_ids": inst_ids,
83
  "inst_mask": inst_mask,
84
  "tool_ids": tool_ids,
85
+ "tool_mask": tool_mask,
86
  "correct_tool_mask": correct_tool_mask,
87
  }
src/models/miniagent_module.py CHANGED
@@ -33,7 +33,7 @@ class MiniAgentModule(LightningModule):
33
  self.tool_proj_model = tool_proj_model
34
  self.pred_model = pred_model
35
 
36
- self.val_acc = Accuracy()
37
 
38
  self.lr = lr
39
 
@@ -50,8 +50,8 @@ class MiniAgentModule(LightningModule):
50
 
51
  inst_ids = batch["inst_ids"]
52
  inst_mask = batch["inst_mask"]
53
- tool_ids = batch["tool_desc_ids"]
54
- tool_mask = batch["tool_desc_mask"]
55
 
56
  inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[1]
57
  tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[1]
@@ -82,8 +82,8 @@ class MiniAgentModule(LightningModule):
82
  ) -> None:
83
  inst_ids = batch["inst_ids"]
84
  inst_mask = batch["inst_mask"]
85
- tool_ids = batch["tool_desc_ids"]
86
- tool_mask = batch["tool_desc_mask"]
87
  correct_tool_mask = batch["correct_tool_mask"]
88
 
89
  B = inst_ids.shape[0] # batch size
 
33
  self.tool_proj_model = tool_proj_model
34
  self.pred_model = pred_model
35
 
36
+ self.val_acc = Accuracy(task="binary")
37
 
38
  self.lr = lr
39
 
 
50
 
51
  inst_ids = batch["inst_ids"]
52
  inst_mask = batch["inst_mask"]
53
+ tool_ids = batch["tool_ids"]
54
+ tool_mask = batch["tool_mask"]
55
 
56
  inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[1]
57
  tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[1]
 
82
  ) -> None:
83
  inst_ids = batch["inst_ids"]
84
  inst_mask = batch["inst_mask"]
85
+ tool_ids = batch["tool_ids"]
86
+ tool_mask = batch["tool_mask"]
87
  correct_tool_mask = batch["correct_tool_mask"]
88
 
89
  B = inst_ids.shape[0] # batch size