alfiannajih
commited on
Commit
•
db7044b
1
Parent(s):
1ddc9ba
Update g_retriever.py
Browse files- g_retriever.py +3 -3
g_retriever.py
CHANGED
@@ -24,13 +24,13 @@ class GRetrieverModel(LlamaForCausalLM):
|
|
24 |
num_layers=config.gnn_num_layers,
|
25 |
dropout=config.gnn_dropout,
|
26 |
num_heads=config.gnn_num_heads,
|
27 |
-
).to(self.model.dtype)
|
28 |
|
29 |
self.projector = nn.Sequential(
|
30 |
nn.Linear(config.gnn_hidden_dim, 2048),
|
31 |
nn.Sigmoid(),
|
32 |
nn.Linear(2048, self.get_input_embeddings().embedding_dim),
|
33 |
-
).to(self.model.dtype)
|
34 |
|
35 |
def encode_graphs(self, graph):
|
36 |
n_embeds, _ = self.graph_encoder(
|
@@ -42,7 +42,7 @@ class GRetrieverModel(LlamaForCausalLM):
|
|
42 |
# mean pooling
|
43 |
g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
|
44 |
|
45 |
-
return g_embeds
|
46 |
|
47 |
@wraps(LlamaForCausalLM.forward)
|
48 |
def forward(
|
|
|
24 |
num_layers=config.gnn_num_layers,
|
25 |
dropout=config.gnn_dropout,
|
26 |
num_heads=config.gnn_num_heads,
|
27 |
+
).to(self.model.dtype)
|
28 |
|
29 |
self.projector = nn.Sequential(
|
30 |
nn.Linear(config.gnn_hidden_dim, 2048),
|
31 |
nn.Sigmoid(),
|
32 |
nn.Linear(2048, self.get_input_embeddings().embedding_dim),
|
33 |
+
).to(self.model.dtype)
|
34 |
|
35 |
def encode_graphs(self, graph):
|
36 |
n_embeds, _ = self.graph_encoder(
|
|
|
42 |
# mean pooling
|
43 |
g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
|
44 |
|
45 |
+
return g_embeds.to(model.device)
|
46 |
|
47 |
@wraps(LlamaForCausalLM.forward)
|
48 |
def forward(
|