File size: 464 Bytes
4c1e73e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
"""Contrastive Loss for Siamese networks"""
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, x1, x2, label):
dist = F.pairwise_distance(x1, x2)
loss = (1 - label) * torch.pow(dist, 2) + label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
return loss.mean()
|