|
import numpy as np |
|
import pandas as pd |
|
|
|
import torch |
|
from torch import nn, Tensor, LongTensor |
|
|
|
from table_bert.utils import BertTokenizer |
|
from table_bert.vertical.vertical_attention_table_bert import VerticalAttentionTableBert |
|
from table_bert.table import DiscomatTable |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
class TabertBaselineModel(nn.Module): |
|
|
|
@staticmethod |
|
def get_all_pair(ar): |
|
return np.array(np.meshgrid(ar, ar)).T.reshape(-1, 2) |
|
|
|
@staticmethod |
|
def get_edges(r, c): |
|
edges = np.empty((0, 2), dtype=int) |
|
row_edges = TabertBaselineModel.get_all_pair(np.arange(c)) |
|
for i in range(r): |
|
edges = np.concatenate((edges, row_edges + i * c), axis=0) |
|
col_edges = TabertBaselineModel.get_all_pair(np.arange(0, r * c, c)) |
|
for i in range(c): |
|
edges = np.concatenate((edges, col_edges + i), axis=0) |
|
edges = np.unique(edges, axis=0) |
|
table_edges = LongTensor(edges[np.lexsort((edges[:, 1], edges[:, 0]))]) |
|
assert len(table_edges) == r * c * (r + c - 1) |
|
return table_edges |
|
|
|
@staticmethod |
|
def create_edge_labels(all_edges, edge_list): |
|
df = pd.DataFrame(all_edges.tolist()) |
|
df['merge'] = list(zip(df[0], df[1])) |
|
edge_labels = LongTensor(df['merge'].isin(edge_list)) |
|
return edge_labels |
|
|
|
@staticmethod |
|
def get_block(h_in, h_out): |
|
return nn.Sequential( |
|
nn.Linear(h_in, h_out), |
|
nn.BatchNorm1d(h_out), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
) |
|
|
|
@staticmethod |
|
def get_max_possible_cells(input_ids, curr_len): |
|
curr_len += 1 |
|
num_cells = 0 |
|
for l in input_ids: |
|
curr_len += max(len(l) - 2, 1) |
|
if curr_len > 512: break |
|
num_cells += 1 |
|
return num_cells |
|
|
|
def __init__(self, args: dict): |
|
super(TabertBaselineModel, self).__init__() |
|
self.args = args |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.tokenizer.add_tokens(['[EMPTY]']) |
|
self.empty_cell_input_id = self.tokenizer.convert_tokens_to_ids(['[EMPTY]']) |
|
self.vertical_attention_table_bert_model = VerticalAttentionTableBert.from_pretrained('bert-base-uncased', cache_dir=args['cache_dir']) |
|
|
|
out_dim = 768 |
|
self.comp_and_gid_layer = nn.Sequential(nn.Dropout(0.2), self.get_block(out_dim, 256), nn.Linear(256, 4)) |
|
self.edge_layer = nn.Sequential(nn.Dropout(0.2), self.get_block(2 * out_dim, 256), nn.Linear(256, 1)) |
|
|
|
def truncate_caption(self, caption, max_len=100): |
|
caption_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(caption)[:max_len]) |
|
return self.tokenizer.decode(caption_ids) |
|
|
|
def forward(self, inps): |
|
|
|
tables = [] |
|
contexts = [] |
|
max_r = 0 |
|
max_c = 0 |
|
|
|
for x in inps: |
|
max_r = max(max_r, x['num_rows']) |
|
max_c = max(max_c, x['num_cols']) |
|
table = DiscomatTable(id=str(x['pii'])+":"+str(x['t_idx']), data=x['act_table']) |
|
table.tokenize(self.vertical_attention_table_bert_model.tokenizer) |
|
context = self.vertical_attention_table_bert_model.tokenizer.tokenize(x['caption']) |
|
tables.append(table) |
|
contexts.append(context) |
|
|
|
context_encodings, schema_encodings, final_table_encodings, info_dict = self.vertical_attention_table_bert_model.encode_for_discomat( |
|
contexts=contexts, |
|
tables=tables |
|
) |
|
del tables, contexts |
|
del context_encodings, schema_encodings |
|
torch.cuda.empty_cache() |
|
|
|
assert final_table_encodings.shape[0] == len(inps) |
|
assert final_table_encodings.shape[1] == min(self.vertical_attention_table_bert_model.config.sample_row_num, max_r) |
|
assert final_table_encodings.shape[2] == max_c |
|
assert final_table_encodings.shape[3] == 768 |
|
|
|
row_col_embs, edge_embs, batch_edge_labels = [], [], [] |
|
extra_rows_mask, extra_edges_mask = [], [] |
|
for i, x in enumerate(inps): |
|
r, c = x['num_rows'], x['num_cols'] |
|
table_mask = info_dict['tensor_dict']['table_mask'][i] |
|
table_cell_embs = final_table_encodings[i][:r, :c, :] |
|
r_ = min(r, final_table_encodings.shape[1]) |
|
|
|
if r > r_: |
|
|
|
|
|
extra_rows = torch.zeros(r-r_, c, table_cell_embs.shape[-1], device=device) |
|
table_cell_embs = torch.cat((table_cell_embs.to(device), extra_rows), dim=0) |
|
|
|
assert table_cell_embs.shape[0] == r and table_cell_embs.shape[1] == c and table_cell_embs.shape[2] == 768 |
|
first_cell_embs = table_cell_embs.reshape(r*c, table_cell_embs.shape[-1]) |
|
row_embs, col_embs = table_cell_embs.mean(1), table_cell_embs.mean(0) |
|
assert row_embs.shape[0] == r and col_embs.shape[0] == c |
|
row_col_embs += [row_embs, col_embs] |
|
extra_rows_mask += [0] * r_ + [1] * (r - r_) + [0] * c |
|
|
|
table_edges = self.get_edges(r_, c) |
|
act_table_edges = self.get_edges(r, c) |
|
curr_edge_embs = torch.zeros(len(act_table_edges), 2 * table_cell_embs.shape[-1]).to(device) |
|
curr_edge_embs[act_table_edges.max(1)[0] < r_ * c] = torch.cat([first_cell_embs[table_edges[:, 0]], first_cell_embs[table_edges[:, 1]]], dim=1) |
|
edge_embs.append(curr_edge_embs) |
|
batch_edge_labels.append(self.create_edge_labels(act_table_edges, x['edge_list'])) |
|
extra_edges_mask.append(act_table_edges.max(1)[0] >= r_ * c) |
|
|
|
keys = list(info_dict['tensor_dict'].keys())[:] |
|
for k in keys: |
|
del info_dict['tensor_dict'][k] |
|
del keys |
|
del first_cell_embs, table_cell_embs, curr_edge_embs, info_dict, table_edges, act_table_edges, final_table_encodings |
|
if r > r_: |
|
del extra_rows |
|
torch.cuda.empty_cache() |
|
extra_rows_mask = Tensor(extra_rows_mask).bool().to(device) |
|
|
|
row_col_gid_logits = self.comp_and_gid_layer(torch.cat(row_col_embs)) |
|
row_col_gid_logits[extra_rows_mask, 0] = 1.0 |
|
row_col_gid_logits[extra_rows_mask, 1:] = 0.0 |
|
del row_col_embs |
|
torch.cuda.empty_cache() |
|
|
|
extra_edges_mask = torch.cat(extra_edges_mask) |
|
edge_logits = self.edge_layer(torch.cat(edge_embs)).squeeze(-1) |
|
edge_logits[extra_edges_mask] = float('-inf') |
|
del edge_embs |
|
torch.cuda.empty_cache() |
|
|
|
batch_row_col_gid_labels = [] |
|
for x in inps: |
|
batch_row_col_gid_labels += x['row_label'] + x['col_label'] |
|
batch_row_col_gid_labels = LongTensor(batch_row_col_gid_labels).to(device) |
|
assert len(batch_row_col_gid_labels) == len(row_col_gid_logits) |
|
|
|
batch_edge_labels = torch.cat(batch_edge_labels).to(device) |
|
assert len(batch_edge_labels) == len(edge_logits) |
|
|
|
if self.training: |
|
return (row_col_gid_logits[~extra_rows_mask], batch_row_col_gid_labels[~extra_rows_mask]), \ |
|
(edge_logits[~extra_edges_mask], batch_edge_labels[~extra_edges_mask]) |
|
else: |
|
return (row_col_gid_logits, batch_row_col_gid_labels), (edge_logits, batch_edge_labels) |