igashov commited on
Commit
b0ab0d5
1 Parent(s): 7c181a3
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import os
3
  import torch
4
  import subprocess
@@ -109,9 +110,13 @@ def generate(input_file):
109
  except Exception as e:
110
  return f'Could not read the molecule: {e}'
111
 
 
 
 
112
  positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
113
- positions = torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device)
114
- one_hot = torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device)
 
115
  print('Read and parsed molecule')
116
 
117
  dataset = [{
@@ -120,9 +125,9 @@ def generate(input_file):
120
  'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
121
  'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
122
  'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
123
- 'anchors': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
124
- 'fragment_mask': torch.ones_like(charges, dtype=const.TORCH_FLOAT, device=device),
125
- 'linker_mask': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
126
  'num_atoms': len(positions),
127
  }]
128
  dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
 
1
  import gradio as gr
2
+ import numpy as np
3
  import os
4
  import torch
5
  import subprocess
 
110
  except Exception as e:
111
  return f'Could not read the molecule: {e}'
112
 
113
+ if molecule.GetNumAtoms() > 50:
114
+ return f'Too large molecule: upper limit is 50 heavy atoms'
115
+
116
  positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
117
+ anchors = np.zeros_like(charges)
118
+ fragment_mask = np.ones_like(charges)
119
+ linker_mask = np.zeros_like(charges)
120
  print('Read and parsed molecule')
121
 
122
  dataset = [{
 
125
  'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
126
  'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
127
  'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
128
+ 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
129
+ 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
130
+ 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
131
  'num_atoms': len(positions),
132
  }]
133
  dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)