Hannes Kuchelmeister
commited on
Commit
•
ba9c868
1
Parent(s):
ad947ed
Add code so loss function uses torch.Size([x,1]) instead of torch.Size([x])
Browse filesThis was done to prevent the error message:
"Using a target size (torch.Size([64])) that is different to the input size (torch.Size([64, 1]))".
src/models/focus_module.py
CHANGED
@@ -85,7 +85,7 @@ class FocusLitModule(LightningModule):
|
|
85 |
x = batch["image"]
|
86 |
y = batch["focus_value"]
|
87 |
logits = self.forward(x)
|
88 |
-
loss = self.criterion(logits, y)
|
89 |
preds = torch.squeeze(logits)
|
90 |
return loss, preds, y
|
91 |
|
@@ -210,7 +210,7 @@ class FocusMSELitModule(LightningModule):
|
|
210 |
x = batch["image"]
|
211 |
y = batch["focus_value"]
|
212 |
logits = self.forward(x)
|
213 |
-
loss = self.criterion(logits, y)
|
214 |
preds = torch.squeeze(logits)
|
215 |
return loss, preds, y
|
216 |
|
|
|
85 |
x = batch["image"]
|
86 |
y = batch["focus_value"]
|
87 |
logits = self.forward(x)
|
88 |
+
loss = self.criterion(logits, y.unsqueeze(1))
|
89 |
preds = torch.squeeze(logits)
|
90 |
return loss, preds, y
|
91 |
|
|
|
210 |
x = batch["image"]
|
211 |
y = batch["focus_value"]
|
212 |
logits = self.forward(x)
|
213 |
+
loss = self.criterion(logits, y.unsqueeze(1))
|
214 |
preds = torch.squeeze(logits)
|
215 |
return loss, preds, y
|
216 |
|