Spaces:
Sleeping
Sleeping
Load graph data and model state dict on CPU for compatibility
Browse files- src/gnn.py +7 -7
src/gnn.py
CHANGED
|
@@ -48,13 +48,13 @@ class GNNClassifier(torch.nn.Module):
|
|
| 48 |
def load_data(version: str = "undirected"):
|
| 49 |
|
| 50 |
if version == "undirected":
|
| 51 |
-
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH)
|
| 52 |
-
title_to_id = torch.load(config.TITLE_TO_ID_PATH)
|
| 53 |
-
label_mapping = torch.load(config.LABEL_MAPPING_PATH)
|
| 54 |
elif version == "no_edge":
|
| 55 |
-
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 56 |
-
title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 57 |
-
label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"))
|
| 58 |
else:
|
| 59 |
raise ValueError(f"Unknown version: {version}")
|
| 60 |
|
|
@@ -141,7 +141,7 @@ if __name__ == "__main__":
|
|
| 141 |
graph_data, title_to_id, label_mapping = load_data()
|
| 142 |
|
| 143 |
model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
|
| 144 |
-
model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"))
|
| 145 |
|
| 146 |
new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
|
| 147 |
embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")
|
|
|
|
| 48 |
def load_data(version: str = "undirected"):
|
| 49 |
|
| 50 |
if version == "undirected":
|
| 51 |
+
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH, map_location=torch.device("cpu"))
|
| 52 |
+
title_to_id = torch.load(config.TITLE_TO_ID_PATH, map_location=torch.device("cpu"))
|
| 53 |
+
label_mapping = torch.load(config.LABEL_MAPPING_PATH, map_location=torch.device("cpu"))
|
| 54 |
elif version == "no_edge":
|
| 55 |
+
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
|
| 56 |
+
title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
|
| 57 |
+
label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
|
| 58 |
else:
|
| 59 |
raise ValueError(f"Unknown version: {version}")
|
| 60 |
|
|
|
|
| 141 |
graph_data, title_to_id, label_mapping = load_data()
|
| 142 |
|
| 143 |
model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
|
| 144 |
+
model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"), map_location=torch.device("cpu"))
|
| 145 |
|
| 146 |
new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
|
| 147 |
embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")
|