File size: 3,404 Bytes
4ac075d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
#!/usr/bin/env python3
from absl import app, flags, logging
import sh
import torch as th
import pytorch_lightning as pl
import nlp
import transformers
flags.DEFINE_boolean('debug', False, '')
flags.DEFINE_integer('epochs', 10, '')
flags.DEFINE_integer('batch_size', 8, '')
flags.DEFINE_float('lr', 1e-2, '')
flags.DEFINE_float('momentum', .9, '')
flags.DEFINE_string('model', 'bert-base-uncased', '')
flags.DEFINE_integer('seq_length', 32, '')
flags.DEFINE_integer('percent', 5, '')
FLAGS = flags.FLAGS
sh.rm('-r', '-f', 'logs')
sh.mkdir('logs')
class IMDBSentimentClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = transformers.BertForSequenceClassification.from_pretrained(FLAGS.model)
self.loss = th.nn.CrossEntropyLoss(reduction='none')
def prepare_data(self):
tokenizer = transformers.BertTokenizer.from_pretrained(FLAGS.model)
def _tokenize(x):
x['input_ids'] = tokenizer.batch_encode_plus(
x['text'],
max_length=FLAGS.seq_length,
pad_to_max_length=True)['input_ids']
return x
def _prepare_ds(split):
ds = nlp.load_dataset('imdb', split=f'{split}[:{FLAGS.batch_size if FLAGS.debug else f"{FLAGS.percent}%"}]')
ds = ds.map(_tokenize, batched=True)
ds.set_format(type='torch', columns=['input_ids', 'label'])
return ds
self.train_ds, self.test_ds = map(_prepare_ds, ('train', 'test'))
def forward(self, input_ids):
mask = (input_ids != 0).float()
logits, = self.model(input_ids, mask)
return logits
def training_step(self, batch, batch_idx):
logits = self.forward(batch['input_ids'])
loss = self.loss(logits, batch['label']).mean()
return {'loss': loss, 'log': {'train_loss': loss}}
def validation_step(self, batch, batch_idx):
logits = self.forward(batch['input_ids'])
loss = self.loss(logits, batch['label'])
acc = (logits.argmax(-1) == batch['label']).float()
return {'loss': loss, 'acc': acc}
def validation_epoch_end(self, outputs):
loss = th.cat([o['loss'] for o in outputs], 0).mean()
acc = th.cat([o['acc'] for o in outputs], 0).mean()
out = {'val_loss': loss, 'val_acc': acc}
return {**out, 'log': out}
def train_dataloader(self):
return th.utils.data.DataLoader(
self.train_ds,
batch_size=FLAGS.batch_size,
drop_last=True,
shuffle=True,
)
def val_dataloader(self):
return th.utils.data.DataLoader(
self.test_ds,
batch_size=FLAGS.batch_size,
drop_last=False,
shuffle=True,
)
def configure_optimizers(self):
return th.optim.SGD(
self.parameters(),
lr=FLAGS.lr,
momentum=FLAGS.momentum,
)
def main(_):
model = IMDBSentimentClassifier()
trainer = pl.Trainer(
default_root_dir='logs',
gpus=(1 if th.cuda.is_available() else 0),
max_epochs=FLAGS.epochs,
fast_dev_run=FLAGS.debug,
logger=pl.loggers.TensorBoardLogger('logs/', name='imdb', version=0),
)
trainer.fit(model)
if __name__ == '__main__':
app.run(main)
|