demo_model / code /gnn_1 /train_gnn_1.py
Ayush121's picture
Upload 686 files
b170003
from argparse import ArgumentParser
import math
import multiprocessing as mp
import os
import pickle
import sys
sys.path.append('..')
from sklearn.metrics import precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import set_seed, get_linear_schedule_with_warmup
from gnn_1_model import GNN_1_Model as Model
from regex_lib import parse_composition
from utils import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('using device:', device)
parser = ArgumentParser()
parser.add_argument('--seed', required=True, type=int)
parser.add_argument('--hidden_layer_sizes', nargs='+', required=True, type=int)
parser.add_argument('--num_heads', nargs='+', required=True, type=int)
parser.add_argument('--num_epochs', required=False, default=15, type=int)
parser.add_argument('--lr', required=False, default=1e-3, type=float)
parser.add_argument('--lm_lr', required=False, default=1e-5, type=float)
parser.add_argument('--use_regex_feat', action='store_true')
parser.add_argument('--use_max_freq_feat', action='store_true')
parser.add_argument('--add_constraint', action='store_true')
parser.add_argument('--regex_emb_size', required=False, default=256, type=int)
parser.add_argument('--max_freq_emb_size', required=False, default=256, type=int)
parser.add_argument('--c_loss_lambda', required=False, default=50.0, type=float)
parser.add_argument('--gid_loss_lambda', required=False, default=1.0, type=float)
parser.add_argument('--model_save_file', required=True, type=str)
parser.add_argument('--res_file', required=False, type=str)
args = parser.parse_args()
print(args)
lm_name = 'm3rg-iitd/matscibert'
cache_dir = os.path.join(table_dir, '.cache')
os.makedirs(os.path.dirname(os.path.abspath(args.model_save_file)), exist_ok=True)
if args.use_regex_feat:
for c in tqdm(comp_data):
c['regex_feats'] = get_regex_feats(c['act_table'])
if args.use_max_freq_feat:
for c in comp_data:
c['max_freq_feat'] = get_max_freq_feat(c['act_table'])
torch.set_deterministic(True)
torch.backends.cudnn.benchmark = False
datasets = dict()
for split in splits:
datasets[split] = TableDataset([comp_data_dict[pii_t_idx] for pii_t_idx in train_val_test_split[split]])
set_seed(args.seed)
batch_size = 8
num_workers = mp.cpu_count()
loaders = dict()
for split in splits:
loaders[split] = DataLoader(datasets[split], batch_size=batch_size, shuffle=(split == 'train'), \
num_workers=num_workers, collate_fn=lambda x: x)
all_train_regex_labels = [x['regex_table'] for x in datasets['train']]
all_train_gid_labels = []
for x in datasets['train']:
all_train_gid_labels += x['gid_row_label'] + x['gid_col_label']
num_epochs = args.num_epochs
n_batches = math.ceil(len(datasets['train']) / batch_size)
n_steps = n_batches * num_epochs
warmup_steps = n_steps // 10
model_args = {
'hidden_layer_sizes': args.hidden_layer_sizes,
'num_heads': args.num_heads,
'lm_name': lm_name,
'cache_dir': cache_dir,
'use_regex_feat': args.use_regex_feat,
'use_max_freq_feat': args.use_max_freq_feat,
'add_constraint': args.add_constraint,
'regex_emb_size': args.regex_emb_size,
'max_freq_emb_size': args.max_freq_emb_size,
}
model = Model(model_args).to(device)
optim_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if 'encoder' not in n], 'lr': args.lr},
{'params': [p for n, p in model.named_parameters() if 'encoder' in n], 'lr': args.lm_lr},
]
optim = torch.optim.AdamW(optim_grouped_parameters)
regex_class_weights = torch.Tensor(compute_class_weight('balanced', classes=[0, 1], y=all_train_regex_labels)).to(device)
regex_loss_fn = nn.CrossEntropyLoss(weight=regex_class_weights)
gid_class_weights = torch.Tensor(compute_class_weight('balanced', classes=[0, 1], y=all_train_gid_labels)).to(device)
gid_loss_fn = nn.CrossEntropyLoss(weight=gid_class_weights)
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=warmup_steps, num_training_steps=n_steps)
gold_tuples = dict()
for split in ['val', 'test']:
gold_tuples[split] = []
for pii, t_idx in train_val_test_split[split]:
gold_tuples[split] += get_gold_tuples(pii, t_idx)
def get_pred_tuples(pii_t_idx: tuple, regex_table: list, orient: str, gid):
gid_list = []
c = comp_data_dict[pii_t_idx]
table = c['act_table']
pii, t_idx = pii_t_idx
if orient == 'row':
for i in range(len(table)):
if gid is not None and table[i][gid]:
gid_list.append('_' + table[i][gid])
else:
gid_list.append('')
else:
for j in range(len(table[0])):
if gid is not None and table[gid][j]:
gid_list.append('_' + table[gid][j])
else:
gid_list.append('')
tuples = []
for i in range(len(table)):
for j in range(len(table[0])):
if regex_table[i][j] is None: continue
prefix = f'{pii}_{t_idx}_{i}_{j}_0'
for x in regex_table[i][j]:
if x[1] == 0: continue
gid = gid_list[i] if orient == 'row' else gid_list[j]
tuples.append((prefix + gid, x[0], x[1], pred_cell_mol_wt(c, i, j)))
return tuples
def get_regex_table_and_orient(table):
regex_table = []
regex_label = 0
for r in table:
res_r = []
for cell in r:
comp = parse_composition(cell)
if len(comp) == 0 or len(comp[0][0]) == 1:
res_r.append(None)
continue
l = comp[0][0]
new_l = []
for x in l:
if type(x[1]) == float:
x = (x[0], round(x[1], 5))
elif type(x[1]) == int:
x = (x[0], float(x[1]))
new_l.append(x)
if all(type(x[1]) == float for x in new_l):
regex_label = 1
res_r.append(new_l)
else:
res_r.append(None)
regex_table.append(res_r)
if regex_label == 0:
return None, None
row_max = 0
for r in range(len(table)):
curr = 0
for comp in regex_table[r]:
if type(comp) == list:
curr += 1
row_max = max(row_max, curr)
col_max = 0
for c in range(len(table[0])):
curr = 0
for r in range(len(table)):
if type(regex_table[r][c]) == list:
curr += 1
col_max = max(col_max, curr)
if row_max <= col_max:
return regex_table, 'row'
return regex_table, 'col'
def get_gid_labels_and_tuples(gid_logits, scc_label: int, pii_t_idx: tuple, num_rows: int, num_cols: int):
row_gid_labels, col_gid_labels = [0] * num_rows, [0] * num_cols
if scc_label == 0:
return row_gid_labels + col_gid_labels, []
regex_table, orient = get_regex_table_and_orient(comp_data_dict[pii_t_idx]['act_table'])
if orient is None:
return row_gid_labels + col_gid_labels, []
gid = None
if orient == 'row':
gid_col_probs = F.softmax(gid_logits[num_rows:], dim=1)
gid_idx = gid_col_probs[:, 1].argmax()
if gid_col_probs[gid_idx, 1] > 0.5:
col_gid_labels[gid_idx] = 1
gid = gid_idx
else:
gid_row_probs = F.softmax(gid_logits[:num_rows], dim=1)
gid_idx = gid_row_probs[:, 1].argmax()
if gid_row_probs[gid_idx, 1] > 0.5:
row_gid_labels[gid_idx] = 1
gid = gid_idx
return row_gid_labels + col_gid_labels, get_pred_tuples(pii_t_idx, regex_table, orient, gid)
def get_batch_gid_labels_and_tuples(gid_logits, scc_labels: list, pii_t_idxs: list, num_rows: list, num_cols: list):
base_gid = 0
pred_gid_labels, pred_tuples = [], []
for pii_t_idx, regex_label, r, c in zip(pii_t_idxs, scc_labels, num_rows, num_cols):
num_gid_logits = r + c
gids_labels, tuples = get_gid_labels_and_tuples(
gid_logits[base_gid:base_gid+num_gid_logits], regex_label, pii_t_idx, r, c)
pred_gid_labels += gids_labels
pred_tuples.append(tuples)
base_gid += num_gid_logits
return pred_gid_labels, pred_tuples
losses = ['regex', 'gid', 'constraint']
coeffs = [1.0, args.gid_loss_lambda, args.c_loss_lambda]
def train_model(epoch):
model.train()
epoch_loss = {l: 0.0 for l in losses}
curr_coeffs = coeffs.copy()
if epoch < 3:
curr_coeffs[2] = 0.0
n_batches = len(loaders['train'])
tepoch = tqdm(loaders['train'], unit='batch')
batch_loss = dict()
for batch_data in tepoch:
tepoch.set_description(f'Epoch {epoch}')
torch.cuda.empty_cache()
(scc_logits, scc_labels), (gid_logits, gid_labels), (batch_loss[losses[2]], ), = model(batch_data)
batch_loss[losses[0]] = regex_loss_fn(scc_logits, scc_labels)
batch_loss[losses[1]] = gid_loss_fn(gid_logits, gid_labels)
for l in losses:
epoch_loss[l] += batch_loss[l].item()
loss = sum(curr_coeffs[i] * batch_loss[losses[i]] for i in range(len(losses)))
optim.zero_grad()
loss.backward()
optim.step()
scheduler.step()
del scc_logits, scc_labels, gid_logits, gid_labels
for l in losses:
epoch_loss[l] /= n_batches
return epoch_loss
def eval_model(split, debug=False):
model.eval()
identifier = []
y_scc_true, y_scc_pred = [], []
y_gids_true, y_gids_pred, ret_gids_pred = [], [], []
ret_tuples_pred = []
y_true_scc_gids, ret_true_scc_gids, ret_true_scc_tuples = [], [], []
with torch.no_grad():
tepoch = tqdm(loaders[split], unit='batch')
for batch_data in tepoch:
tepoch.set_description(f'{split} mode')
(scc_logits, scc_labels), (gid_logits, gid_labels) = model(batch_data)
true_regex_labels = scc_labels.cpu().detach().tolist()
pred_regex_labels = scc_logits.argmax(1).cpu().detach().tolist()
y_scc_true += true_regex_labels
y_scc_pred += pred_regex_labels
y_gids_true += gid_labels.cpu().detach().tolist()
pred_gid_labels = gid_logits.argmax(1).cpu().detach().tolist()
base = 0
for p, x in zip(pred_regex_labels, batch_data):
if p == 1:
y_gids_pred += pred_gid_labels[base:base+x['num_rows']+x['num_cols']]
else:
y_gids_pred += [0] * (x['num_rows'] + x['num_cols'])
base += x['num_rows'] + x['num_cols']
if debug:
base = 0
for p, x in zip(true_regex_labels, batch_data):
if p == 1:
y_true_scc_gids += pred_gid_labels[base:base+x['num_rows']+x['num_cols']]
else:
y_true_scc_gids += [0] * (x['num_rows'] + x['num_cols'])
base += x['num_rows'] + x['num_cols']
num_rows, num_cols = [x['num_rows'] for x in batch_data], [x['num_cols'] for x in batch_data]
pii_t_idxs = [(x['pii'], x['t_idx']) for x in batch_data]
identifier += pii_t_idxs
pred_gid_labels, pred_tuples = get_batch_gid_labels_and_tuples(
gid_logits.cpu().detach(), pred_regex_labels, pii_t_idxs, num_rows, num_cols)
ret_tuples_pred += pred_tuples
if not debug: continue
base_gid = 0
for x in batch_data:
gid_dict = dict()
gid_dict['row'] = pred_gid_labels[base_gid:base_gid+x['num_rows']]
base_gid += x['num_rows']
gid_dict['col'] = pred_gid_labels[base_gid:base_gid+x['num_cols']]
base_gid += x['num_cols']
ret_gids_pred.append(gid_dict)
pred_gid_labels, pred_tuples = get_batch_gid_labels_and_tuples(
gid_logits.cpu().detach(), true_regex_labels, pii_t_idxs, num_rows, num_cols)
ret_true_scc_tuples += pred_tuples
for x in batch_data:
gid_dict = dict()
gid_dict['row'] = pred_gid_labels[base_gid:base_gid+x['num_rows']]
base_gid += x['num_rows']
gid_dict['col'] = pred_gid_labels[base_gid:base_gid+x['num_cols']]
base_gid += x['num_cols']
ret_true_scc_gids.append(gid_dict)
prec, recall, fscore, _ = precision_recall_fscore_support(y_scc_true, y_scc_pred, average='binary')
scc_metrics = {'precision': prec, 'recall': recall, 'fscore': fscore}
prec, recall, fscore, _ = precision_recall_fscore_support(y_gids_true, y_gids_pred, average='binary')
gids_metrics = {'precision': prec, 'recall': recall, 'fscore': fscore}
all_pred_tuples = []
for t in ret_tuples_pred:
all_pred_tuples += t
tuple_metrics = get_tuples_metrics(gold_tuples[split], all_pred_tuples)
composition_metrics = get_composition_metrics(gold_tuples[split], all_pred_tuples)
if not debug:
return scc_metrics, gids_metrics, tuple_metrics, composition_metrics
else:
return identifier, (scc_metrics, y_scc_pred), (gids_metrics, y_gids_pred, ret_gids_pred), \
(y_true_scc_gids, ret_true_scc_gids), (tuple_metrics, ret_tuples_pred), \
(ret_true_scc_tuples, ), composition_metrics
best_val = 0.0
for epoch in range(num_epochs):
epoch_loss = train_model(epoch)
print(f'Epoch {epoch} | Loss {epoch_loss}')
val_stats = eval_model('val')
print('Val Stats\n', val_stats)
test_stats = eval_model('test')
print('Test Stats\n', test_stats)
print()
if val_stats[-1]['fscore'] > best_val:
best_val = val_stats[-1]['fscore']
torch.save(model.state_dict(), args.model_save_file)
model.load_state_dict(torch.load(args.model_save_file, map_location=torch.device('cpu')))
model = model.to(device)
res = {'val': dict(), 'test': dict()}
print()
for s in res.keys():
res[s]['identifier'], (res[s]['scc_stats'], res[s]['scc_pred']), (res[s]['gid_stats'], \
res[s]['gid_pred_orig'], res[s]['gid_pred']), (res[s]['true_scc_gid_pred_orig'], \
res[s]['true_scc_gid_pred']), (res[s]['tuple_metrics'], res[s]['tuples_pred']), \
(res[s]['true_scc_tuples_pred'], ), res[s]['composition_metrics'] = eval_model(s, debug=True)
print(f'{s} scc: \n', res[s]['scc_stats'])
print(f'{s} gid: \n', res[s]['gid_stats'])
print(f'{s} tuple metrics: \n', res[s]['tuple_metrics'])
print(f'{s} composition metrics: \n', res[s]['composition_metrics'])
for k in ['gid_pred_orig', 'true_scc_gid_pred_orig']:
gid_pred_orig = []
base = 0
for pii_t_idx in res[s]['identifier']:
c = comp_data_dict[pii_t_idx]
d = dict()
d['row'] = res[s][k][base:base+c['num_rows']]
base += c['num_rows']
d['col'] = res[s][k][base:base+c['num_cols']]
base += c['num_cols']
gid_pred_orig.append(d)
res[s][k] = gid_pred_orig
violations, total = 0, 0
for table in res[s]['gid_pred_orig']:
v, t = cnt_3_3_violations(table)
violations += v
total += t
print(f'{s} 3_3_violations: {violations}/{total}')
res[s]['3_3_violations'] = violations
if args.res_file:
os.makedirs(os.path.join(table_dir, 'res_dir'), exist_ok=True)
pickle.dump(res, open(os.path.join(table_dir, 'res_dir', args.res_file), 'wb'))
# os.remove(args.model_save_file)