Thomas Wolf commited on
Commit
0fa5994
2 Parent(s): 7b8e6d6 fad545a

Merge pull request #7 from pushpankar/master

Browse files

Add support for finetuning more than 2 classes

Files changed (1) hide show
  1. torchmoji/finetuning.py +21 -9
torchmoji/finetuning.py CHANGED
@@ -14,6 +14,7 @@ import numpy as np
14
  import torch
15
  import torch.nn as nn
16
  import torch.optim as optim
 
17
  from torch.autograd import Variable
18
  from torch.utils.data import Dataset, DataLoader
19
  from torch.utils.data.sampler import BatchSampler, SequentialSampler
@@ -360,15 +361,18 @@ def evaluate_using_acc(model, test_gen):
360
 
361
  # Validate on test_data
362
  model.eval()
363
- correct_count = 0.0
364
- total_y = sum(len(y) for _, y in test_gen)
365
  for i, data in enumerate(test_gen):
366
  x, y = data
367
  outs = model(x)
368
- pred = (outs >= 0).long()
369
- added_counts = (pred == y).double().sum()
370
- correct_count += added_counts
371
- return correct_count/total_y
 
 
 
 
372
 
373
 
374
  def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
@@ -486,6 +490,14 @@ def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs,
486
  if verbose >= 2:
487
  print("Loaded weights from {}".format(checkpoint_path))
488
 
 
 
 
 
 
 
 
 
489
  def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
490
  checkpoint_path, patience):
491
  """ Analog to Keras fit_generator function.
@@ -509,7 +521,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
509
  torch.save(model.state_dict(), checkpoint_path)
510
 
511
  model.eval()
512
- best_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen])
513
  print("original val loss", best_loss)
514
 
515
  epoch_without_impr = 0
@@ -521,7 +533,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
521
  model.train()
522
  optim_op.zero_grad()
523
  output = model(X_train)
524
- loss = loss_op(output, y_train.float())
525
  loss.backward()
526
  clip_grad_norm(model.parameters(), 1)
527
  optim_op.step()
@@ -533,7 +545,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
533
  acc = evaluate_using_acc(model, val_gen)
534
  print("val acc", acc)
535
 
536
- val_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen])
537
  print("val loss", val_loss)
538
  if best_loss is not None and val_loss >= best_loss:
539
  epoch_without_impr += 1
 
14
  import torch
15
  import torch.nn as nn
16
  import torch.optim as optim
17
+ from sklearn.metrics import accuracy_score
18
  from torch.autograd import Variable
19
  from torch.utils.data import Dataset, DataLoader
20
  from torch.utils.data.sampler import BatchSampler, SequentialSampler
 
361
 
362
  # Validate on test_data
363
  model.eval()
364
+ accs = []
 
365
  for i, data in enumerate(test_gen):
366
  x, y = data
367
  outs = model(x)
368
+ if model.nb_classes > 2:
369
+ pred = torch.max(outs, 1)[1]
370
+ acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
371
+ else:
372
+ pred = (outs >= 0).long()
373
+ acc = (pred == y).double().sum() / len(pred)
374
+ accs.append(acc)
375
+ return np.mean(accs)
376
 
377
 
378
  def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
 
490
  if verbose >= 2:
491
  print("Loaded weights from {}".format(checkpoint_path))
492
 
493
+
494
+ def calc_loss(loss_op, pred, yv):
495
+ if type(loss_op) is nn.CrossEntropyLoss:
496
+ return loss_op(pred.squeeze(), yv.squeeze())
497
+ else:
498
+ return loss_op(pred.squeeze(), yv.squeeze().float())
499
+
500
+
501
  def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
502
  checkpoint_path, patience):
503
  """ Analog to Keras fit_generator function.
 
521
  torch.save(model.state_dict(), checkpoint_path)
522
 
523
  model.eval()
524
+ best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
525
  print("original val loss", best_loss)
526
 
527
  epoch_without_impr = 0
 
533
  model.train()
534
  optim_op.zero_grad()
535
  output = model(X_train)
536
+ loss = calc_loss(loss_op, output, y_train)
537
  loss.backward()
538
  clip_grad_norm(model.parameters(), 1)
539
  optim_op.step()
 
545
  acc = evaluate_using_acc(model, val_gen)
546
  print("val acc", acc)
547
 
548
+ val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
549
  print("val loss", val_loss)
550
  if best_loss is not None and val_loss >= best_loss:
551
  epoch_without_impr += 1