Simon Duerr commited on
Commit
22e3abd
1 Parent(s): 8a361d8

fix: train path, update draw samples

Browse files
README.md CHANGED
@@ -10,15 +10,15 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- # protpardelle WebDemo
14
 
15
  Code for the paper: [An all-atom protein generative model](https://www.biorxiv.org/content/10.1101/2023.05.24.542194v1.full).
16
 
17
  The code is under active development and we welcome contributions, feature requests, issues, corrections, and any questions! Where we have used or adapted code from others we have tried to give proper attribution, but please let us know if anything should be corrected.
18
 
19
- ## Environment
20
 
21
- To set up the conda environment, run `conda env create -f configs/environment.yml`.
22
 
23
  ## Inference
24
 
@@ -28,12 +28,16 @@ To draw 8 samples per length for lengths in `range(70, 150, 5)` from the backbon
28
 
29
  `python draw_samples.py --type backbone --param n_steps --paramval 100 --minlen 70 --maxlen 150 --steplen 5 --perlen 8`
30
 
31
- We have also added the ability to provide an input PDB file and a list of (zero-indexed) indices to condition on from the PDB file. We can expect it to do better or worse depending on the problem (better on easier problems such as inpainting, worse on difficult problems such as discontiguous sidechain-only scaffolding).
32
 
33
- `python draw_samples.py --input_pdb --cond_idxs 0-25,40-80`
 
 
34
 
35
  ## Training
36
 
 
 
37
  Pretrained model weights are provided, but if you are interested in training your own models, we have provided training code together with some basic online evaluation. You will need to create a Weights & Biases account.
38
 
39
  The dataset can be downloaded from [CATH](http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/), and the train/validation/test splits used can be downloaded with
 
10
  license: mit
11
  ---
12
 
13
+ # protpardelle
14
 
15
  Code for the paper: [An all-atom protein generative model](https://www.biorxiv.org/content/10.1101/2023.05.24.542194v1.full).
16
 
17
  The code is under active development and we welcome contributions, feature requests, issues, corrections, and any questions! Where we have used or adapted code from others we have tried to give proper attribution, but please let us know if anything should be corrected.
18
 
19
+ ## Environment and setup
20
 
21
+ To set up the conda environment, run `conda env create -f configs/environment.yml` then `conda activate delle`. You will also need to clone the [ProteinMPNN repository](https://github.com/dauparas/ProteinMPNN) to the same directory that contains the `protpardelle/` repository. You may also need to set the `home_dir` variable in the configs you use to the path to the directory containing the `protpardelle/` directory.
22
 
23
  ## Inference
24
 
 
28
 
29
  `python draw_samples.py --type backbone --param n_steps --paramval 100 --minlen 70 --maxlen 150 --steplen 5 --perlen 8`
30
 
31
+ We have also added the ability to provide an input PDB file and a list of (zero-indexed) indices to condition on from the PDB file. Note also that current models are single-chain only, so multi-chain PDBs will be treated as single chains (we intend to release multi-chain models in a later update). We can expect it to do better or worse depending on the problem (better on easier problems such as inpainting, worse on difficult problems such as discontiguous scaffolding). Use this command to resample the first 25 and 71st to 80th residues of `my_pdb.pdb`.
32
 
33
+ `python draw_samples.py --input_pdb my_pdb.pdb --resample_idxs 0-25,70-80`
34
+
35
+ For more control over the sampling process, including tweaking the sampling hyperparameters and more specific methods of conditioning, you can directly interface with the `model.sample()` function; we have provided examples of how to configure and run these commands in `sampling.py`.
36
 
37
  ## Training
38
 
39
+ Note (Sep 2023): the lab has decided to collect usage statistics on people interested in training their own versions of Protpardelle (for funding and other purposes). To obtain a copy of the repository with training code, please complete [this Google Form](https://docs.google.com/forms/d/1WKMVbydLh6LIegc3HfwMQhgL2_qnrY7ks9FM_ylo4ts) - you will receive a link to a Google Drive zip which contains the repository with training code. After publication, the plan is to include the full training code directly in this repository.
40
+
41
  Pretrained model weights are provided, but if you are interested in training your own models, we have provided training code together with some basic online evaluation. You will need to create a Weights & Biases account.
42
 
43
  The dataset can be downloaded from [CATH](http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/), and the train/validation/test splits used can be downloaded with
app.py CHANGED
@@ -303,15 +303,15 @@ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, step
303
  if args.type == "backbone":
304
  if args.model_checkpoint:
305
  checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
306
- cfg_path = f"{args.model_checkpoint}/backbone.yml"
307
  else:
308
  checkpoint = (
309
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
310
  )
311
  cfg_path = f"{model_directory}/configs/backbone.yml"
312
- cfg = utils.load_config(cfg_path)
313
  weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
314
- model = models.Protpardelle(cfg, device=device)
315
  model.load_state_dict(weights)
316
  model.to(device)
317
  model.eval()
@@ -319,7 +319,7 @@ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, step
319
  elif args.type == "allatom":
320
  if args.model_checkpoint:
321
  checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
322
- cfg_path = f"{args.model_checkpoint}/allatom.yml"
323
  else:
324
  checkpoint = (
325
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
@@ -345,6 +345,9 @@ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, step
345
  for k, v in sampling_kwargs_readme:
346
  f.write(f"{k}\t{v}\n")
347
 
 
 
 
348
  # Draw samples
349
  output_files = draw_and_save_samples(
350
  model,
 
303
  if args.type == "backbone":
304
  if args.model_checkpoint:
305
  checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
306
+ cfg_path = f"{args.model_checkpoint}/backbone_pretrained.yml"
307
  else:
308
  checkpoint = (
309
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
310
  )
311
  cfg_path = f"{model_directory}/configs/backbone.yml"
312
+ config = utils.load_config(cfg_path)
313
  weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
314
+ model = models.Protpardelle(config, device=device)
315
  model.load_state_dict(weights)
316
  model.to(device)
317
  model.eval()
 
319
  elif args.type == "allatom":
320
  if args.model_checkpoint:
321
  checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
322
+ cfg_path = f"{args.model_checkpoint}/allatom_pretrained.yml"
323
  else:
324
  checkpoint = (
325
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
 
345
  for k, v in sampling_kwargs_readme:
346
  f.write(f"{k}\t{v}\n")
347
 
348
+ print(f"Model loaded from {checkpoint}")
349
+ print(f"Beginning sampling for {date_string}...")
350
+
351
  # Draw samples
352
  output_files = draw_and_save_samples(
353
  model,
checkpoints/allatom.yml CHANGED
@@ -1,5 +1,5 @@
1
  train:
2
- home_dir: '/home/user/app'
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
 
1
  train:
2
+ home_dir: ''
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
checkpoints/backbone.yml CHANGED
@@ -1,5 +1,5 @@
1
  train:
2
- home_dir: '/home/user/app'
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
 
1
  train:
2
+ home_dir: ''
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
configs/allatom.yml CHANGED
@@ -1,5 +1,5 @@
1
  train:
2
- home_dir: '/home/user/app'
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
 
1
  train:
2
+ home_dir: ''
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
configs/backbone.yml CHANGED
@@ -1,5 +1,5 @@
1
  train:
2
- home_dir: '/home/user/app'
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
 
1
  train:
2
+ home_dir: ''
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
configs/seqdes.yml CHANGED
@@ -1,5 +1,5 @@
1
  train:
2
- home_dir: '/home/user/app'
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
 
1
  train:
2
+ home_dir: ''
3
  seed: 0
4
  checkpoint: ['', 0]
5
  batch_size: 32
core/protein_mpnn.py CHANGED
@@ -55,10 +55,11 @@ def get_mpnn_model(model_name='v_48_020', path_to_model_weights='', ca_only=Fals
55
  else:
56
  file_path = os.path.realpath(__file__)
57
  k = file_path.rfind("/")
 
58
  if ca_only:
59
- model_folder_path = file_path[:k] + '/ca_model_weights/'
60
  else:
61
- model_folder_path = file_path[:k] + '/vanilla_model_weights/'
62
 
63
  checkpoint_path = model_folder_path + f'{model_name}.pt'
64
  checkpoint = torch.load(checkpoint_path, map_location=device)
@@ -450,7 +451,6 @@ def run_proteinmpnn(model=None, pdb_path='', pdb_path_chains='', path_to_model_w
450
  print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
451
  if write_output_files:
452
  f.close()
453
-
454
  return new_mpnn_seqs
455
 
456
 
 
55
  else:
56
  file_path = os.path.realpath(__file__)
57
  k = file_path.rfind("/")
58
+ k = file_path[:k].rfind("/")
59
  if ca_only:
60
+ model_folder_path = file_path[:k] + '/ProteinMPNN/ca_model_weights/'
61
  else:
62
+ model_folder_path = file_path[:k] + '/ProteinMPNN/vanilla_model_weights/'
63
 
64
  checkpoint_path = model_folder_path + f'{model_name}.pt'
65
  checkpoint = torch.load(checkpoint_path, map_location=device)
 
451
  print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
452
  if write_output_files:
453
  f.close()
 
454
  return new_mpnn_seqs
455
 
456
 
draw_samples.py CHANGED
@@ -122,18 +122,18 @@ class Manager(object):
122
  "--perlen", type=int, default=2, help="How many samples per sequence length"
123
  )
124
  self.parser.add_argument(
125
- "--minlen", type=int, required=False, help="Minimum sequence length"
126
  )
127
  self.parser.add_argument(
128
  "--maxlen",
129
  type=int,
130
- required=False,
131
  help="Maximum sequence length, not inclusive",
132
  )
133
  self.parser.add_argument(
134
  "--steplen",
135
  type=int,
136
- required=False,
137
  help="How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
138
  )
139
  self.parser.add_argument(
@@ -279,15 +279,15 @@ def main():
279
  if args.type == "backbone":
280
  if args.model_checkpoint:
281
  checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
282
- cfg_path = f"{args.model_checkpoint}/backbone.yml"
283
  else:
284
  checkpoint = (
285
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
286
  )
287
  cfg_path = f"{model_directory}/configs/backbone.yml"
288
- cfg = utils.load_config(cfg_path)
289
  weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
290
- model = models.Protpardelle(cfg, device=device)
291
  model.load_state_dict(weights)
292
  model.to(device)
293
  model.eval()
@@ -295,7 +295,7 @@ def main():
295
  elif args.type == "allatom":
296
  if args.model_checkpoint:
297
  checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
298
- cfg_path = f"{args.model_checkpoint}/allatom.yml"
299
  else:
300
  checkpoint = (
301
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
@@ -310,8 +310,11 @@ def main():
310
  model.eval()
311
  model.device = device
312
 
 
 
 
313
  # Sampling
314
- with open(base_dir + "/readme.txt", "w") as f:
315
  f.write(f"Sampling run for {date_string}\n")
316
  f.write(f"Random seed {seed}\n")
317
  f.write(f"Model checkpoint: {checkpoint}\n")
@@ -341,7 +344,7 @@ def main():
341
  print(f"Of this, {sampling_time} seconds were for actual sampling.")
342
  print(f"{total_num_samples} total samples were drawn.")
343
 
344
- with open(base_dir + "/readme.txt", "a") as f:
345
  f.write(f"Total job time: {time_elapsed} seconds\n")
346
  f.write(f"Model run time: {sampling_time} seconds\n")
347
  f.write(f"Total samples drawn: {total_num_samples}\n")
 
122
  "--perlen", type=int, default=2, help="How many samples per sequence length"
123
  )
124
  self.parser.add_argument(
125
+ "--minlen", type=int, default=50, help="Minimum sequence length"
126
  )
127
  self.parser.add_argument(
128
  "--maxlen",
129
  type=int,
130
+ default=60,
131
  help="Maximum sequence length, not inclusive",
132
  )
133
  self.parser.add_argument(
134
  "--steplen",
135
  type=int,
136
+ default=5,
137
  help="How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
138
  )
139
  self.parser.add_argument(
 
279
  if args.type == "backbone":
280
  if args.model_checkpoint:
281
  checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
282
+ cfg_path = f"{args.model_checkpoint}/backbone_pretrained.yml"
283
  else:
284
  checkpoint = (
285
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
286
  )
287
  cfg_path = f"{model_directory}/configs/backbone.yml"
288
+ config = utils.load_config(cfg_path)
289
  weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
290
+ model = models.Protpardelle(config, device=device)
291
  model.load_state_dict(weights)
292
  model.to(device)
293
  model.eval()
 
295
  elif args.type == "allatom":
296
  if args.model_checkpoint:
297
  checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
298
+ cfg_path = f"{args.model_checkpoint}/allatom_pretrained.yml"
299
  else:
300
  checkpoint = (
301
  f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
 
310
  model.eval()
311
  model.device = device
312
 
313
+ if config.train.home_dir == '':
314
+ config.train.home_dir = os.getcwd()
315
+
316
  # Sampling
317
+ with open(save_dir + "/readme.txt", "w") as f:
318
  f.write(f"Sampling run for {date_string}\n")
319
  f.write(f"Random seed {seed}\n")
320
  f.write(f"Model checkpoint: {checkpoint}\n")
 
344
  print(f"Of this, {sampling_time} seconds were for actual sampling.")
345
  print(f"{total_num_samples} total samples were drawn.")
346
 
347
+ with open(save_dir + "/readme.txt", "a") as f:
348
  f.write(f"Total job time: {time_elapsed} seconds\n")
349
  f.write(f"Model run time: {sampling_time} seconds\n")
350
  f.write(f"Total samples drawn: {total_num_samples}\n")
protpardelle_pymol.py CHANGED
@@ -15,9 +15,9 @@ except ImportError:
15
 
16
 
17
  if os.environ.get("GRADIO_LOCAL") != None:
18
- public_link = "http://127.0.0.1:7862"
19
  else:
20
- public_link = "spacesplaceholder"
21
 
22
 
23
 
@@ -140,6 +140,16 @@ def query_protpardelle_uncond(
140
 
141
 
142
  def setprotpardellelink(link:str):
 
 
 
 
 
 
 
 
 
 
143
  global public_link
144
  try:
145
  client = Client(link)
 
15
 
16
 
17
  if os.environ.get("GRADIO_LOCAL") != None:
18
+ public_link = "http://127.0.0.1:7860"
19
  else:
20
+ public_link = "ProteinDesignLab/protpardelle"
21
 
22
 
23
 
 
140
 
141
 
142
  def setprotpardellelink(link:str):
143
+ """
144
+ AUTHOR
145
+ Simon Duerr
146
+ https://twitter.com/simonduerr
147
+ DESCRIPTION
148
+ Set a public link to use a locally hosted version of this space
149
+ USAGE
150
+ protpardelle_setlink link_or_username/spacename
151
+ """
152
+
153
  global public_link
154
  try:
155
  client = Client(link)