igashov commited on
Commit
d1da608
1 Parent(s): b0ab0d5

fix size_nn

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. src/linker_size_lightning.py +6 -3
app.py CHANGED
@@ -72,7 +72,7 @@ print('Loaded diffusion model')
72
 
73
 
74
  def sample_fn(_data):
75
- output, _ = size_nn.forward(_data)
76
  probabilities = torch.softmax(output, dim=1)
77
  distribution = torch.distributions.Categorical(probs=probabilities)
78
  samples = distribution.sample()
 
72
 
73
 
74
  def sample_fn(_data):
75
+ output, _ = size_nn.forward(_data, return_loss=False)
76
  probabilities = torch.softmax(output, dim=1)
77
  distribution = torch.distributions.Categorical(probs=probabilities)
78
  samples = distribution.sample()
src/linker_size_lightning.py CHANGED
@@ -79,7 +79,7 @@ class SizeClassifier(pl.LightningModule):
79
  def test_dataloader(self):
80
  return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
81
 
82
- def forward(self, data):
83
  h = data['one_hot']
84
  x = data['positions']
85
  fragment_mask = data['fragment_mask']
@@ -103,8 +103,11 @@ class SizeClassifier(pl.LightningModule):
103
  output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
104
  output = output.view(bs, n_nodes, -1).mean(1)
105
 
106
- true = self.get_true_labels(linker_mask)
107
- loss = cross_entropy(output, true, weight=self.loss_weights)
 
 
 
108
 
109
  return output, loss
110
 
 
79
  def test_dataloader(self):
80
  return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
81
 
82
+ def forward(self, data, return_loss=True):
83
  h = data['one_hot']
84
  x = data['positions']
85
  fragment_mask = data['fragment_mask']
 
103
  output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
104
  output = output.view(bs, n_nodes, -1).mean(1)
105
 
106
+ if return_loss:
107
+ true = self.get_true_labels(linker_mask)
108
+ loss = cross_entropy(output, true, weight=self.loss_weights)
109
+ else:
110
+ loss = None
111
 
112
  return output, loss
113