AlienChen commited on
Commit
16339c9
·
verified ·
1 Parent(s): 13fd594

Update moppit.py

Browse files
Files changed (1) hide show
  1. moppit.py +5 -5
moppit.py CHANGED
@@ -22,7 +22,7 @@ if args.motifs:
22
  print(motifs)
23
 
24
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
25
- target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
26
 
27
  # Load Models
28
  solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
@@ -35,14 +35,14 @@ if 'Non-Fouling' in args.objectives:
35
  nonfouling_model = NonfoulingModel(device=device)
36
  score_models.append(nonfouling_model)
37
  if 'Solubility' in args.objectives:
38
- solubility_model = SolubilityModelNew(device=device)
39
  score_models.append(solubility_model)
40
  if 'Half-Life' in args.objectives:
41
  halflife_model = HalfLifeModel(device=device)
42
  score_models.append(halflife_model)
43
  if 'Affinity' in args.objectives:
44
- affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device)
45
- affinity_model = AffinityModel(affinity_predictor, target_sequence)
46
  score_models.append(affinity_model)
47
  if 'Motif' in args.objectives or 'Specificity' in args.objectives:
48
  bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
@@ -50,7 +50,7 @@ if 'Motif' in args.objectives or 'Specificity' in args.objectives:
50
  motif_penalty = True
51
  else:
52
  motif_penalty = False
53
- motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=motif_penalty)
54
  score_models.append(motif_model)
55
 
56
  objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
 
22
  print(motifs)
23
 
24
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
25
+ target_sequence = tokenizer(target, return_tensors='pt').to(device)
26
 
27
  # Load Models
28
  solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
 
35
  nonfouling_model = NonfoulingModel(device=device)
36
  score_models.append(nonfouling_model)
37
  if 'Solubility' in args.objectives:
38
+ solubility_model = SolubilityModel(device=device)
39
  score_models.append(solubility_model)
40
  if 'Half-Life' in args.objectives:
41
  halflife_model = HalfLifeModel(device=device)
42
  score_models.append(halflife_model)
43
  if 'Affinity' in args.objectives:
44
+ affinity_predictor = load_affinity_predictor('../classifier_ckpt/wt_affinity.pt', device)
45
+ affinity_model = AffinityModel(affinity_predictor, target_sequence, device)
46
  score_models.append(affinity_model)
47
  if 'Motif' in args.objectives or 'Specificity' in args.objectives:
48
  bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
 
50
  motif_penalty = True
51
  else:
52
  motif_penalty = False
53
+ motif_model = MotifModel(bindevaluator, target_sequence['input_ids'], motifs, penalty=motif_penalty)
54
  score_models.append(motif_model)
55
 
56
  objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'