alfiannajih commited on
Commit
db7044b
1 Parent(s): 1ddc9ba

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +3 -3
g_retriever.py CHANGED
@@ -24,13 +24,13 @@ class GRetrieverModel(LlamaForCausalLM):
24
  num_layers=config.gnn_num_layers,
25
  dropout=config.gnn_dropout,
26
  num_heads=config.gnn_num_heads,
27
- ).to(self.model.dtype).to(self.model.device)
28
 
29
  self.projector = nn.Sequential(
30
  nn.Linear(config.gnn_hidden_dim, 2048),
31
  nn.Sigmoid(),
32
  nn.Linear(2048, self.get_input_embeddings().embedding_dim),
33
- ).to(self.model.dtype).to(self.model.device)
34
 
35
  def encode_graphs(self, graph):
36
  n_embeds, _ = self.graph_encoder(
@@ -42,7 +42,7 @@ class GRetrieverModel(LlamaForCausalLM):
42
  # mean pooling
43
  g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
- return g_embeds
46
 
47
  @wraps(LlamaForCausalLM.forward)
48
  def forward(
 
24
  num_layers=config.gnn_num_layers,
25
  dropout=config.gnn_dropout,
26
  num_heads=config.gnn_num_heads,
27
+ ).to(self.model.dtype)
28
 
29
  self.projector = nn.Sequential(
30
  nn.Linear(config.gnn_hidden_dim, 2048),
31
  nn.Sigmoid(),
32
  nn.Linear(2048, self.get_input_embeddings().embedding_dim),
33
+ ).to(self.model.dtype)
34
 
35
  def encode_graphs(self, graph):
36
  n_embeds, _ = self.graph_encoder(
 
42
  # mean pooling
43
  g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
+ return g_embeds.to(model.device)
46
 
47
  @wraps(LlamaForCausalLM.forward)
48
  def forward(