File size: 538 Bytes
235b048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from egt_model.configuration_egt import EGTConfig
from egt_model.modeling_egt import EGTModel, EGTForGraphClassification

EGTConfig.register_for_auto_class()
EGTModel.register_for_auto_class("AutoModel")
EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification")

egt_config = EGTConfig()
egt = EGTForGraphClassification(egt_config)

pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state")
egt.model.load_state_dict(pretrained_model.state_dict())

# egt.push_to_hub("Zhiteng/egt")