igashov commited on
Commit
df1ea66
·
1 Parent(s): c381edc

Stop downloading models

Browse files
Files changed (1) hide show
  1. app.py +0 -10
app.py CHANGED
@@ -5,7 +5,6 @@ import gradio as gr
5
  import numpy as np
6
  import os
7
  import torch
8
- import subprocess
9
  import output
10
 
11
  from rdkit import Chem
@@ -53,24 +52,15 @@ args = parser.parse_args()
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  print(f'Device: {device}')
55
  os.makedirs("results", exist_ok=True)
56
- os.makedirs("models", exist_ok=True)
57
 
58
  size_gnn_path = 'models/geom_size_gnn.ckpt'
59
- if not os.path.exists(size_gnn_path):
60
- print('Downloading SizeGNN model...')
61
- link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
62
- subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
63
  size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
64
  print('Loaded SizeGNN model')
65
 
66
 
67
  diffusion_models = {}
68
  for model_name, metadata in MODELS_METADATA.items():
69
- link = metadata['link']
70
  diffusion_path = metadata['path']
71
- if not os.path.exists(diffusion_path):
72
- print(f'Downloading {model_name}...')
73
- subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
74
  diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
75
  print(f'Loaded model {model_name}')
76
 
 
5
  import numpy as np
6
  import os
7
  import torch
 
8
  import output
9
 
10
  from rdkit import Chem
 
52
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
  print(f'Device: {device}')
54
  os.makedirs("results", exist_ok=True)
 
55
 
56
  size_gnn_path = 'models/geom_size_gnn.ckpt'
 
 
 
 
57
  size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
58
  print('Loaded SizeGNN model')
59
 
60
 
61
  diffusion_models = {}
62
  for model_name, metadata in MODELS_METADATA.items():
 
63
  diffusion_path = metadata['path']
 
 
 
64
  diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
65
  print(f'Loaded model {model_name}')
66