demo_model / code /baseline_comparison /tabert_baseline_model.py
Ayush121's picture
Upload 686 files
b170003
import numpy as np
import pandas as pd
import torch
from torch import nn, Tensor, LongTensor
from table_bert.utils import BertTokenizer
from table_bert.vertical.vertical_attention_table_bert import VerticalAttentionTableBert
from table_bert.table import DiscomatTable
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
class TabertBaselineModel(nn.Module):
@staticmethod
def get_all_pair(ar):
return np.array(np.meshgrid(ar, ar)).T.reshape(-1, 2)
@staticmethod
def get_edges(r, c):
edges = np.empty((0, 2), dtype=int)
row_edges = TabertBaselineModel.get_all_pair(np.arange(c))
for i in range(r):
edges = np.concatenate((edges, row_edges + i * c), axis=0)
col_edges = TabertBaselineModel.get_all_pair(np.arange(0, r * c, c))
for i in range(c):
edges = np.concatenate((edges, col_edges + i), axis=0)
edges = np.unique(edges, axis=0)
table_edges = LongTensor(edges[np.lexsort((edges[:, 1], edges[:, 0]))])
assert len(table_edges) == r * c * (r + c - 1)
return table_edges
@staticmethod
def create_edge_labels(all_edges, edge_list):
df = pd.DataFrame(all_edges.tolist())
df['merge'] = list(zip(df[0], df[1]))
edge_labels = LongTensor(df['merge'].isin(edge_list))
return edge_labels
@staticmethod
def get_block(h_in, h_out):
return nn.Sequential(
nn.Linear(h_in, h_out),
nn.BatchNorm1d(h_out),
nn.ReLU(),
nn.Dropout(0.2),
)
@staticmethod
def get_max_possible_cells(input_ids, curr_len):
curr_len += 1 # one for last SEP in end
num_cells = 0
for l in input_ids:
curr_len += max(len(l) - 2, 1)
if curr_len > 512: break
num_cells += 1
return num_cells
def __init__(self, args: dict):
super(TabertBaselineModel, self).__init__()
self.args = args
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.tokenizer.add_tokens(['[EMPTY]']) # for empty table cells
self.empty_cell_input_id = self.tokenizer.convert_tokens_to_ids(['[EMPTY]'])
self.vertical_attention_table_bert_model = VerticalAttentionTableBert.from_pretrained('bert-base-uncased', cache_dir=args['cache_dir']) # initialize the vertical attention table bert model architecture with bert weights and not the pretrained weights
out_dim = 768
self.comp_and_gid_layer = nn.Sequential(nn.Dropout(0.2), self.get_block(out_dim, 256), nn.Linear(256, 4))
self.edge_layer = nn.Sequential(nn.Dropout(0.2), self.get_block(2 * out_dim, 256), nn.Linear(256, 1))
def truncate_caption(self, caption, max_len=100):
caption_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(caption)[:max_len])
return self.tokenizer.decode(caption_ids)
def forward(self, inps):
tables = []
contexts = []
max_r = 0
max_c = 0
for x in inps:
max_r = max(max_r, x['num_rows'])
max_c = max(max_c, x['num_cols'])
table = DiscomatTable(id=str(x['pii'])+":"+str(x['t_idx']), data=x['act_table'])
table.tokenize(self.vertical_attention_table_bert_model.tokenizer)
context = self.vertical_attention_table_bert_model.tokenizer.tokenize(x['caption'])
tables.append(table)
contexts.append(context)
context_encodings, schema_encodings, final_table_encodings, info_dict = self.vertical_attention_table_bert_model.encode_for_discomat(
contexts=contexts,
tables=tables
)
del tables, contexts
del context_encodings, schema_encodings
torch.cuda.empty_cache()
assert final_table_encodings.shape[0] == len(inps)
assert final_table_encodings.shape[1] == min(self.vertical_attention_table_bert_model.config.sample_row_num, max_r)
assert final_table_encodings.shape[2] == max_c
assert final_table_encodings.shape[3] == 768
row_col_embs, edge_embs, batch_edge_labels = [], [], []
extra_rows_mask, extra_edges_mask = [], []
for i, x in enumerate(inps):
r, c = x['num_rows'], x['num_cols']
table_mask = info_dict['tensor_dict']['table_mask'][i]
table_cell_embs = final_table_encodings[i][:r, :c, :]
r_ = min(r, final_table_encodings.shape[1])
if r > r_:
# means some rows have been truncated
# append extra rows to table_cell_embs
extra_rows = torch.zeros(r-r_, c, table_cell_embs.shape[-1], device=device)
table_cell_embs = torch.cat((table_cell_embs.to(device), extra_rows), dim=0)
assert table_cell_embs.shape[0] == r and table_cell_embs.shape[1] == c and table_cell_embs.shape[2] == 768
first_cell_embs = table_cell_embs.reshape(r*c, table_cell_embs.shape[-1])
row_embs, col_embs = table_cell_embs.mean(1), table_cell_embs.mean(0)
assert row_embs.shape[0] == r and col_embs.shape[0] == c
row_col_embs += [row_embs, col_embs]
extra_rows_mask += [0] * r_ + [1] * (r - r_) + [0] * c
table_edges = self.get_edges(r_, c)
act_table_edges = self.get_edges(r, c)
curr_edge_embs = torch.zeros(len(act_table_edges), 2 * table_cell_embs.shape[-1]).to(device)
curr_edge_embs[act_table_edges.max(1)[0] < r_ * c] = torch.cat([first_cell_embs[table_edges[:, 0]], first_cell_embs[table_edges[:, 1]]], dim=1)
edge_embs.append(curr_edge_embs)
batch_edge_labels.append(self.create_edge_labels(act_table_edges, x['edge_list']))
extra_edges_mask.append(act_table_edges.max(1)[0] >= r_ * c)
keys = list(info_dict['tensor_dict'].keys())[:]
for k in keys:
del info_dict['tensor_dict'][k]
del keys
del first_cell_embs, table_cell_embs, curr_edge_embs, info_dict, table_edges, act_table_edges, final_table_encodings
if r > r_:
del extra_rows
torch.cuda.empty_cache()
extra_rows_mask = Tensor(extra_rows_mask).bool().to(device)
row_col_gid_logits = self.comp_and_gid_layer(torch.cat(row_col_embs))
row_col_gid_logits[extra_rows_mask, 0] = 1.0
row_col_gid_logits[extra_rows_mask, 1:] = 0.0
del row_col_embs
torch.cuda.empty_cache()
extra_edges_mask = torch.cat(extra_edges_mask)
edge_logits = self.edge_layer(torch.cat(edge_embs)).squeeze(-1)
edge_logits[extra_edges_mask] = float('-inf')
del edge_embs
torch.cuda.empty_cache()
batch_row_col_gid_labels = []
for x in inps:
batch_row_col_gid_labels += x['row_label'] + x['col_label']
batch_row_col_gid_labels = LongTensor(batch_row_col_gid_labels).to(device)
assert len(batch_row_col_gid_labels) == len(row_col_gid_logits)
batch_edge_labels = torch.cat(batch_edge_labels).to(device)
assert len(batch_edge_labels) == len(edge_logits)
if self.training:
return (row_col_gid_logits[~extra_rows_mask], batch_row_col_gid_labels[~extra_rows_mask]), \
(edge_logits[~extra_edges_mask], batch_edge_labels[~extra_edges_mask])
else:
return (row_col_gid_logits, batch_row_col_gid_labels), (edge_logits, batch_edge_labels)