igashov commited on
Commit
c438a2a
1 Parent(s): bec2844

update COM

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -160,6 +160,12 @@ def generate(input_file, n_steps):
160
  print('Generated linker')
161
  x = chain[0][:, :, :ddpm.n_dims]
162
  h = chain[0][:, :, ddpm.n_dims:]
 
 
 
 
 
 
163
  names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
164
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
165
  print('Saved XYZ files')
 
160
  print('Generated linker')
161
  x = chain[0][:, :, :ddpm.n_dims]
162
  h = chain[0][:, :, ddpm.n_dims:]
163
+
164
+ pos_masked = data['positions'] * data['fragment_mask']
165
+ N = data['fragment_mask'].sum(1, keepdims=True)
166
+ mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
167
+ x = x + mean * node_mask
168
+
169
  names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
170
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
171
  print('Saved XYZ files')