Spaces:
Starting
on
A10G
Starting
on
A10G
igashov
commited on
Commit
•
b0ab0d5
1
Parent(s):
7c181a3
updates
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import os
|
3 |
import torch
|
4 |
import subprocess
|
@@ -109,9 +110,13 @@ def generate(input_file):
|
|
109 |
except Exception as e:
|
110 |
return f'Could not read the molecule: {e}'
|
111 |
|
|
|
|
|
|
|
112 |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
113 |
-
|
114 |
-
|
|
|
115 |
print('Read and parsed molecule')
|
116 |
|
117 |
dataset = [{
|
@@ -120,9 +125,9 @@ def generate(input_file):
|
|
120 |
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
121 |
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
122 |
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
123 |
-
'anchors': torch.
|
124 |
-
'fragment_mask': torch.
|
125 |
-
'linker_mask': torch.
|
126 |
'num_atoms': len(positions),
|
127 |
}]
|
128 |
dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
import os
|
4 |
import torch
|
5 |
import subprocess
|
|
|
110 |
except Exception as e:
|
111 |
return f'Could not read the molecule: {e}'
|
112 |
|
113 |
+
if molecule.GetNumAtoms() > 50:
|
114 |
+
return f'Too large molecule: upper limit is 50 heavy atoms'
|
115 |
+
|
116 |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
117 |
+
anchors = np.zeros_like(charges)
|
118 |
+
fragment_mask = np.ones_like(charges)
|
119 |
+
linker_mask = np.zeros_like(charges)
|
120 |
print('Read and parsed molecule')
|
121 |
|
122 |
dataset = [{
|
|
|
125 |
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
126 |
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
127 |
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
128 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
129 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
130 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
131 |
'num_atoms': len(positions),
|
132 |
}]
|
133 |
dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
|