Spaces:
Sleeping
Sleeping
import os | |
import sys | |
from transformers import RobertaTokenizer, BertTokenizer | |
from tqdm import tqdm # for our progress bar | |
from transformers import AdamW | |
import torch | |
from torch.utils.data import DataLoader | |
from models.spatial_bert_model import SpatialBertModel | |
from models.spatial_bert_model import SpatialBertConfig | |
from models.spatial_bert_model import SpatialBertForMaskedLM | |
from datasets.osm_sample_loader import PbfMapDataset | |
from transformers.models.bert.modeling_bert import BertForMaskedLM | |
import numpy as np | |
import argparse | |
import pdb | |
DEBUG = False | |
def training(args): | |
num_workers = args.num_workers | |
batch_size = args.batch_size | |
epochs = args.epochs | |
lr = args.lr #1e-7 # 5e-5 | |
save_interval = args.save_interval | |
max_token_len = args.max_token_len | |
distance_norm_factor = args.distance_norm_factor | |
spatial_dist_fill=args.spatial_dist_fill | |
with_type = args.with_type | |
sep_between_neighbors = args.sep_between_neighbors | |
freeze_backbone = args.freeze_backbone | |
bert_option = args.bert_option | |
if_no_spatial_distance = args.no_spatial_distance | |
assert bert_option in ['bert-base','bert-large'] | |
london_file_path = '../data/sql_output/osm-point-london.json' | |
california_file_path = '../data/sql_output/osm-point-california.json' | |
if args.model_save_dir is None: | |
sep_pathstr = '_sep' if sep_between_neighbors else '_nosep' | |
freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze' | |
context_pathstr = '_nocontext' if if_no_spatial_distance else '_withcontext' | |
model_save_dir = '/data2/zekun/spatial_bert_weights/mlm_mem_lr' + str("{:.0e}".format(lr)) + sep_pathstr + context_pathstr +'/'+bert_option+ freeze_pathstr + '_mlm_mem_london_california_bsize' + str(batch_size) | |
if not os.path.isdir(model_save_dir): | |
os.makedirs(model_save_dir) | |
else: | |
model_save_dir = args.model_save_dir | |
print('model_save_dir', model_save_dir) | |
print('\n') | |
if bert_option == 'bert-base': | |
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased') | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance) | |
elif bert_option == 'bert-large': | |
bert_model = BertForMaskedLM.from_pretrained('bert-large-uncased') | |
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") | |
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) | |
else: | |
raise NotImplementedError | |
model = SpatialBertForMaskedLM(config) | |
model.load_state_dict(bert_model.state_dict() , strict = False) # load sentence position embedding weights as well | |
if bert_option == 'bert-large' and freeze_backbone: | |
print('freezing backbone weights') | |
for param in model.parameters(): | |
param.requires_grad = False | |
for param in model.cls.parameters(): | |
param.requires_grad = True | |
for param in model.bert.encoder.layer[21].parameters(): | |
param.requires_grad = True | |
for param in model.bert.encoder.layer[22].parameters(): | |
param.requires_grad = True | |
for param in model.bert.encoder.layer[23].parameters(): | |
param.requires_grad = True | |
london_train_dataset = PbfMapDataset(data_file_path = london_file_path, | |
tokenizer = tokenizer, | |
max_token_len = max_token_len, | |
distance_norm_factor = distance_norm_factor, | |
spatial_dist_fill = spatial_dist_fill, | |
with_type = with_type, | |
sep_between_neighbors = sep_between_neighbors, | |
label_encoder = None, | |
mode = None) | |
california_train_dataset = PbfMapDataset(data_file_path = california_file_path, | |
tokenizer = tokenizer, | |
max_token_len = max_token_len, | |
distance_norm_factor = distance_norm_factor, | |
spatial_dist_fill = spatial_dist_fill, | |
with_type = with_type, | |
sep_between_neighbors = sep_between_neighbors, | |
label_encoder = None, | |
mode = None) | |
train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset]) | |
if DEBUG: | |
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, | |
shuffle=False, pin_memory=True, drop_last=True) | |
else: | |
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, | |
shuffle=True, pin_memory=True, drop_last=True) | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
model.to(device) | |
model.train() | |
# initialize optimizer | |
optim = AdamW(model.parameters(), lr = lr) | |
print('start training...') | |
for epoch in range(epochs): | |
# setup loop with TQDM and dataloader | |
loop = tqdm(train_loader, leave=True) | |
iter = 0 | |
for batch in loop: | |
# initialize calculated gradients (from prev step) | |
optim.zero_grad() | |
# pull all tensor batches required for training | |
input_ids = batch['masked_input'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
position_list_x = batch['norm_lng_list'].to(device) | |
position_list_y = batch['norm_lat_list'].to(device) | |
sent_position_ids = batch['sent_position_ids'].to(device) | |
labels = batch['pseudo_sentence'].to(device) | |
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids, | |
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels) | |
loss = outputs.loss | |
loss.backward() | |
optim.step() | |
loop.set_description(f'Epoch {epoch}') | |
loop.set_postfix({'loss':loss.item()}) | |
if DEBUG: | |
print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() ) | |
iter += 1 | |
if iter % save_interval == 0 or iter == loop.total: | |
save_path = os.path.join(model_save_dir, 'mlm_mem_keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \ | |
+ '_' +str("{:.4f}".format(loss.item())) +'.pth' ) | |
torch.save(model.state_dict(), save_path) | |
print('saving model checkpoint to', save_path) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--num_workers', type=int, default=5) | |
parser.add_argument('--batch_size', type=int, default=12) | |
parser.add_argument('--epochs', type=int, default=10) | |
parser.add_argument('--save_interval', type=int, default=2000) | |
parser.add_argument('--max_token_len', type=int, default=300) | |
parser.add_argument('--lr', type=float, default = 5e-5) | |
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001) | |
parser.add_argument('--spatial_dist_fill', type=float, default = 20) | |
parser.add_argument('--with_type', default=False, action='store_true') | |
parser.add_argument('--sep_between_neighbors', default=False, action='store_true') | |
parser.add_argument('--freeze_backbone', default=False, action='store_true') | |
parser.add_argument('--no_spatial_distance', default=False, action='store_true') | |
parser.add_argument('--bert_option', type=str, default='bert-base') | |
parser.add_argument('--model_save_dir', type=str, default=None) | |
args = parser.parse_args() | |
print('\n') | |
print(args) | |
print('\n') | |
# out_dir not None, and out_dir does not exist, then create out_dir | |
if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir): | |
os.makedirs(args.model_save_dir) | |
training(args) | |
if __name__ == '__main__': | |
main() | |