Spaces:
Running
Running
File size: 10,356 Bytes
ac7c391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
'''
A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
'''
#%%
import logging
import time
from tqdm import tqdm
import io
import pandas as pd
import numpy as np
import os
import json
import argparse
import torch
import random
from yaml import parse
from model import Conve, Distmult, Complex
logger = logging.getLogger(__name__)
#%%
def generate_dicts(data_path):
with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
ent_to_id = json.load(f)
with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
rel_to_id = json.load(f)
n_ent = len(list(ent_to_id.keys()))
n_rel = len(list(rel_to_id.keys()))
return n_ent, n_rel, ent_to_id, rel_to_id
def save_data(file_name, data):
with open(file_name, 'w') as fl:
for item in data:
fl.write("%s\n" % "\t".join(map(str, item)))
def load_data(file_name, drop = True):
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
if drop:
df = df.drop_duplicates()
else:
pass
return df.values
def seed_all(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
def add_model(args, n_ent, n_rel):
if args.model is None:
model = Distmult(args, n_ent, n_rel)
elif args.model == 'distmult':
model = Distmult(args, n_ent, n_rel)
elif args.model == 'complex':
model = Complex(args, n_ent, n_rel)
elif args.model == 'conve':
model = Conve(args, n_ent, n_rel)
else:
raise Exception("Unknown model!")
return model
def load_model(model_path, args, n_ent, n_rel, device):
# add a model and load the pre-trained params
model = add_model(args, n_ent, n_rel)
model.to(device)
logger.info('Loading saved model from {0}'.format(model_path))
state = torch.load(model_path)
model_params = state['state_dict']
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
for key, size, count in params:
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
model.load_state_dict(model_params)
model.eval()
logger.info(model)
return model
def add_eval_parameters(parser):
# parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
parser.add_argument('--mode', type = str, default = '', help = ' '' or '' ')
parser.add_argument('--mask-ratio', type=str, default='', help='Mask ratio for Fig4b')
return parser
def add_attack_parameters(parser):
# parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
# parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
parser.add_argument('--added-edge-num', type = str, default='', help = 'How many edges to add for each target edge. Default: '' means 1.')
# parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
# parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
parser.add_argument('--load-existed', action='store_true', help = 'Use cached intermidiate results or not, when only --reasonable-rate changed, set this param to True')
return parser
def get_argument_parser():
'''Generate an argument parser'''
parser = argparse.ArgumentParser(description='Graph embedding')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, conve, complex}')
parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
parser.add_argument('--add-reciprocals', action='store_true')
parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
#parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
# parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
# parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
# parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
# parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
return parser
def set_hyperparams(args):
if args.model == 'distmult':
args.lr = 0.005
args.train_batch_size = 1024
args.reg_norm = 3
elif args.model == 'complex':
args.lr = 0.005
args.reg_norm = 3
args.input_drop = 0.4
args.train_batch_size = 1024
elif args.model == 'conve':
args.lr = 0.005
args.train_batch_size = 1024
args.reg_weight = 0.0
# args.damping = 0.01
# args.lissa_repeat = 1
# args.lissa_depth = 1
# args.scale = 500
# args.lissa_batch_size = 100
args.damping = 0.01
args.lissa_repeat = 1
args.lissa_depth = 1
args.scale = 400
args.lissa_batch_size = 300
return args
|