loss bug fix
Browse files- src/loss.py +4 -4
- src/models.py +6 -0
src/loss.py
CHANGED
@@ -15,7 +15,7 @@ def contrastive_sigmoid_loss(logits):
|
|
15 |
class CLIPLoss(nn.Module):
|
16 |
def __init__(self, logit_temperature: float = -1.0):
|
17 |
super().__init__()
|
18 |
-
self.logit_temperature = nn.Parameter(logit_temperature)
|
19 |
|
20 |
def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
|
21 |
temperature = self.logit_temperature.sigmoid()
|
@@ -29,7 +29,7 @@ class CLIPLoss(nn.Module):
|
|
29 |
class CyCLIP(nn.Module):
|
30 |
def __init__(self, logit_temperature: float = -1.0):
|
31 |
super().__init__()
|
32 |
-
self.logit_temperature = nn.Parameter(logit_temperature)
|
33 |
self.lambda_1: float = 1.0
|
34 |
self.lambda_2: float = 1.0
|
35 |
|
@@ -54,7 +54,7 @@ class CyCLIP(nn.Module):
|
|
54 |
class SigLIPLoss(nn.Module):
|
55 |
def __init__(self, logit_temperature: float = -1.0):
|
56 |
super().__init__()
|
57 |
-
self.logit_temperature = nn.Parameter(logit_temperature)
|
58 |
|
59 |
def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
|
60 |
temperature = self.logit_temperature.sigmoid()
|
@@ -65,7 +65,7 @@ class SigLIPLoss(nn.Module):
|
|
65 |
class CySigLIPLoss(nn.Module):
|
66 |
def __init__(self, logit_temperature: float = -1.0):
|
67 |
super().__init__()
|
68 |
-
self.logit_temperature = nn.Parameter(logit_temperature)
|
69 |
self.lambda_1: float = 1.0
|
70 |
self.lambda_2: float = 1.0
|
71 |
|
|
|
15 |
class CLIPLoss(nn.Module):
|
16 |
def __init__(self, logit_temperature: float = -1.0):
|
17 |
super().__init__()
|
18 |
+
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
|
19 |
|
20 |
def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
|
21 |
temperature = self.logit_temperature.sigmoid()
|
|
|
29 |
class CyCLIP(nn.Module):
|
30 |
def __init__(self, logit_temperature: float = -1.0):
|
31 |
super().__init__()
|
32 |
+
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
|
33 |
self.lambda_1: float = 1.0
|
34 |
self.lambda_2: float = 1.0
|
35 |
|
|
|
54 |
class SigLIPLoss(nn.Module):
|
55 |
def __init__(self, logit_temperature: float = -1.0):
|
56 |
super().__init__()
|
57 |
+
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
|
58 |
|
59 |
def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
|
60 |
temperature = self.logit_temperature.sigmoid()
|
|
|
65 |
class CySigLIPLoss(nn.Module):
|
66 |
def __init__(self, logit_temperature: float = -1.0):
|
67 |
super().__init__()
|
68 |
+
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
|
69 |
self.lambda_1: float = 1.0
|
70 |
self.lambda_2: float = 1.0
|
71 |
|
src/models.py
CHANGED
@@ -119,3 +119,9 @@ class TinyCLIP(PreTrainedModel):
|
|
119 |
out["loss"] = self.loss_fn(vision_output, text_output)
|
120 |
|
121 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
out["loss"] = self.loss_fn(vision_output, text_output)
|
120 |
|
121 |
return out
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
model = TinyCLIP(TinyCLIPConfig())
|
126 |
+
print(model)
|
127 |
+
print("Done!")
|