Metin commited on
Commit
0bee7fb
·
1 Parent(s): 928a132

Load graph data and model state dict on CPU for compatibility

Browse files
Files changed (1) hide show
  1. src/gnn.py +7 -7
src/gnn.py CHANGED
@@ -48,13 +48,13 @@ class GNNClassifier(torch.nn.Module):
48
  def load_data(version: str = "undirected"):
49
 
50
  if version == "undirected":
51
- graph_data = torch.load(config.GNN_GRAPH_DATA_PATH)
52
- title_to_id = torch.load(config.TITLE_TO_ID_PATH)
53
- label_mapping = torch.load(config.LABEL_MAPPING_PATH)
54
  elif version == "no_edge":
55
- graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"))
56
- title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"))
57
- label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"))
58
  else:
59
  raise ValueError(f"Unknown version: {version}")
60
 
@@ -141,7 +141,7 @@ if __name__ == "__main__":
141
  graph_data, title_to_id, label_mapping = load_data()
142
 
143
  model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
144
- model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"))
145
 
146
  new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
147
  embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")
 
48
  def load_data(version: str = "undirected"):
49
 
50
  if version == "undirected":
51
+ graph_data = torch.load(config.GNN_GRAPH_DATA_PATH, map_location=torch.device("cpu"))
52
+ title_to_id = torch.load(config.TITLE_TO_ID_PATH, map_location=torch.device("cpu"))
53
+ label_mapping = torch.load(config.LABEL_MAPPING_PATH, map_location=torch.device("cpu"))
54
  elif version == "no_edge":
55
+ graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
56
+ title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
57
+ label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
58
  else:
59
  raise ValueError(f"Unknown version: {version}")
60
 
 
141
  graph_data, title_to_id, label_mapping = load_data()
142
 
143
  model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
144
+ model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"), map_location=torch.device("cpu"))
145
 
146
  new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
147
  embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")