demo_model / code /gnn_2 /gnn_2_model.py
Ayush121's picture
Upload 686 files
b170003
import numpy as np
import pandas as pd
import torch
from torch import nn, Tensor, LongTensor, tensor
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel
import dgl
from dgl.nn import GATConv
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
class GNN_2_Model(nn.Module):
@staticmethod
def get_all_pair(ar):
return np.array(np.meshgrid(ar, ar)).T.reshape(-1, 2)
@staticmethod
def get_edges(r, c, ret_extra_edges=False, ret_caption_edges=False):
edges = np.empty((0, 2), dtype=int)
row_edges = GNN_2_Model.get_all_pair(np.arange(c))
for i in range(r):
edges = np.concatenate((edges, row_edges + i * c), axis=0)
col_edges = GNN_2_Model.get_all_pair(np.arange(0, r * c, c))
for i in range(c):
edges = np.concatenate((edges, col_edges + i), axis=0)
edges = np.unique(edges, axis=0)
table_edges = LongTensor(edges[np.lexsort((edges[:, 1], edges[:, 0]))])
if ret_extra_edges:
table_cells = torch.arange(r * c)
row_edges = torch.stack([table_cells, r * c + table_cells // c]).T
col_edges = torch.stack([table_cells, r * c + r + table_cells % c]).T
row_self_edges = torch.stack([r * c + torch.arange(r), r * c + torch.arange(r)]).T
col_self_edges = torch.stack([r * c + r + torch.arange(c), r * c + r + torch.arange(c)]).T
if not ret_caption_edges:
return table_edges, torch.cat([row_edges, col_edges, row_self_edges, col_self_edges])
row_col_edges = torch.cat([row_edges, col_edges, torch.flip(row_edges, (1, )), torch.flip(col_edges, (1, )), row_self_edges, col_self_edges])
caption_edges = torch.stack([(r * c + r + c) * torch.ones(r + c + 1).long(), r * c + torch.arange(r + c + 1)]).T
return table_edges, torch.cat([row_col_edges, caption_edges])
return table_edges
@staticmethod
def create_edge_labels(all_edges, edge_list):
df = pd.DataFrame(all_edges.tolist())
df['merge'] = list(zip(df[0], df[1]))
edge_labels = LongTensor(df['merge'].isin(edge_list))
return edge_labels
@staticmethod
def get_all_pairs_torch(n, ordered=False):
if ordered:
return torch.cat([torch.combinations(torch.arange(n)), torch.combinations(torch.arange(n-1, -1, -1))])
return torch.combinations(torch.arange(n))
@staticmethod
def get_block(h_in, h_out):
return nn.Sequential(
nn.Linear(h_in, h_out),
nn.BatchNorm1d(h_out),
nn.ReLU(),
nn.Dropout(0.2),
)
@staticmethod
def validate_args(args):
assert isinstance(args['hidden_layer_sizes'], list)
assert isinstance(args['num_heads'], list)
assert len(args['hidden_layer_sizes']) == len(args['num_heads'])
assert isinstance(args['use_caption'], bool)
assert isinstance(args['add_constraint'], bool)
assert isinstance(args['use_max_freq_feat'], bool)
assert isinstance(args['max_freq_emb_size'], int)
return args
def __init__(self, args: dict):
super(GNN_2_Model, self).__init__()
self.args = self.validate_args(args)
config = AutoConfig.from_pretrained(args['lm_name'], cache_dir=args['cache_dir'])
self.encoder = AutoModel.from_pretrained(args['lm_name'], config=config, cache_dir=args['cache_dir'])
in_dim = config.hidden_size
self.default_embedding = nn.Embedding(1, in_dim)
self.positional_embeddings = nn.Embedding(4, config.hidden_size)
if self.args['use_max_freq_feat']:
self.max_freq_feat_embedding = nn.Embedding(6, self.args['max_freq_emb_size'])
in_dim += self.args['max_freq_emb_size']
self.gat_layers = nn.ModuleList()
self.gat_layers.append(
GATConv(in_dim, self.args['hidden_layer_sizes'][0], num_heads=self.args['num_heads'][0], residual=False))
for l in range(1, len(self.args['hidden_layer_sizes'])):
self.gat_layers.append(
GATConv(self.args['hidden_layer_sizes'][l-1] * self.args['num_heads'][l-1], self.args['hidden_layer_sizes'][l], \
num_heads=self.args['num_heads'][l], residual=True))
out_dim = self.args['hidden_layer_sizes'][-1] * self.args['num_heads'][-1]
self.dropout = nn.Dropout(0.2)
self.comp_and_gid_layer = nn.Sequential(self.get_block(out_dim, 256), nn.Linear(256, 4))
self.edge_layer = nn.Sequential(self.get_block(2 * out_dim, 256), nn.Linear(256, 1))
def _encoder_forward(self, input_ids, attention_mask):
lm_inp = {'input_ids': input_ids, 'attention_mask': attention_mask}
max_len = max(len(s) for s in lm_inp['input_ids'])
for k in lm_inp.keys():
lm_inp[k] = [s + [0] * (max_len - len(s)) for s in lm_inp[k]]
lm_inp[k] = LongTensor(lm_inp[k]).to(device)
return self.encoder(**lm_inp)[0][:, 0]
def concat_max_freq_feat_embedding(self, inps, h):
all_row_max_freq, all_col_max_freq = [], []
for x in inps:
row_max_freq, col_max_freq = x['max_freq_feat'][:x['num_rows']], x['max_freq_feat'][-x['num_cols']:]
row_max_freq = LongTensor(row_max_freq).unsqueeze(1).expand(x['num_rows'], x['num_cols']).reshape(-1)
col_max_freq = LongTensor(col_max_freq).unsqueeze(0).expand(x['num_rows'], x['num_cols']).reshape(-1)
all_row_max_freq += row_max_freq.tolist() + x['max_freq_feat'][:x['num_rows']] + [0] * x['num_cols']
all_col_max_freq += col_max_freq.tolist() + [0] * x['num_rows'] + x['max_freq_feat'][-x['num_cols']:]
if self.args['use_caption']:
all_row_max_freq.append(0)
all_col_max_freq.append(0)
all_row_max_freq, all_col_max_freq = LongTensor(all_row_max_freq), LongTensor(all_col_max_freq)
all_row_max_freq[all_row_max_freq > 5] = 5
all_col_max_freq[all_col_max_freq > 5] = 5
max_freq_feats = self.max_freq_feat_embedding(all_row_max_freq.to(device)) + self.max_freq_feat_embedding(all_col_max_freq.to(device))
h = torch.cat([h, max_freq_feats], dim=1)
return h
def calc_constraint_loss(self, inps, row_col_gid_logits):
comp_gid_probs = F.softmax(row_col_gid_logits, dim=1)
base = 0
constraints = {'1_2': [], '1_3': [], '2_3': [], '3_3': []}
for x in inps:
row_probs = comp_gid_probs[base:base+x['num_rows']].unsqueeze(1)
base += x['num_rows']
col_probs = comp_gid_probs[base:base+x['num_cols']].unsqueeze(0)
base += x['num_cols']
constraints['1_2'].append((row_probs[:, :, 1:3] + col_probs[:, :, 1:3] - 1).flatten())
row_pairs = self.get_all_pairs_torch(x['num_rows'], ordered=True)
col_pairs = self.get_all_pairs_torch(x['num_cols'], ordered=True)
constraints['1_3'] += [
(row_probs[:, 0, 1][row_pairs[:, 0]] + row_probs[:, 0, 3][row_pairs[:, 1]] - 1).flatten(),
(col_probs[0, :, 1][col_pairs[:, 0]] + col_probs[0, :, 3][col_pairs[:, 1]] - 1).flatten(),
]
constraints['2_3'] += [
(row_probs[:, :, 2] + col_probs[:, :, 3] - 1).flatten(),
(row_probs[:, :, 3] + col_probs[:, :, 2] - 1).flatten(),
]
row_col_pairs = self.get_all_pairs_torch(x['num_rows'] + x['num_cols'])
gid_probs = torch.cat([row_probs[:, 0, 3], col_probs[0, :, 3]])
constraints['3_3'].append(gid_probs[row_col_pairs[:, 0]] + gid_probs[row_col_pairs[:, 1]] - 1)
constraint_loss = tensor(0.0).to(device)
for c in constraints.keys():
constraint_loss += F.relu(torch.cat(constraints[c])).mean()
return constraint_loss
def forward(self, inps):
lm_inp = {'input_ids': [], 'attention_mask': []}
for x in inps:
for k in lm_inp.keys():
lm_inp[k] += x[k]
embs = []
for idx in range(0, len(lm_inp['input_ids']), 160):
embs.append(self._encoder_forward(lm_inp['input_ids'][idx:idx+160], lm_inp['attention_mask'][idx:idx+160]))
del lm_inp
cell_h = torch.cat(embs)
del embs
if self.args['use_caption']:
lm_inp = {'input_ids': [], 'attention_mask': []}
for x in inps:
for k in lm_inp.keys():
lm_inp[k] += x[f'caption_{k}']
caption_h = self._encoder_forward(lm_inp['input_ids'], lm_inp['attention_mask'])
del lm_inp
mask_keys = ['cell', 'row', 'col']
if self.args['use_caption']:
mask_keys.append('caption')
mask = {k: [] for k in mask_keys}
row_positional_idxs, col_positional_idxs = [], []
for x in inps:
mask['cell'] += [1] * x['num_cells'] + [0] * (x['num_rows'] + x['num_cols'] )
mask['row'] += [0] * x['num_cells'] + [1] * x['num_rows'] + [0] * (x['num_cols'])
mask['col'] += [0] * (x['num_cells'] + x['num_rows']) + [1] * x['num_cols']
table_cells = np.arange(x['num_cells']).reshape(x['num_rows'], x['num_cols'])
row_nums = table_cells // x['num_cols']
row_nums[row_nums > 2] = 2
row_nums += 1
row_positional_idxs += row_nums.flatten().tolist() + row_nums[:, 0].tolist() + [0] * x['num_cols']
col_nums = table_cells % x['num_cols']
col_nums[col_nums > 2] = 2
col_nums += 1
col_positional_idxs += col_nums.flatten().tolist() + [0] * x['num_rows'] + col_nums[0].tolist()
if self.args['use_caption']:
mask['caption'] += [0] * (x['num_cells'] + x['num_rows'] + x['num_cols'])
for k in mask_keys:
mask[k].append(1 if k == 'caption' else 0)
row_positional_idxs.append(0)
col_positional_idxs.append(0)
for k in mask_keys:
mask[k] = Tensor(mask[k]).bool().to(device)
h = self.default_embedding(LongTensor([0] * len(mask['cell'])).to(device))
h[mask['cell']] = cell_h
del cell_h
if self.args['use_caption']:
h[mask['caption']] = caption_h
del caption_h
h += self.positional_embeddings(LongTensor(row_positional_idxs).to(device)) + \
self.positional_embeddings(LongTensor(col_positional_idxs).to(device))
if self.args['use_max_freq_feat']:
h = self.concat_max_freq_feat_embedding(inps, h)
base = 0
batch_all_edges, batch_table_edges, batch_edge_labels = [], [], []
for x in inps:
table_edges, extra_edges = self.get_edges(x['num_rows'], x['num_cols'], ret_extra_edges=True, ret_caption_edges=self.args['use_caption'])
batch_all_edges.append(torch.cat([table_edges, extra_edges]) + base)
batch_table_edges.append(table_edges + base)
batch_edge_labels.append(self.create_edge_labels(table_edges, x['edge_list']))
base += x['num_cells'] + x['num_rows'] + x['num_cols']
if self.args['use_caption']:
base += 1
batch_all_edges = torch.cat(batch_all_edges)
batch_table_edges = torch.cat(batch_table_edges).to(device)
batch_edge_labels = torch.cat(batch_edge_labels).to(device)
batch_g = dgl.graph((batch_all_edges[:, 0], batch_all_edges[:, 1])).to(device)
for l in range(len(self.gat_layers)):
h = F.elu(self.gat_layers[l](batch_g, h)).flatten(1)
h = self.dropout(h)
batch_row_col_gid_labels = []
for x in inps:
batch_row_col_gid_labels += x['row_label'] + x['col_label']
batch_row_col_gid_labels = LongTensor(batch_row_col_gid_labels).to(device)
row_col_gid_logits = self.comp_and_gid_layer(h[mask['row'] | mask['col']])
assert len(batch_row_col_gid_labels) == len(row_col_gid_logits)
edge_embs = torch.cat([h[batch_table_edges[:, 0]], h[batch_table_edges[:, 1]]], dim=1)
edge_logits = self.edge_layer(edge_embs).squeeze(-1)
assert len(batch_edge_labels) == len(edge_logits)
ret = (row_col_gid_logits, batch_row_col_gid_labels), (edge_logits, batch_edge_labels),
if self.training:
if self.args['add_constraint']:
constraint_loss = self.calc_constraint_loss(inps, row_col_gid_logits)
else:
constraint_loss = tensor(0.0).to(device)
ret += (constraint_loss, ),
return ret