|
import os |
|
import csv |
|
import shutil |
|
import torch |
|
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip |
|
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR |
|
|
|
from ultra.tasks import build_relation_graph |
|
|
|
|
|
class GrailInductiveDataset(InMemoryDataset): |
|
|
|
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, merge_valid_test=True): |
|
self.version = version |
|
assert version in ["v1", "v2", "v3", "v4"] |
|
|
|
|
|
|
|
|
|
|
|
self.merge_valid_test = merge_valid_test |
|
super().__init__(root, transform, pre_transform) |
|
self.data, self.slices = torch.load(self.processed_paths[0]) |
|
|
|
@property |
|
def num_relations(self): |
|
return int(self.data.edge_type.max()) + 1 |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, "grail", self.name, self.version, "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, "grail", self.name, self.version, "processed") |
|
|
|
@property |
|
def processed_file_names(self): |
|
return "data.pt" |
|
|
|
@property |
|
def raw_file_names(self): |
|
return [ |
|
"train_ind.txt", "valid_ind.txt", "test_ind.txt", "train.txt", "valid.txt" |
|
] |
|
|
|
def download(self): |
|
for url, path in zip(self.urls, self.raw_paths): |
|
download_path = download_url(url % self.version, self.raw_dir) |
|
os.rename(download_path, path) |
|
|
|
def process(self): |
|
test_files = self.raw_paths[:3] |
|
train_files = self.raw_paths[3:] |
|
|
|
inv_train_entity_vocab = {} |
|
inv_test_entity_vocab = {} |
|
inv_relation_vocab = {} |
|
triplets = [] |
|
num_samples = [] |
|
|
|
for txt_file in train_files: |
|
with open(txt_file, "r") as fin: |
|
num_sample = 0 |
|
for line in fin: |
|
h_token, r_token, t_token = line.strip().split("\t") |
|
if h_token not in inv_train_entity_vocab: |
|
inv_train_entity_vocab[h_token] = len(inv_train_entity_vocab) |
|
h = inv_train_entity_vocab[h_token] |
|
if r_token not in inv_relation_vocab: |
|
inv_relation_vocab[r_token] = len(inv_relation_vocab) |
|
r = inv_relation_vocab[r_token] |
|
if t_token not in inv_train_entity_vocab: |
|
inv_train_entity_vocab[t_token] = len(inv_train_entity_vocab) |
|
t = inv_train_entity_vocab[t_token] |
|
triplets.append((h, t, r)) |
|
num_sample += 1 |
|
num_samples.append(num_sample) |
|
|
|
for txt_file in test_files: |
|
with open(txt_file, "r") as fin: |
|
num_sample = 0 |
|
for line in fin: |
|
h_token, r_token, t_token = line.strip().split("\t") |
|
if h_token not in inv_test_entity_vocab: |
|
inv_test_entity_vocab[h_token] = len(inv_test_entity_vocab) |
|
h = inv_test_entity_vocab[h_token] |
|
assert r_token in inv_relation_vocab |
|
r = inv_relation_vocab[r_token] |
|
if t_token not in inv_test_entity_vocab: |
|
inv_test_entity_vocab[t_token] = len(inv_test_entity_vocab) |
|
t = inv_test_entity_vocab[t_token] |
|
triplets.append((h, t, r)) |
|
num_sample += 1 |
|
num_samples.append(num_sample) |
|
triplets = torch.tensor(triplets) |
|
|
|
edge_index = triplets[:, :2].t() |
|
edge_type = triplets[:, 2] |
|
num_relations = int(edge_type.max()) + 1 |
|
|
|
|
|
|
|
train_fact_slice = slice(None, sum(num_samples[:1])) |
|
test_fact_slice = slice(sum(num_samples[:2]), sum(num_samples[:3])) |
|
train_fact_index = edge_index[:, train_fact_slice] |
|
train_fact_type = edge_type[train_fact_slice] |
|
test_fact_index = edge_index[:, test_fact_slice] |
|
test_fact_type = edge_type[test_fact_slice] |
|
|
|
|
|
train_fact_index = torch.cat([train_fact_index, train_fact_index.flip(0)], dim=-1) |
|
train_fact_type = torch.cat([train_fact_type, train_fact_type + num_relations]) |
|
test_fact_index = torch.cat([test_fact_index, test_fact_index.flip(0)], dim=-1) |
|
test_fact_type = torch.cat([test_fact_type, test_fact_type + num_relations]) |
|
|
|
train_slice = slice(None, sum(num_samples[:1])) |
|
valid_slice = slice(sum(num_samples[:1]), sum(num_samples[:2])) |
|
|
|
|
|
|
|
test_slice = slice(sum(num_samples[:3]), sum(num_samples)) if self.merge_valid_test else slice(sum(num_samples[:4]), sum(num_samples)) |
|
|
|
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab), |
|
target_edge_index=edge_index[:, train_slice], target_edge_type=edge_type[train_slice], num_relations=num_relations*2) |
|
valid_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab), |
|
target_edge_index=edge_index[:, valid_slice], target_edge_type=edge_type[valid_slice], num_relations=num_relations*2) |
|
test_data = Data(edge_index=test_fact_index, edge_type=test_fact_type, num_nodes=len(inv_test_entity_vocab), |
|
target_edge_index=edge_index[:, test_slice], target_edge_type=edge_type[test_slice], num_relations=num_relations*2) |
|
|
|
if self.pre_transform is not None: |
|
train_data = self.pre_transform(train_data) |
|
valid_data = self.pre_transform(valid_data) |
|
test_data = self.pre_transform(test_data) |
|
|
|
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0]) |
|
|
|
def __repr__(self): |
|
return "%s(%s)" % (self.name, self.version) |
|
|
|
|
|
class FB15k237Inductive(GrailInductiveDataset): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/train.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/valid.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/test.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/train.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/valid.txt" |
|
] |
|
|
|
name = "IndFB15k237" |
|
|
|
def __init__(self, root, version): |
|
super().__init__(root, version) |
|
|
|
class WN18RRInductive(GrailInductiveDataset): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/train.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/valid.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/test.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/train.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/valid.txt" |
|
] |
|
|
|
name = "IndWN18RR" |
|
|
|
def __init__(self, root, version): |
|
super().__init__(root, version) |
|
|
|
class NELLInductive(GrailInductiveDataset): |
|
urls = [ |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/train.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/valid.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/test.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/train.txt", |
|
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/valid.txt" |
|
] |
|
name = "IndNELL" |
|
|
|
def __init__(self, root, version): |
|
super().__init__(root, version) |
|
|
|
|
|
def FB15k237(root): |
|
dataset = RelLinkPredDataset(name="FB15k-237", root=root+"/fb15k237/") |
|
data = dataset.data |
|
train_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes, |
|
target_edge_index=data.train_edge_index, target_edge_type=data.train_edge_type, |
|
num_relations=dataset.num_relations) |
|
valid_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes, |
|
target_edge_index=data.valid_edge_index, target_edge_type=data.valid_edge_type, |
|
num_relations=dataset.num_relations) |
|
test_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes, |
|
target_edge_index=data.test_edge_index, target_edge_type=data.test_edge_type, |
|
num_relations=dataset.num_relations) |
|
|
|
|
|
train_data = build_relation_graph(train_data) |
|
valid_data = build_relation_graph(valid_data) |
|
test_data = build_relation_graph(test_data) |
|
|
|
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data]) |
|
return dataset |
|
|
|
def WN18RR(root): |
|
dataset = WordNet18RR(root=root+"/wn18rr/") |
|
|
|
data = dataset.data |
|
num_nodes = int(data.edge_index.max()) + 1 |
|
num_relations = int(data.edge_type.max()) + 1 |
|
edge_index = data.edge_index[:, data.train_mask] |
|
edge_type = data.edge_type[data.train_mask] |
|
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1) |
|
edge_type = torch.cat([edge_type, edge_type + num_relations]) |
|
train_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes, |
|
target_edge_index=data.edge_index[:, data.train_mask], |
|
target_edge_type=data.edge_type[data.train_mask], |
|
num_relations=num_relations*2) |
|
valid_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes, |
|
target_edge_index=data.edge_index[:, data.val_mask], |
|
target_edge_type=data.edge_type[data.val_mask], |
|
num_relations=num_relations*2) |
|
test_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes, |
|
target_edge_index=data.edge_index[:, data.test_mask], |
|
target_edge_type=data.edge_type[data.test_mask], |
|
num_relations=num_relations*2) |
|
|
|
|
|
train_data = build_relation_graph(train_data) |
|
valid_data = build_relation_graph(valid_data) |
|
test_data = build_relation_graph(test_data) |
|
|
|
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data]) |
|
dataset.num_relations = num_relations * 2 |
|
return dataset |
|
|
|
|
|
class TransductiveDataset(InMemoryDataset): |
|
|
|
delimiter = None |
|
|
|
def __init__(self, root, transform=None, pre_transform=build_relation_graph, **kwargs): |
|
|
|
super().__init__(root, transform, pre_transform) |
|
self.data, self.slices = torch.load(self.processed_paths[0]) |
|
|
|
@property |
|
def raw_file_names(self): |
|
return ["train.txt", "valid.txt", "test.txt"] |
|
|
|
def download(self): |
|
for url, path in zip(self.urls, self.raw_paths): |
|
download_path = download_url(url, self.raw_dir) |
|
os.rename(download_path, path) |
|
|
|
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}): |
|
|
|
triplets = [] |
|
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab) |
|
|
|
with open(triplet_file, "r", encoding="utf-8") as fin: |
|
for l in fin: |
|
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter) |
|
if u not in inv_entity_vocab: |
|
inv_entity_vocab[u] = entity_cnt |
|
entity_cnt += 1 |
|
if v not in inv_entity_vocab: |
|
inv_entity_vocab[v] = entity_cnt |
|
entity_cnt += 1 |
|
if r not in inv_rel_vocab: |
|
inv_rel_vocab[r] = rel_cnt |
|
rel_cnt += 1 |
|
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v] |
|
|
|
triplets.append((u, v, r)) |
|
|
|
return { |
|
"triplets": triplets, |
|
"num_node": len(inv_entity_vocab), |
|
"num_relation": rel_cnt, |
|
"inv_entity_vocab": inv_entity_vocab, |
|
"inv_rel_vocab": inv_rel_vocab |
|
} |
|
|
|
|
|
def process(self): |
|
|
|
train_files = self.raw_paths[:3] |
|
|
|
train_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={}) |
|
valid_results = self.load_file(train_files[1], |
|
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"]) |
|
test_results = self.load_file(train_files[2], |
|
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"]) |
|
|
|
|
|
|
|
num_node = test_results["num_node"] |
|
|
|
|
|
num_relations = test_results["num_relation"] |
|
|
|
train_triplets = train_results["triplets"] |
|
valid_triplets = valid_results["triplets"] |
|
test_triplets = test_results["triplets"] |
|
|
|
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t() |
|
train_target_etypes = torch.tensor([t[2] for t in train_triplets]) |
|
|
|
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t() |
|
valid_etypes = torch.tensor([t[2] for t in valid_triplets]) |
|
|
|
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t() |
|
test_etypes = torch.tensor([t[2] for t in test_triplets]) |
|
|
|
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1) |
|
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations]) |
|
|
|
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node, |
|
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2) |
|
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node, |
|
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2) |
|
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node, |
|
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2) |
|
|
|
|
|
if self.pre_transform is not None: |
|
train_data = self.pre_transform(train_data) |
|
valid_data = self.pre_transform(valid_data) |
|
test_data = self.pre_transform(test_data) |
|
|
|
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0]) |
|
|
|
def __repr__(self): |
|
return "%s()" % (self.name) |
|
|
|
@property |
|
def num_relations(self): |
|
return int(self.data.edge_type.max()) + 1 |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, self.name, "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, self.name, "processed") |
|
|
|
@property |
|
def processed_file_names(self): |
|
return "data.pt" |
|
|
|
|
|
|
|
class CoDEx(TransductiveDataset): |
|
|
|
name = "codex" |
|
urls = [ |
|
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/train.txt", |
|
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/valid.txt", |
|
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/test.txt", |
|
] |
|
|
|
def download(self): |
|
for url, path in zip(self.urls, self.raw_paths): |
|
download_path = download_url(url % self.name, self.raw_dir) |
|
os.rename(download_path, path) |
|
|
|
|
|
class CoDExSmall(CoDEx): |
|
""" |
|
#node: 2034 |
|
#edge: 36543 |
|
#relation: 42 |
|
""" |
|
url = "https://zenodo.org/record/4281094/files/codex-s.tar.gz" |
|
md5 = "63cd8186fc2aeddc154e20cf4a10087e" |
|
name = "codex-s" |
|
|
|
def __init__(self, root): |
|
super(CoDExSmall, self).__init__(root=root, size='s') |
|
|
|
|
|
class CoDExMedium(CoDEx): |
|
""" |
|
#node: 17050 |
|
#edge: 206205 |
|
#relation: 51 |
|
""" |
|
url = "https://zenodo.org/record/4281094/files/codex-m.tar.gz" |
|
md5 = "43e561cfdca1c6ad9cc2f5b1ca4add76" |
|
name = "codex-m" |
|
def __init__(self, root): |
|
super(CoDExMedium, self).__init__(root=root, size='m') |
|
|
|
|
|
class CoDExLarge(CoDEx): |
|
""" |
|
#node: 77951 |
|
#edge: 612437 |
|
#relation: 69 |
|
""" |
|
url = "https://zenodo.org/record/4281094/files/codex-l.tar.gz" |
|
md5 = "9a10f4458c4bd2b16ef9b92b677e0d71" |
|
name = "codex-l" |
|
def __init__(self, root): |
|
super(CoDExLarge, self).__init__(root=root, size='l') |
|
|
|
|
|
class NELL995(TransductiveDataset): |
|
|
|
|
|
|
|
|
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/facts.txt", |
|
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/train.txt", |
|
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/valid.txt", |
|
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/test.txt", |
|
] |
|
name = "nell995" |
|
|
|
@property |
|
def raw_file_names(self): |
|
return ["facts.txt", "train.txt", "valid.txt", "test.txt"] |
|
|
|
|
|
def process(self): |
|
train_files = self.raw_paths[:4] |
|
|
|
facts_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={}) |
|
train_results = self.load_file(train_files[1], facts_results["inv_entity_vocab"], facts_results["inv_rel_vocab"]) |
|
valid_results = self.load_file(train_files[2], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"]) |
|
test_results = self.load_file(train_files[3], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"]) |
|
|
|
num_node = valid_results["num_node"] |
|
num_relations = train_results["num_relation"] |
|
|
|
train_triplets = facts_results["triplets"] + train_results["triplets"] |
|
valid_triplets = valid_results["triplets"] |
|
test_triplets = test_results["triplets"] |
|
|
|
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t() |
|
train_target_etypes = torch.tensor([t[2] for t in train_triplets]) |
|
|
|
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t() |
|
valid_etypes = torch.tensor([t[2] for t in valid_triplets]) |
|
|
|
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t() |
|
test_etypes = torch.tensor([t[2] for t in test_triplets]) |
|
|
|
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1) |
|
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations]) |
|
|
|
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node, |
|
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2) |
|
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node, |
|
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2) |
|
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node, |
|
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2) |
|
|
|
|
|
if self.pre_transform is not None: |
|
train_data = self.pre_transform(train_data) |
|
valid_data = self.pre_transform(valid_data) |
|
test_data = self.pre_transform(test_data) |
|
|
|
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0]) |
|
|
|
|
|
class ConceptNet100k(TransductiveDataset): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/train", |
|
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/valid", |
|
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/test", |
|
] |
|
name = "cnet100k" |
|
delimiter = "\t" |
|
|
|
|
|
class DBpedia100k(TransductiveDataset): |
|
urls = [ |
|
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_train.txt", |
|
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_valid.txt", |
|
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_test.txt", |
|
] |
|
name = "dbp100k" |
|
|
|
|
|
class YAGO310(TransductiveDataset): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/train.txt", |
|
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/valid.txt", |
|
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/test.txt", |
|
] |
|
name = "yago310" |
|
|
|
|
|
class Hetionet(TransductiveDataset): |
|
|
|
urls = [ |
|
"https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1", |
|
"https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1", |
|
"https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1", |
|
] |
|
name = "hetionet" |
|
|
|
|
|
class AristoV4(TransductiveDataset): |
|
|
|
url = "https://zenodo.org/record/5942560/files/aristo-v4.zip" |
|
|
|
name = "aristov4" |
|
delimiter = "\t" |
|
|
|
def download(self): |
|
download_path = download_url(self.url, self.raw_dir) |
|
extract_zip(download_path, self.raw_dir) |
|
os.unlink(download_path) |
|
for oldname, newname in zip(['train', 'valid', 'test'], self.raw_paths): |
|
os.rename(os.path.join(self.raw_dir, oldname), newname) |
|
|
|
|
|
class SparserKG(TransductiveDataset): |
|
|
|
|
|
|
|
|
|
url = "https://raw.githubusercontent.com/THU-KEG/DacKGR/master/data.zip" |
|
delimiter = "\t" |
|
base_name = "SparseKG" |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, self.base_name, self.name, "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, self.base_name, self.name, "processed") |
|
|
|
def download(self): |
|
base_path = os.path.join(self.root, self.base_name) |
|
download_path = download_url(self.url, base_path) |
|
extract_zip(download_path, base_path) |
|
for dsname in ['NELL23K', 'WD-singer', 'FB15K-237-10', 'FB15K-237-20', 'FB15K-237-50']: |
|
for oldname, newname in zip(['train.triples', 'dev.triples', 'test.triples'], self.raw_file_names): |
|
os.renames(os.path.join(base_path, "data", dsname, oldname), os.path.join(base_path, dsname, "raw", newname)) |
|
shutil.rmtree(os.path.join(base_path, "data")) |
|
|
|
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}): |
|
|
|
triplets = [] |
|
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab) |
|
|
|
with open(triplet_file, "r", encoding="utf-8") as fin: |
|
for l in fin: |
|
u, v, r = l.split() if self.delimiter is None else l.strip().split(self.delimiter) |
|
if u not in inv_entity_vocab: |
|
inv_entity_vocab[u] = entity_cnt |
|
entity_cnt += 1 |
|
if v not in inv_entity_vocab: |
|
inv_entity_vocab[v] = entity_cnt |
|
entity_cnt += 1 |
|
if r not in inv_rel_vocab: |
|
inv_rel_vocab[r] = rel_cnt |
|
rel_cnt += 1 |
|
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v] |
|
|
|
triplets.append((u, v, r)) |
|
|
|
return { |
|
"triplets": triplets, |
|
"num_node": len(inv_entity_vocab), |
|
"num_relation": rel_cnt, |
|
"inv_entity_vocab": inv_entity_vocab, |
|
"inv_rel_vocab": inv_rel_vocab |
|
} |
|
|
|
class WDsinger(SparserKG): |
|
name = "WD-singer" |
|
|
|
class NELL23k(SparserKG): |
|
name = "NELL23K" |
|
|
|
class FB15k237_10(SparserKG): |
|
name = "FB15K-237-10" |
|
|
|
class FB15k237_20(SparserKG): |
|
name = "FB15K-237-20" |
|
|
|
class FB15k237_50(SparserKG): |
|
name = "FB15K-237-50" |
|
|
|
|
|
class InductiveDataset(InMemoryDataset): |
|
|
|
delimiter = None |
|
|
|
valid_on_inf = True |
|
|
|
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, **kwargs): |
|
|
|
self.version = str(version) |
|
super().__init__(root, transform, pre_transform) |
|
self.data, self.slices = torch.load(self.processed_paths[0]) |
|
|
|
def download(self): |
|
for url, path in zip(self.urls, self.raw_paths): |
|
download_path = download_url(url % self.version, self.raw_dir) |
|
os.rename(download_path, path) |
|
|
|
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}): |
|
|
|
triplets = [] |
|
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab) |
|
|
|
with open(triplet_file, "r", encoding="utf-8") as fin: |
|
for l in fin: |
|
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter) |
|
if u not in inv_entity_vocab: |
|
inv_entity_vocab[u] = entity_cnt |
|
entity_cnt += 1 |
|
if v not in inv_entity_vocab: |
|
inv_entity_vocab[v] = entity_cnt |
|
entity_cnt += 1 |
|
if r not in inv_rel_vocab: |
|
inv_rel_vocab[r] = rel_cnt |
|
rel_cnt += 1 |
|
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v] |
|
|
|
triplets.append((u, v, r)) |
|
|
|
return { |
|
"triplets": triplets, |
|
"num_node": len(inv_entity_vocab), |
|
"num_relation": rel_cnt, |
|
"inv_entity_vocab": inv_entity_vocab, |
|
"inv_rel_vocab": inv_rel_vocab |
|
} |
|
|
|
def process(self): |
|
|
|
train_files = self.raw_paths[:4] |
|
|
|
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={}) |
|
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={}) |
|
valid_res = self.load_file( |
|
train_files[2], |
|
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"], |
|
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"] |
|
) |
|
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"]) |
|
|
|
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"] |
|
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"] |
|
|
|
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"] |
|
|
|
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t() |
|
train_target_etypes = torch.tensor([t[2] for t in train_edges]) |
|
|
|
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1) |
|
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels]) |
|
|
|
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t() |
|
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1) |
|
inf_etypes = torch.tensor([t[2] for t in inf_graph]) |
|
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels]) |
|
|
|
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long) |
|
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long) |
|
|
|
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes, |
|
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2) |
|
valid_data = Data(edge_index=inf_edges if self.valid_on_inf else train_fact_index, |
|
edge_type=inf_etypes if self.valid_on_inf else train_fact_type, |
|
num_nodes=inference_num_nodes if self.valid_on_inf else num_train_nodes, |
|
target_edge_index=inf_valid_edges[:, :2].T, |
|
target_edge_type=inf_valid_edges[:, 2], |
|
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2) |
|
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes, |
|
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2) |
|
|
|
if self.pre_transform is not None: |
|
train_data = self.pre_transform(train_data) |
|
valid_data = self.pre_transform(valid_data) |
|
test_data = self.pre_transform(test_data) |
|
|
|
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0]) |
|
|
|
@property |
|
def num_relations(self): |
|
return int(self.data.edge_type.max()) + 1 |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, self.name, self.version, "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, self.name, self.version, "processed") |
|
|
|
@property |
|
def raw_file_names(self): |
|
return [ |
|
"transductive_train.txt", "inference_graph.txt", "inf_valid.txt", "inf_test.txt" |
|
] |
|
|
|
@property |
|
def processed_file_names(self): |
|
return "data.pt" |
|
|
|
def __repr__(self): |
|
return "%s(%s)" % (self.name, self.version) |
|
|
|
|
|
class IngramInductive(InductiveDataset): |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, "ingram", self.name, self.version, "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, "ingram", self.name, self.version, "processed") |
|
|
|
|
|
class FBIngram(IngramInductive): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/train.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/msg.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/valid.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/test.txt", |
|
] |
|
name = "fb" |
|
|
|
|
|
class WKIngram(IngramInductive): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/train.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/msg.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/valid.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/test.txt", |
|
] |
|
name = "wk" |
|
|
|
class NLIngram(IngramInductive): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/train.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/msg.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/valid.txt", |
|
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/test.txt", |
|
] |
|
name = "nl" |
|
|
|
|
|
class ILPC2022(InductiveDataset): |
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/train.txt", |
|
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference.txt", |
|
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_validation.txt", |
|
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_test.txt", |
|
] |
|
|
|
name = "ilpc2022" |
|
|
|
|
|
class HM(InductiveDataset): |
|
|
|
|
|
urls = [ |
|
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/train.txt", |
|
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-graph.txt", |
|
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/valid.txt", |
|
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-fact.txt", |
|
] |
|
|
|
name = "hm" |
|
versions = { |
|
'1k': "Hamaguchi-BM_both-1000", |
|
'3k': "Hamaguchi-BM_both-3000", |
|
'5k': "Hamaguchi-BM_both-5000", |
|
'indigo': "INDIGO-BM" |
|
} |
|
|
|
valid_on_inf = False |
|
|
|
def __init__(self, root, version, **kwargs): |
|
version = self.versions[version] |
|
super().__init__(root, version, **kwargs) |
|
|
|
|
|
def process(self): |
|
|
|
train_files = self.raw_paths[:4] |
|
|
|
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={}) |
|
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={}) |
|
valid_res = self.load_file( |
|
train_files[2], |
|
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"], |
|
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"] |
|
) |
|
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"]) |
|
|
|
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"] |
|
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"] |
|
|
|
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"] |
|
|
|
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t() |
|
train_target_etypes = torch.tensor([t[2] for t in train_edges]) |
|
|
|
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1) |
|
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels]) |
|
|
|
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t() |
|
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1) |
|
inf_etypes = torch.tensor([t[2] for t in inf_graph]) |
|
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels]) |
|
|
|
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long) |
|
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long) |
|
|
|
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes, |
|
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2) |
|
valid_data = Data(edge_index=train_fact_index, |
|
edge_type=train_fact_type, |
|
num_nodes=valid_res["num_node"], |
|
target_edge_index=inf_valid_edges[:, :2].T, |
|
target_edge_type=inf_valid_edges[:, 2], |
|
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2) |
|
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes, |
|
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2) |
|
|
|
if self.pre_transform is not None: |
|
train_data = self.pre_transform(train_data) |
|
valid_data = self.pre_transform(valid_data) |
|
test_data = self.pre_transform(test_data) |
|
|
|
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0]) |
|
|
|
|
|
class MTDEAInductive(InductiveDataset): |
|
|
|
valid_on_inf = False |
|
url = "https://reltrans.s3.us-east-2.amazonaws.com/MTDEA_data.zip" |
|
base_name = "mtdea" |
|
|
|
def __init__(self, root, version, **kwargs): |
|
|
|
assert version in self.versions, f"unknown version {version} for {self.name}, available: {self.versions}" |
|
super().__init__(root, version, **kwargs) |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, self.base_name, self.name, self.version, "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, self.base_name, self.name, self.version, "processed") |
|
|
|
@property |
|
def raw_file_names(self): |
|
return [ |
|
"transductive_train.txt", "inference_graph.txt", "transductive_valid.txt", "inf_test.txt" |
|
] |
|
|
|
def download(self): |
|
base_path = os.path.join(self.root, self.base_name) |
|
download_path = download_url(self.url, base_path) |
|
extract_zip(download_path, base_path) |
|
|
|
for dsname in ['FBNELL', 'Metafam', 'WikiTopics-MT1', 'WikiTopics-MT2', 'WikiTopics-MT3', 'WikiTopics-MT4']: |
|
cl = globals()[dsname.replace("-","")] |
|
versions = cl.versions |
|
for version in versions: |
|
for oldname, newname in zip(['train.txt', 'observe.txt', 'valid.txt', 'test.txt'], self.raw_file_names): |
|
foldername = cl.prefix % version + "-trans" if "transductive" in newname else cl.prefix % version + "-ind" |
|
os.renames( |
|
os.path.join(base_path, "MTDEA_datasets", dsname, foldername, oldname), |
|
os.path.join(base_path, dsname, version, "raw", newname) |
|
) |
|
shutil.rmtree(os.path.join(base_path, "MTDEA_datasets")) |
|
|
|
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}, limit_vocab=False): |
|
|
|
triplets = [] |
|
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab) |
|
|
|
|
|
|
|
with open(triplet_file, "r", encoding="utf-8") as fin: |
|
for l in fin: |
|
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter) |
|
if u not in inv_entity_vocab: |
|
if limit_vocab: |
|
continue |
|
inv_entity_vocab[u] = entity_cnt |
|
entity_cnt += 1 |
|
if v not in inv_entity_vocab: |
|
if limit_vocab: |
|
continue |
|
inv_entity_vocab[v] = entity_cnt |
|
entity_cnt += 1 |
|
if r not in inv_rel_vocab: |
|
if limit_vocab: |
|
continue |
|
inv_rel_vocab[r] = rel_cnt |
|
rel_cnt += 1 |
|
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v] |
|
|
|
triplets.append((u, v, r)) |
|
|
|
return { |
|
"triplets": triplets, |
|
"num_node": entity_cnt, |
|
"num_relation": rel_cnt, |
|
"inv_entity_vocab": inv_entity_vocab, |
|
"inv_rel_vocab": inv_rel_vocab |
|
} |
|
|
|
|
|
def process(self): |
|
|
|
train_files = self.raw_paths[:4] |
|
|
|
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={}) |
|
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={}) |
|
valid_res = self.load_file( |
|
train_files[2], |
|
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"], |
|
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"], |
|
limit_vocab=True, |
|
) |
|
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"]) |
|
|
|
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"] |
|
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"] |
|
|
|
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"] |
|
|
|
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t() |
|
train_target_etypes = torch.tensor([t[2] for t in train_edges]) |
|
|
|
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1) |
|
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels]) |
|
|
|
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t() |
|
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1) |
|
inf_etypes = torch.tensor([t[2] for t in inf_graph]) |
|
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels]) |
|
|
|
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long) |
|
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long) |
|
|
|
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes, |
|
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2) |
|
valid_data = Data(edge_index=train_fact_index, |
|
edge_type=train_fact_type, |
|
num_nodes=valid_res["num_node"], |
|
target_edge_index=inf_valid_edges[:, :2].T, |
|
target_edge_type=inf_valid_edges[:, 2], |
|
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2) |
|
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes, |
|
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2) |
|
|
|
if self.pre_transform is not None: |
|
train_data = self.pre_transform(train_data) |
|
valid_data = self.pre_transform(valid_data) |
|
test_data = self.pre_transform(test_data) |
|
|
|
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0]) |
|
|
|
|
|
class FBNELL(MTDEAInductive): |
|
|
|
name = "FBNELL" |
|
prefix = "%s" |
|
versions = ["FBNELL_v1"] |
|
|
|
def __init__(self, **kwargs): |
|
kwargs.pop("version") |
|
kwargs['version'] = self.versions[0] |
|
super(FBNELL, self).__init__(**kwargs) |
|
|
|
|
|
class Metafam(MTDEAInductive): |
|
|
|
name = "Metafam" |
|
prefix = "%s" |
|
versions = ["Metafam"] |
|
|
|
def __init__(self, **kwargs): |
|
kwargs.pop("version") |
|
kwargs['version'] = self.versions[0] |
|
super(Metafam, self).__init__(**kwargs) |
|
|
|
|
|
class WikiTopicsMT1(MTDEAInductive): |
|
|
|
name = "WikiTopics-MT1" |
|
prefix = "wikidata_%sv1" |
|
versions = ['mt', 'health', 'tax'] |
|
|
|
def __init__(self, **kwargs): |
|
assert kwargs['version'] in self.versions, f"unknown version {kwargs['version']}, available: {self.versions}" |
|
super(WikiTopicsMT1, self).__init__(**kwargs) |
|
|
|
|
|
class WikiTopicsMT2(MTDEAInductive): |
|
|
|
name = "WikiTopics-MT2" |
|
prefix = "wikidata_%sv1" |
|
versions = ['mt2', 'org', 'sci'] |
|
|
|
def __init__(self, **kwargs): |
|
super(WikiTopicsMT2, self).__init__(**kwargs) |
|
|
|
|
|
class WikiTopicsMT3(MTDEAInductive): |
|
|
|
name = "WikiTopics-MT3" |
|
prefix = "wikidata_%sv2" |
|
versions = ['mt3', 'art', 'infra'] |
|
|
|
def __init__(self, **kwargs): |
|
super(WikiTopicsMT3, self).__init__(**kwargs) |
|
|
|
|
|
class WikiTopicsMT4(MTDEAInductive): |
|
|
|
name = "WikiTopics-MT4" |
|
prefix = "wikidata_%sv2" |
|
versions = ['mt4', 'sci', 'health'] |
|
|
|
def __init__(self, **kwargs): |
|
super(WikiTopicsMT4, self).__init__(**kwargs) |
|
|
|
|
|
|
|
class JointDataset(InMemoryDataset): |
|
|
|
datasets_map = { |
|
'FB15k237': FB15k237, |
|
'WN18RR': WN18RR, |
|
'CoDExSmall': CoDExSmall, |
|
'CoDExMedium': CoDExMedium, |
|
'CoDExLarge': CoDExLarge, |
|
'NELL995': NELL995, |
|
'ConceptNet100k': ConceptNet100k, |
|
'DBpedia100k': DBpedia100k, |
|
'YAGO310': YAGO310, |
|
'AristoV4': AristoV4, |
|
} |
|
|
|
def __init__(self, root, graphs, transform=None, pre_transform=None): |
|
|
|
|
|
self.graphs = [self.datasets_map[ds](root=root) for ds in graphs] |
|
self.num_graphs = len(graphs) |
|
super().__init__(root, transform, pre_transform) |
|
self.data = torch.load(self.processed_paths[0]) |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "raw") |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "processed") |
|
|
|
@property |
|
def processed_file_names(self): |
|
return "data.pt" |
|
|
|
def process(self): |
|
|
|
train_data = [g[0] for g in self.graphs] |
|
valid_data = [g[1] for g in self.graphs] |
|
test_data = [g[2] for g in self.graphs] |
|
|
|
|
|
|
|
|
|
torch.save((train_data, valid_data, test_data), self.processed_paths[0]) |