ultra_3g / ultra /datasets.py
mgalkin's picture
ultra source
c810120
import os
import csv
import shutil
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
from ultra.tasks import build_relation_graph
class GrailInductiveDataset(InMemoryDataset):
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, merge_valid_test=True):
self.version = version
assert version in ["v1", "v2", "v3", "v4"]
# by default, most models on Grail datasets merge inductive valid and test splits as the final test split
# with this choice, the validation set is that of the transductive train (on the seen graph)
# by default it's turned on but you can experiment with turning this option off
# you'll need to delete the processed datasets then and re-run to cache a new dataset
self.merge_valid_test = merge_valid_test
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def num_relations(self):
return int(self.data.edge_type.max()) + 1
@property
def raw_dir(self):
return os.path.join(self.root, "grail", self.name, self.version, "raw")
@property
def processed_dir(self):
return os.path.join(self.root, "grail", self.name, self.version, "processed")
@property
def processed_file_names(self):
return "data.pt"
@property
def raw_file_names(self):
return [
"train_ind.txt", "valid_ind.txt", "test_ind.txt", "train.txt", "valid.txt"
]
def download(self):
for url, path in zip(self.urls, self.raw_paths):
download_path = download_url(url % self.version, self.raw_dir)
os.rename(download_path, path)
def process(self):
test_files = self.raw_paths[:3]
train_files = self.raw_paths[3:]
inv_train_entity_vocab = {}
inv_test_entity_vocab = {}
inv_relation_vocab = {}
triplets = []
num_samples = []
for txt_file in train_files:
with open(txt_file, "r") as fin:
num_sample = 0
for line in fin:
h_token, r_token, t_token = line.strip().split("\t")
if h_token not in inv_train_entity_vocab:
inv_train_entity_vocab[h_token] = len(inv_train_entity_vocab)
h = inv_train_entity_vocab[h_token]
if r_token not in inv_relation_vocab:
inv_relation_vocab[r_token] = len(inv_relation_vocab)
r = inv_relation_vocab[r_token]
if t_token not in inv_train_entity_vocab:
inv_train_entity_vocab[t_token] = len(inv_train_entity_vocab)
t = inv_train_entity_vocab[t_token]
triplets.append((h, t, r))
num_sample += 1
num_samples.append(num_sample)
for txt_file in test_files:
with open(txt_file, "r") as fin:
num_sample = 0
for line in fin:
h_token, r_token, t_token = line.strip().split("\t")
if h_token not in inv_test_entity_vocab:
inv_test_entity_vocab[h_token] = len(inv_test_entity_vocab)
h = inv_test_entity_vocab[h_token]
assert r_token in inv_relation_vocab
r = inv_relation_vocab[r_token]
if t_token not in inv_test_entity_vocab:
inv_test_entity_vocab[t_token] = len(inv_test_entity_vocab)
t = inv_test_entity_vocab[t_token]
triplets.append((h, t, r))
num_sample += 1
num_samples.append(num_sample)
triplets = torch.tensor(triplets)
edge_index = triplets[:, :2].t()
edge_type = triplets[:, 2]
num_relations = int(edge_type.max()) + 1
# creating fact graphs - those are graphs sent to a model, based on which we'll predict missing facts
# also, those fact graphs will be used for filtered evaluation
train_fact_slice = slice(None, sum(num_samples[:1]))
test_fact_slice = slice(sum(num_samples[:2]), sum(num_samples[:3]))
train_fact_index = edge_index[:, train_fact_slice]
train_fact_type = edge_type[train_fact_slice]
test_fact_index = edge_index[:, test_fact_slice]
test_fact_type = edge_type[test_fact_slice]
# add flipped triplets for the fact graphs
train_fact_index = torch.cat([train_fact_index, train_fact_index.flip(0)], dim=-1)
train_fact_type = torch.cat([train_fact_type, train_fact_type + num_relations])
test_fact_index = torch.cat([test_fact_index, test_fact_index.flip(0)], dim=-1)
test_fact_type = torch.cat([test_fact_type, test_fact_type + num_relations])
train_slice = slice(None, sum(num_samples[:1]))
valid_slice = slice(sum(num_samples[:1]), sum(num_samples[:2]))
# by default, SOTA models on Grail datasets merge inductive valid and test splits as the final test split
# with this choice, the validation set is that of the transductive train (on the seen graph)
# by default it's turned on but you can experiment with turning this option off
test_slice = slice(sum(num_samples[:3]), sum(num_samples)) if self.merge_valid_test else slice(sum(num_samples[:4]), sum(num_samples))
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
target_edge_index=edge_index[:, train_slice], target_edge_type=edge_type[train_slice], num_relations=num_relations*2)
valid_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
target_edge_index=edge_index[:, valid_slice], target_edge_type=edge_type[valid_slice], num_relations=num_relations*2)
test_data = Data(edge_index=test_fact_index, edge_type=test_fact_type, num_nodes=len(inv_test_entity_vocab),
target_edge_index=edge_index[:, test_slice], target_edge_type=edge_type[test_slice], num_relations=num_relations*2)
if self.pre_transform is not None:
train_data = self.pre_transform(train_data)
valid_data = self.pre_transform(valid_data)
test_data = self.pre_transform(test_data)
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
def __repr__(self):
return "%s(%s)" % (self.name, self.version)
class FB15k237Inductive(GrailInductiveDataset):
urls = [
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/train.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/valid.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/test.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/train.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/valid.txt"
]
name = "IndFB15k237"
def __init__(self, root, version):
super().__init__(root, version)
class WN18RRInductive(GrailInductiveDataset):
urls = [
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/train.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/valid.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/test.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/train.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/valid.txt"
]
name = "IndWN18RR"
def __init__(self, root, version):
super().__init__(root, version)
class NELLInductive(GrailInductiveDataset):
urls = [
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/train.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/valid.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/test.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/train.txt",
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/valid.txt"
]
name = "IndNELL"
def __init__(self, root, version):
super().__init__(root, version)
def FB15k237(root):
dataset = RelLinkPredDataset(name="FB15k-237", root=root+"/fb15k237/")
data = dataset.data
train_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
target_edge_index=data.train_edge_index, target_edge_type=data.train_edge_type,
num_relations=dataset.num_relations)
valid_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
target_edge_index=data.valid_edge_index, target_edge_type=data.valid_edge_type,
num_relations=dataset.num_relations)
test_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
target_edge_index=data.test_edge_index, target_edge_type=data.test_edge_type,
num_relations=dataset.num_relations)
# build relation graphs
train_data = build_relation_graph(train_data)
valid_data = build_relation_graph(valid_data)
test_data = build_relation_graph(test_data)
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
return dataset
def WN18RR(root):
dataset = WordNet18RR(root=root+"/wn18rr/")
# convert wn18rr into the same format as fb15k-237
data = dataset.data
num_nodes = int(data.edge_index.max()) + 1
num_relations = int(data.edge_type.max()) + 1
edge_index = data.edge_index[:, data.train_mask]
edge_type = data.edge_type[data.train_mask]
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)
edge_type = torch.cat([edge_type, edge_type + num_relations])
train_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
target_edge_index=data.edge_index[:, data.train_mask],
target_edge_type=data.edge_type[data.train_mask],
num_relations=num_relations*2)
valid_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
target_edge_index=data.edge_index[:, data.val_mask],
target_edge_type=data.edge_type[data.val_mask],
num_relations=num_relations*2)
test_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
target_edge_index=data.edge_index[:, data.test_mask],
target_edge_type=data.edge_type[data.test_mask],
num_relations=num_relations*2)
# build relation graphs
train_data = build_relation_graph(train_data)
valid_data = build_relation_graph(valid_data)
test_data = build_relation_graph(test_data)
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
dataset.num_relations = num_relations * 2
return dataset
class TransductiveDataset(InMemoryDataset):
delimiter = None
def __init__(self, root, transform=None, pre_transform=build_relation_graph, **kwargs):
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ["train.txt", "valid.txt", "test.txt"]
def download(self):
for url, path in zip(self.urls, self.raw_paths):
download_path = download_url(url, self.raw_dir)
os.rename(download_path, path)
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
triplets = []
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
with open(triplet_file, "r", encoding="utf-8") as fin:
for l in fin:
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
if u not in inv_entity_vocab:
inv_entity_vocab[u] = entity_cnt
entity_cnt += 1
if v not in inv_entity_vocab:
inv_entity_vocab[v] = entity_cnt
entity_cnt += 1
if r not in inv_rel_vocab:
inv_rel_vocab[r] = rel_cnt
rel_cnt += 1
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
triplets.append((u, v, r))
return {
"triplets": triplets,
"num_node": len(inv_entity_vocab), #entity_cnt,
"num_relation": rel_cnt,
"inv_entity_vocab": inv_entity_vocab,
"inv_rel_vocab": inv_rel_vocab
}
# default loading procedure: process train/valid/test files, create graphs from them
def process(self):
train_files = self.raw_paths[:3]
train_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
valid_results = self.load_file(train_files[1],
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
test_results = self.load_file(train_files[2],
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
# in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train adn 123,182 in YAGO test
# for consistency with other experimental results, we'll include those in the full vocab and num nodes
num_node = test_results["num_node"]
# the same for rels: in most cases train == test for transductive
# for AristoV4 train rels 1593, test 1604
num_relations = test_results["num_relation"]
train_triplets = train_results["triplets"]
valid_triplets = valid_results["triplets"]
test_triplets = test_results["triplets"]
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
train_target_etypes = torch.tensor([t[2] for t in train_triplets])
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
valid_etypes = torch.tensor([t[2] for t in valid_triplets])
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
test_etypes = torch.tensor([t[2] for t in test_triplets])
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
# build graphs of relations
if self.pre_transform is not None:
train_data = self.pre_transform(train_data)
valid_data = self.pre_transform(valid_data)
test_data = self.pre_transform(test_data)
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
def __repr__(self):
return "%s()" % (self.name)
@property
def num_relations(self):
return int(self.data.edge_type.max()) + 1
@property
def raw_dir(self):
return os.path.join(self.root, self.name, "raw")
@property
def processed_dir(self):
return os.path.join(self.root, self.name, "processed")
@property
def processed_file_names(self):
return "data.pt"
class CoDEx(TransductiveDataset):
name = "codex"
urls = [
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/train.txt",
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/valid.txt",
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/test.txt",
]
def download(self):
for url, path in zip(self.urls, self.raw_paths):
download_path = download_url(url % self.name, self.raw_dir)
os.rename(download_path, path)
class CoDExSmall(CoDEx):
"""
#node: 2034
#edge: 36543
#relation: 42
"""
url = "https://zenodo.org/record/4281094/files/codex-s.tar.gz"
md5 = "63cd8186fc2aeddc154e20cf4a10087e"
name = "codex-s"
def __init__(self, root):
super(CoDExSmall, self).__init__(root=root, size='s')
class CoDExMedium(CoDEx):
"""
#node: 17050
#edge: 206205
#relation: 51
"""
url = "https://zenodo.org/record/4281094/files/codex-m.tar.gz"
md5 = "43e561cfdca1c6ad9cc2f5b1ca4add76"
name = "codex-m"
def __init__(self, root):
super(CoDExMedium, self).__init__(root=root, size='m')
class CoDExLarge(CoDEx):
"""
#node: 77951
#edge: 612437
#relation: 69
"""
url = "https://zenodo.org/record/4281094/files/codex-l.tar.gz"
md5 = "9a10f4458c4bd2b16ef9b92b677e0d71"
name = "codex-l"
def __init__(self, root):
super(CoDExLarge, self).__init__(root=root, size='l')
class NELL995(TransductiveDataset):
# from the RED-GNN paper https://github.com/LARS-research/RED-GNN/tree/main/transductive/data/nell
# the OG dumps were found to have test set leakages
# training set is made out of facts+train files, so we sum up their samples to build one training graph
urls = [
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/facts.txt",
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/train.txt",
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/valid.txt",
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/test.txt",
]
name = "nell995"
@property
def raw_file_names(self):
return ["facts.txt", "train.txt", "valid.txt", "test.txt"]
def process(self):
train_files = self.raw_paths[:4]
facts_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
train_results = self.load_file(train_files[1], facts_results["inv_entity_vocab"], facts_results["inv_rel_vocab"])
valid_results = self.load_file(train_files[2], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
test_results = self.load_file(train_files[3], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
num_node = valid_results["num_node"]
num_relations = train_results["num_relation"]
train_triplets = facts_results["triplets"] + train_results["triplets"]
valid_triplets = valid_results["triplets"]
test_triplets = test_results["triplets"]
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
train_target_etypes = torch.tensor([t[2] for t in train_triplets])
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
valid_etypes = torch.tensor([t[2] for t in valid_triplets])
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
test_etypes = torch.tensor([t[2] for t in test_triplets])
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
# build graphs of relations
if self.pre_transform is not None:
train_data = self.pre_transform(train_data)
valid_data = self.pre_transform(valid_data)
test_data = self.pre_transform(test_data)
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
class ConceptNet100k(TransductiveDataset):
urls = [
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/train",
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/valid",
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/test",
]
name = "cnet100k"
delimiter = "\t"
class DBpedia100k(TransductiveDataset):
urls = [
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_train.txt",
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_valid.txt",
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_test.txt",
]
name = "dbp100k"
class YAGO310(TransductiveDataset):
urls = [
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/train.txt",
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/valid.txt",
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/test.txt",
]
name = "yago310"
class Hetionet(TransductiveDataset):
urls = [
"https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1",
"https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1",
"https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1",
]
name = "hetionet"
class AristoV4(TransductiveDataset):
url = "https://zenodo.org/record/5942560/files/aristo-v4.zip"
name = "aristov4"
delimiter = "\t"
def download(self):
download_path = download_url(self.url, self.raw_dir)
extract_zip(download_path, self.raw_dir)
os.unlink(download_path)
for oldname, newname in zip(['train', 'valid', 'test'], self.raw_paths):
os.rename(os.path.join(self.raw_dir, oldname), newname)
class SparserKG(TransductiveDataset):
# 5 datasets based on FB/NELL/WD, introduced in https://github.com/THU-KEG/DacKGR
# re-writing the loading function because dumps are in the format (h, t, r) while the standard is (h, r, t)
url = "https://raw.githubusercontent.com/THU-KEG/DacKGR/master/data.zip"
delimiter = "\t"
base_name = "SparseKG"
@property
def raw_dir(self):
return os.path.join(self.root, self.base_name, self.name, "raw")
@property
def processed_dir(self):
return os.path.join(self.root, self.base_name, self.name, "processed")
def download(self):
base_path = os.path.join(self.root, self.base_name)
download_path = download_url(self.url, base_path)
extract_zip(download_path, base_path)
for dsname in ['NELL23K', 'WD-singer', 'FB15K-237-10', 'FB15K-237-20', 'FB15K-237-50']:
for oldname, newname in zip(['train.triples', 'dev.triples', 'test.triples'], self.raw_file_names):
os.renames(os.path.join(base_path, "data", dsname, oldname), os.path.join(base_path, dsname, "raw", newname))
shutil.rmtree(os.path.join(base_path, "data"))
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
triplets = []
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
with open(triplet_file, "r", encoding="utf-8") as fin:
for l in fin:
u, v, r = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
if u not in inv_entity_vocab:
inv_entity_vocab[u] = entity_cnt
entity_cnt += 1
if v not in inv_entity_vocab:
inv_entity_vocab[v] = entity_cnt
entity_cnt += 1
if r not in inv_rel_vocab:
inv_rel_vocab[r] = rel_cnt
rel_cnt += 1
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
triplets.append((u, v, r))
return {
"triplets": triplets,
"num_node": len(inv_entity_vocab), #entity_cnt,
"num_relation": rel_cnt,
"inv_entity_vocab": inv_entity_vocab,
"inv_rel_vocab": inv_rel_vocab
}
class WDsinger(SparserKG):
name = "WD-singer"
class NELL23k(SparserKG):
name = "NELL23K"
class FB15k237_10(SparserKG):
name = "FB15K-237-10"
class FB15k237_20(SparserKG):
name = "FB15K-237-20"
class FB15k237_50(SparserKG):
name = "FB15K-237-50"
class InductiveDataset(InMemoryDataset):
delimiter = None
# some datasets (4 from Hamaguchi et al and Indigo) have validation set based off the train graph, not inference
valid_on_inf = True #
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, **kwargs):
self.version = str(version)
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
def download(self):
for url, path in zip(self.urls, self.raw_paths):
download_path = download_url(url % self.version, self.raw_dir)
os.rename(download_path, path)
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
triplets = []
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
with open(triplet_file, "r", encoding="utf-8") as fin:
for l in fin:
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
if u not in inv_entity_vocab:
inv_entity_vocab[u] = entity_cnt
entity_cnt += 1
if v not in inv_entity_vocab:
inv_entity_vocab[v] = entity_cnt
entity_cnt += 1
if r not in inv_rel_vocab:
inv_rel_vocab[r] = rel_cnt
rel_cnt += 1
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
triplets.append((u, v, r))
return {
"triplets": triplets,
"num_node": len(inv_entity_vocab), #entity_cnt,
"num_relation": rel_cnt,
"inv_entity_vocab": inv_entity_vocab,
"inv_rel_vocab": inv_rel_vocab
}
def process(self):
train_files = self.raw_paths[:4]
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
valid_res = self.load_file(
train_files[2],
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
)
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
train_target_etypes = torch.tensor([t[2] for t in train_edges])
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
inf_etypes = torch.tensor([t[2] for t in inf_graph])
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
valid_data = Data(edge_index=inf_edges if self.valid_on_inf else train_fact_index,
edge_type=inf_etypes if self.valid_on_inf else train_fact_type,
num_nodes=inference_num_nodes if self.valid_on_inf else num_train_nodes,
target_edge_index=inf_valid_edges[:, :2].T,
target_edge_type=inf_valid_edges[:, 2],
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
if self.pre_transform is not None:
train_data = self.pre_transform(train_data)
valid_data = self.pre_transform(valid_data)
test_data = self.pre_transform(test_data)
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
@property
def num_relations(self):
return int(self.data.edge_type.max()) + 1
@property
def raw_dir(self):
return os.path.join(self.root, self.name, self.version, "raw")
@property
def processed_dir(self):
return os.path.join(self.root, self.name, self.version, "processed")
@property
def raw_file_names(self):
return [
"transductive_train.txt", "inference_graph.txt", "inf_valid.txt", "inf_test.txt"
]
@property
def processed_file_names(self):
return "data.pt"
def __repr__(self):
return "%s(%s)" % (self.name, self.version)
class IngramInductive(InductiveDataset):
@property
def raw_dir(self):
return os.path.join(self.root, "ingram", self.name, self.version, "raw")
@property
def processed_dir(self):
return os.path.join(self.root, "ingram", self.name, self.version, "processed")
class FBIngram(IngramInductive):
urls = [
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/train.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/msg.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/valid.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/test.txt",
]
name = "fb"
class WKIngram(IngramInductive):
urls = [
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/train.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/msg.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/valid.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/test.txt",
]
name = "wk"
class NLIngram(IngramInductive):
urls = [
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/train.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/msg.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/valid.txt",
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/test.txt",
]
name = "nl"
class ILPC2022(InductiveDataset):
urls = [
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/train.txt",
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference.txt",
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_validation.txt",
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_test.txt",
]
name = "ilpc2022"
class HM(InductiveDataset):
# benchmarks from Hamaguchi et al and Indigo BM
urls = [
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/train.txt",
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-graph.txt",
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/valid.txt",
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-fact.txt",
]
name = "hm"
versions = {
'1k': "Hamaguchi-BM_both-1000",
'3k': "Hamaguchi-BM_both-3000",
'5k': "Hamaguchi-BM_both-5000",
'indigo': "INDIGO-BM"
}
# in 4 HM graphs, the validation set is based off the training graph, so we'll adjust the dataset creation accordingly
valid_on_inf = False
def __init__(self, root, version, **kwargs):
version = self.versions[version]
super().__init__(root, version, **kwargs)
# HM datasets are a bit weird: validation set (based off the train graph) has a few hundred new nodes, so we need a custom processing
def process(self):
train_files = self.raw_paths[:4]
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
valid_res = self.load_file(
train_files[2],
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
)
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
train_target_etypes = torch.tensor([t[2] for t in train_edges])
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
inf_etypes = torch.tensor([t[2] for t in inf_graph])
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
valid_data = Data(edge_index=train_fact_index,
edge_type=train_fact_type,
num_nodes=valid_res["num_node"], # the only fix in this function
target_edge_index=inf_valid_edges[:, :2].T,
target_edge_type=inf_valid_edges[:, 2],
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
if self.pre_transform is not None:
train_data = self.pre_transform(train_data)
valid_data = self.pre_transform(valid_data)
test_data = self.pre_transform(test_data)
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
class MTDEAInductive(InductiveDataset):
valid_on_inf = False
url = "https://reltrans.s3.us-east-2.amazonaws.com/MTDEA_data.zip"
base_name = "mtdea"
def __init__(self, root, version, **kwargs):
assert version in self.versions, f"unknown version {version} for {self.name}, available: {self.versions}"
super().__init__(root, version, **kwargs)
@property
def raw_dir(self):
return os.path.join(self.root, self.base_name, self.name, self.version, "raw")
@property
def processed_dir(self):
return os.path.join(self.root, self.base_name, self.name, self.version, "processed")
@property
def raw_file_names(self):
return [
"transductive_train.txt", "inference_graph.txt", "transductive_valid.txt", "inf_test.txt"
]
def download(self):
base_path = os.path.join(self.root, self.base_name)
download_path = download_url(self.url, base_path)
extract_zip(download_path, base_path)
# unzip all datasets at once
for dsname in ['FBNELL', 'Metafam', 'WikiTopics-MT1', 'WikiTopics-MT2', 'WikiTopics-MT3', 'WikiTopics-MT4']:
cl = globals()[dsname.replace("-","")]
versions = cl.versions
for version in versions:
for oldname, newname in zip(['train.txt', 'observe.txt', 'valid.txt', 'test.txt'], self.raw_file_names):
foldername = cl.prefix % version + "-trans" if "transductive" in newname else cl.prefix % version + "-ind"
os.renames(
os.path.join(base_path, "MTDEA_datasets", dsname, foldername, oldname),
os.path.join(base_path, dsname, version, "raw", newname)
)
shutil.rmtree(os.path.join(base_path, "MTDEA_datasets"))
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}, limit_vocab=False):
triplets = []
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
# limit_vocab is for dropping triples with unseen head/tail not seen in the main entity_vocab
# can be used for FBNELL and MT3:art, other datasets seem to be ok and share num_nodes/num_relations in the train/inference graph
with open(triplet_file, "r", encoding="utf-8") as fin:
for l in fin:
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
if u not in inv_entity_vocab:
if limit_vocab:
continue
inv_entity_vocab[u] = entity_cnt
entity_cnt += 1
if v not in inv_entity_vocab:
if limit_vocab:
continue
inv_entity_vocab[v] = entity_cnt
entity_cnt += 1
if r not in inv_rel_vocab:
if limit_vocab:
continue
inv_rel_vocab[r] = rel_cnt
rel_cnt += 1
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
triplets.append((u, v, r))
return {
"triplets": triplets,
"num_node": entity_cnt,
"num_relation": rel_cnt,
"inv_entity_vocab": inv_entity_vocab,
"inv_rel_vocab": inv_rel_vocab
}
# special processes for MTDEA datasets for one particular fix in the validation set loading
def process(self):
train_files = self.raw_paths[:4]
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
valid_res = self.load_file(
train_files[2],
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"],
limit_vocab=True, # the 1st fix in this function compared to the superclass processor
)
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
train_target_etypes = torch.tensor([t[2] for t in train_edges])
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
inf_etypes = torch.tensor([t[2] for t in inf_graph])
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
valid_data = Data(edge_index=train_fact_index,
edge_type=train_fact_type,
num_nodes=valid_res["num_node"], # the 2nd fix in this function
target_edge_index=inf_valid_edges[:, :2].T,
target_edge_type=inf_valid_edges[:, 2],
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
if self.pre_transform is not None:
train_data = self.pre_transform(train_data)
valid_data = self.pre_transform(valid_data)
test_data = self.pre_transform(test_data)
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
class FBNELL(MTDEAInductive):
name = "FBNELL"
prefix = "%s"
versions = ["FBNELL_v1"]
def __init__(self, **kwargs):
kwargs.pop("version")
kwargs['version'] = self.versions[0]
super(FBNELL, self).__init__(**kwargs)
class Metafam(MTDEAInductive):
name = "Metafam"
prefix = "%s"
versions = ["Metafam"]
def __init__(self, **kwargs):
kwargs.pop("version")
kwargs['version'] = self.versions[0]
super(Metafam, self).__init__(**kwargs)
class WikiTopicsMT1(MTDEAInductive):
name = "WikiTopics-MT1"
prefix = "wikidata_%sv1"
versions = ['mt', 'health', 'tax']
def __init__(self, **kwargs):
assert kwargs['version'] in self.versions, f"unknown version {kwargs['version']}, available: {self.versions}"
super(WikiTopicsMT1, self).__init__(**kwargs)
class WikiTopicsMT2(MTDEAInductive):
name = "WikiTopics-MT2"
prefix = "wikidata_%sv1"
versions = ['mt2', 'org', 'sci']
def __init__(self, **kwargs):
super(WikiTopicsMT2, self).__init__(**kwargs)
class WikiTopicsMT3(MTDEAInductive):
name = "WikiTopics-MT3"
prefix = "wikidata_%sv2"
versions = ['mt3', 'art', 'infra']
def __init__(self, **kwargs):
super(WikiTopicsMT3, self).__init__(**kwargs)
class WikiTopicsMT4(MTDEAInductive):
name = "WikiTopics-MT4"
prefix = "wikidata_%sv2"
versions = ['mt4', 'sci', 'health']
def __init__(self, **kwargs):
super(WikiTopicsMT4, self).__init__(**kwargs)
# a joint dataset for pre-training ULTRA on several graphs
class JointDataset(InMemoryDataset):
datasets_map = {
'FB15k237': FB15k237,
'WN18RR': WN18RR,
'CoDExSmall': CoDExSmall,
'CoDExMedium': CoDExMedium,
'CoDExLarge': CoDExLarge,
'NELL995': NELL995,
'ConceptNet100k': ConceptNet100k,
'DBpedia100k': DBpedia100k,
'YAGO310': YAGO310,
'AristoV4': AristoV4,
}
def __init__(self, root, graphs, transform=None, pre_transform=None):
self.graphs = [self.datasets_map[ds](root=root) for ds in graphs]
self.num_graphs = len(graphs)
super().__init__(root, transform, pre_transform)
self.data = torch.load(self.processed_paths[0])
@property
def raw_dir(self):
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "raw")
@property
def processed_dir(self):
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "processed")
@property
def processed_file_names(self):
return "data.pt"
def process(self):
train_data = [g[0] for g in self.graphs]
valid_data = [g[1] for g in self.graphs]
test_data = [g[2] for g in self.graphs]
# filter_data = [
# Data(edge_index=g.data.target_edge_index, edge_type=g.data.target_edge_type, num_nodes=g[0].num_nodes) for g in self.graphs
# ]
torch.save((train_data, valid_data, test_data), self.processed_paths[0])