Merge pull request #7 from pushpankar/master
Browse filesAdd support for finetuning more than 2 classes
- torchmoji/finetuning.py +21 -9
torchmoji/finetuning.py
CHANGED
@@ -14,6 +14,7 @@ import numpy as np
|
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
import torch.optim as optim
|
|
|
17 |
from torch.autograd import Variable
|
18 |
from torch.utils.data import Dataset, DataLoader
|
19 |
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
@@ -360,15 +361,18 @@ def evaluate_using_acc(model, test_gen):
|
|
360 |
|
361 |
# Validate on test_data
|
362 |
model.eval()
|
363 |
-
|
364 |
-
total_y = sum(len(y) for _, y in test_gen)
|
365 |
for i, data in enumerate(test_gen):
|
366 |
x, y = data
|
367 |
outs = model(x)
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
|
|
|
|
|
|
|
|
372 |
|
373 |
|
374 |
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
@@ -486,6 +490,14 @@ def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs,
|
|
486 |
if verbose >= 2:
|
487 |
print("Loaded weights from {}".format(checkpoint_path))
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
490 |
checkpoint_path, patience):
|
491 |
""" Analog to Keras fit_generator function.
|
@@ -509,7 +521,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
|
509 |
torch.save(model.state_dict(), checkpoint_path)
|
510 |
|
511 |
model.eval()
|
512 |
-
best_loss = np.mean([loss_op
|
513 |
print("original val loss", best_loss)
|
514 |
|
515 |
epoch_without_impr = 0
|
@@ -521,7 +533,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
|
521 |
model.train()
|
522 |
optim_op.zero_grad()
|
523 |
output = model(X_train)
|
524 |
-
loss = loss_op
|
525 |
loss.backward()
|
526 |
clip_grad_norm(model.parameters(), 1)
|
527 |
optim_op.step()
|
@@ -533,7 +545,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
|
533 |
acc = evaluate_using_acc(model, val_gen)
|
534 |
print("val acc", acc)
|
535 |
|
536 |
-
val_loss = np.mean([loss_op
|
537 |
print("val loss", val_loss)
|
538 |
if best_loss is not None and val_loss >= best_loss:
|
539 |
epoch_without_impr += 1
|
|
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
import torch.optim as optim
|
17 |
+
from sklearn.metrics import accuracy_score
|
18 |
from torch.autograd import Variable
|
19 |
from torch.utils.data import Dataset, DataLoader
|
20 |
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
|
|
361 |
|
362 |
# Validate on test_data
|
363 |
model.eval()
|
364 |
+
accs = []
|
|
|
365 |
for i, data in enumerate(test_gen):
|
366 |
x, y = data
|
367 |
outs = model(x)
|
368 |
+
if model.nb_classes > 2:
|
369 |
+
pred = torch.max(outs, 1)[1]
|
370 |
+
acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
|
371 |
+
else:
|
372 |
+
pred = (outs >= 0).long()
|
373 |
+
acc = (pred == y).double().sum() / len(pred)
|
374 |
+
accs.append(acc)
|
375 |
+
return np.mean(accs)
|
376 |
|
377 |
|
378 |
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
|
|
490 |
if verbose >= 2:
|
491 |
print("Loaded weights from {}".format(checkpoint_path))
|
492 |
|
493 |
+
|
494 |
+
def calc_loss(loss_op, pred, yv):
|
495 |
+
if type(loss_op) is nn.CrossEntropyLoss:
|
496 |
+
return loss_op(pred.squeeze(), yv.squeeze())
|
497 |
+
else:
|
498 |
+
return loss_op(pred.squeeze(), yv.squeeze().float())
|
499 |
+
|
500 |
+
|
501 |
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
502 |
checkpoint_path, patience):
|
503 |
""" Analog to Keras fit_generator function.
|
|
|
521 |
torch.save(model.state_dict(), checkpoint_path)
|
522 |
|
523 |
model.eval()
|
524 |
+
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
|
525 |
print("original val loss", best_loss)
|
526 |
|
527 |
epoch_without_impr = 0
|
|
|
533 |
model.train()
|
534 |
optim_op.zero_grad()
|
535 |
output = model(X_train)
|
536 |
+
loss = calc_loss(loss_op, output, y_train)
|
537 |
loss.backward()
|
538 |
clip_grad_norm(model.parameters(), 1)
|
539 |
optim_op.step()
|
|
|
545 |
acc = evaluate_using_acc(model, val_gen)
|
546 |
print("val acc", acc)
|
547 |
|
548 |
+
val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
|
549 |
print("val loss", val_loss)
|
550 |
if best_loss is not None and val_loss >= best_loss:
|
551 |
epoch_without_impr += 1
|