Niksa Praljak commited on
Commit
66d2e5f
1 Parent(s): d5de529

Update PenCL argparse and Finish Facilitator script

Browse files
README.md CHANGED
@@ -62,7 +62,8 @@ cd BioM3_PenCL
62
  ```bash
63
  python run_PenCL_inference.py \
64
  --json_path "stage1_config.json" \
65
- --model_path "./weights/PenCL/BioM3_PenCL_epoch20.bin"
 
66
  ```
67
 
68
  ### Example Input Data
 
62
  ```bash
63
  python run_PenCL_inference.py \
64
  --json_path "stage1_config.json" \
65
+ --model_path "./weights/PenCL/BioM3_PenCL_epoch20.bin" \
66
+ --output_path "test_PenCL_embeddings.pt"
67
  ```
68
 
69
  ### Example Input Data
run_Facilitator_sample.py CHANGED
@@ -1,117 +1,103 @@
 
1
  import yaml
2
  from argparse import Namespace
3
  import json
4
  import pandas as pd
5
-
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
-
10
- import pytorch_lightning as pl
11
- import Stage1_source.preprocess as prep
12
  import Stage1_source.model as mod
13
- import Stage1_source.PL_wrapper as PL_wrap
14
-
15
 
16
- # Step 1: Load JSON configuration
17
  def load_json_config(json_path):
18
- """
19
- Load JSON configuration file.
20
- """
21
  with open(json_path, "r") as f:
22
  config = json.load(f)
23
- # print("Loaded JSON config:", config)
24
  return config
25
 
26
  # Step 2: Convert JSON dictionary to Namespace
27
  def convert_to_namespace(config_dict):
28
- """
29
- Recursively convert a dictionary to an argparse Namespace.
30
- """
31
  for key, value in config_dict.items():
32
- if isinstance(value, dict): # Recursively handle nested dictionaries
33
  config_dict[key] = convert_to_namespace(value)
34
  return Namespace(**config_dict)
35
 
36
- def prepare_model(args) ->nn.Module:
37
- """
38
- Prepare the model and PyTorch Lightning Trainer using a flat args object.
39
- """
40
  model = mod.Facilitator(
41
- in_dim=args.emb_dim,
42
- hid_dim=args.hid_dim,
43
- out_dim=args.emb_dim,
44
- dropout=args.dropout
45
  )
46
- weights_path = f"{save_dir}/BioM3_Facilitator_epoch20.bin"# BioM3_PenCL_epoch20.bin"
47
- model.load_state_dict(torch.load(weights_path, map_location="cpu"))
48
  model.eval()
49
  print("Model loaded successfully with weights!")
50
  return model
51
 
 
52
  def compute_mmd_loss(x, y, kernel="rbf", sigma=1.0):
53
- """
54
- Compute the MMD loss between two sets of embeddings.
55
- Args:
56
- x: Tensor of shape [N, D]
57
- y: Tensor of shape [N, D]
58
- kernel: Kernel function, default is 'rbf' (Gaussian kernel)
59
- sigma: Bandwidth for the Gaussian kernel
60
- """
61
  def rbf_kernel(a, b, sigma):
62
- """
63
- Compute the RBF kernel between two tensors.
64
- """
65
  pairwise_distances = torch.cdist(a, b, p=2) ** 2
66
  return torch.exp(-pairwise_distances / (2 * sigma ** 2))
67
 
68
- # Compute RBF kernel matrices
69
- K_xx = rbf_kernel(x, x, sigma) # Kernel within x
70
- K_yy = rbf_kernel(y, y, sigma) # Kernel within y
71
- K_xy = rbf_kernel(x, y, sigma) # Kernel between x and y
72
 
73
- # Compute MMD loss
74
  mmd_loss = K_xx.mean() - 2 * K_xy.mean() + K_yy.mean()
75
  return mmd_loss
76
 
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == '__main__':
79
-
80
- json_path = f"{save_dir}/stage2_config.json"
81
- # Load and convert JSON config
82
- json_path = f"{save_dir}/stage2_config.json"
83
- config_dict = load_json_config(json_path)
84
- args = convert_to_namespace(config_dict)
85
 
86
- # load model
87
- model = prepare_model(args=args)
88
-
89
- # load test dataset
90
- embedding_dataset = torch.load('./PenCL_test_outputs.pt')
91
 
92
- # Run inference and store z_t, z_p
 
93
 
 
94
  with torch.no_grad():
95
  z_t = embedding_dataset['z_t']
96
- z_p = embedding_dataset['z_p']
97
  z_c = model(z_t)
98
  embedding_dataset['z_c'] = z_c
99
 
100
- # Compute MSE between embeddings
101
- mse_zc_zp = F.mse_loss(z_c, z_p) # MSE between facilitated embeddings and protein embeddings
102
- mse_zt_zp = F.mse_loss(z_t, z_p) # MSE between text embeddings and protein embeddings
103
-
104
- # Compute Norms (L2 magnitudes) for a given batch (e.g., first 5 embeddings)
 
105
  batch_idx = 0
106
  norm_z_t = torch.norm(z_t[batch_idx], p=2).item()
107
  norm_z_p = torch.norm(z_p[batch_idx], p=2).item()
108
  norm_z_c = torch.norm(z_c[batch_idx], p=2).item()
109
-
110
- # Compute MMD between embeddings
111
- MMD_zc_zp = model.compute_mmd(z_c, z_p)
112
- MMD_zp_zt = model.compute_mmd(z_p, z_t)
113
 
114
- # Print Results
 
 
 
 
115
  print("\n=== Facilitator Model Output ===")
116
  print(f"Shape of z_t (Text Embeddings): {z_t.shape}")
117
  print(f"Shape of z_p (Protein Embeddings): {z_p.shape}")
@@ -127,13 +113,9 @@ if __name__ == '__main__':
127
  print(f"MSE between Text Embeddings (z_t) and Protein Embeddings (z_p): {mse_zt_zp:.6f}")
128
 
129
  print("\n=== Max Mean Discrepancy (MMD) Results ===")
130
- print(f"MMD between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): {MMD_zc_zp:.6f}")
131
- print(f"MMD between Text Embeddings (z_t) and Protein Embeddings (z_p): {MMD_zp_zt:.6f}")
132
-
133
- print("\nFacilitator Model successfully computed facilitated embeddings!")
134
-
135
- # save output embeddings
136
-
137
- torch.save(embedding_dataset, 'Facilitator_test_outputs.pt')
138
-
139
 
 
 
 
 
1
+ import argparse
2
  import yaml
3
  from argparse import Namespace
4
  import json
5
  import pandas as pd
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
 
 
9
  import Stage1_source.model as mod
 
 
10
 
11
+ # Step 1: Load JSON Configuration
12
  def load_json_config(json_path):
 
 
 
13
  with open(json_path, "r") as f:
14
  config = json.load(f)
 
15
  return config
16
 
17
  # Step 2: Convert JSON dictionary to Namespace
18
  def convert_to_namespace(config_dict):
 
 
 
19
  for key, value in config_dict.items():
20
+ if isinstance(value, dict):
21
  config_dict[key] = convert_to_namespace(value)
22
  return Namespace(**config_dict)
23
 
24
+ # Step 3: Load Pre-trained Model
25
+ def prepare_model(config_args, model_path) -> nn.Module:
 
 
26
  model = mod.Facilitator(
27
+ in_dim=config_args.emb_dim,
28
+ hid_dim=config_args.hid_dim,
29
+ out_dim=config_args.emb_dim,
30
+ dropout=config_args.dropout
31
  )
32
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
 
33
  model.eval()
34
  print("Model loaded successfully with weights!")
35
  return model
36
 
37
+ # Step 4: Compute MMD Loss
38
  def compute_mmd_loss(x, y, kernel="rbf", sigma=1.0):
 
 
 
 
 
 
 
 
39
  def rbf_kernel(a, b, sigma):
 
 
 
40
  pairwise_distances = torch.cdist(a, b, p=2) ** 2
41
  return torch.exp(-pairwise_distances / (2 * sigma ** 2))
42
 
43
+ K_xx = rbf_kernel(x, x, sigma)
44
+ K_yy = rbf_kernel(y, y, sigma)
45
+ K_xy = rbf_kernel(x, y, sigma)
 
46
 
 
47
  mmd_loss = K_xx.mean() - 2 * K_xy.mean() + K_yy.mean()
48
  return mmd_loss
49
 
50
+ # Step 5: Argument Parser Function
51
+ def parse_arguments():
52
+ parser = argparse.ArgumentParser(description="BioM3 Facilitator Model (Stage 2)")
53
+ parser.add_argument('--input_data_path', type=str, required=True,
54
+ help="Path to the input embeddings (e.g., PenCL_test_outputs.pt)")
55
+ parser.add_argument('--output_data_path', type=str, required=True,
56
+ help="Path to save the output embeddings (e.g., Facilitator_test_outputs.pt)")
57
+ parser.add_argument('--model_path', type=str, required=True,
58
+ help="Path to the Facilitator model weights (e.g., BioM3_Facilitator_epoch20.bin)")
59
+ parser.add_argument('--json_path', type=str, required=True,
60
+ help="Path to the JSON configuration file (stage2_config.json)")
61
+ return parser.parse_args()
62
+
63
+ # Main Execution
64
  if __name__ == '__main__':
65
+ # Parse arguments
66
+ args = parse_arguments()
67
+
68
+ # Load configuration
69
+ config_dict = load_json_config(args.json_path)
70
+ config_args = convert_to_namespace(config_dict)
71
 
72
+ # Load model
73
+ model = prepare_model(config_args=config_args, model_path=args.model_path)
 
 
 
74
 
75
+ # Load input embeddings
76
+ embedding_dataset = torch.load(args.input_data_path)
77
 
78
+ # Run inference to get facilitated embeddings
79
  with torch.no_grad():
80
  z_t = embedding_dataset['z_t']
81
+ z_p = embedding_dataset['z_p']
82
  z_c = model(z_t)
83
  embedding_dataset['z_c'] = z_c
84
 
85
+ # Compute evaluation metrics
86
+ # 1. MSE between embeddings
87
+ mse_zc_zp = F.mse_loss(z_c, z_p)
88
+ mse_zt_zp = F.mse_loss(z_t, z_p)
89
+
90
+ # 2. Compute L2 norms for first batch
91
  batch_idx = 0
92
  norm_z_t = torch.norm(z_t[batch_idx], p=2).item()
93
  norm_z_p = torch.norm(z_p[batch_idx], p=2).item()
94
  norm_z_c = torch.norm(z_c[batch_idx], p=2).item()
 
 
 
 
95
 
96
+ # 3. Compute MMD between embeddings
97
+ mmd_zc_zp = model.compute_mmd(z_c, z_p)
98
+ mmd_zp_zt = model.compute_mmd(z_p, z_t)
99
+
100
+ # Print results
101
  print("\n=== Facilitator Model Output ===")
102
  print(f"Shape of z_t (Text Embeddings): {z_t.shape}")
103
  print(f"Shape of z_p (Protein Embeddings): {z_p.shape}")
 
113
  print(f"MSE between Text Embeddings (z_t) and Protein Embeddings (z_p): {mse_zt_zp:.6f}")
114
 
115
  print("\n=== Max Mean Discrepancy (MMD) Results ===")
116
+ print(f"MMD between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): {mmd_zc_zp:.6f}")
117
+ print(f"MMD between Text Embeddings (z_t) and Protein Embeddings (z_p): {mmd_zp_zt:.6f}")
 
 
 
 
 
 
 
118
 
119
+ # Save output embeddings
120
+ torch.save(embedding_dataset, args.output_data_path)
121
+ print(f"\nFacilitator embeddings saved to {args.output_data_path}")
run_PenCL_inference.py CHANGED
@@ -56,6 +56,9 @@ def parse_arguments():
56
  help="Path to the JSON configuration file (stage1_config.json)")
57
  parser.add_argument('--model_path', type=str, required=True,
58
  help="Path to the pre-trained model weights (pytorch_model.bin)")
 
 
 
59
  return parser.parse_args()
60
 
61
  # Step 6: Compute Homology Probabilities
@@ -90,7 +93,9 @@ if __name__ == '__main__':
90
  # Run inference and store z_t, z_p
91
  z_t_list = []
92
  z_p_list = []
93
-
 
 
94
  with torch.no_grad():
95
  for idx in range(len(test_dataset)):
96
  batch = test_dataset[idx]
@@ -100,10 +105,24 @@ if __name__ == '__main__':
100
  z_p = outputs['seq_joint_latent'] # Protein latent
101
  z_t_list.append(z_t)
102
  z_p_list.append(z_p)
 
 
 
 
 
 
103
 
104
  # Stack all latent vectors
105
  z_t_tensor = torch.vstack(z_t_list) # Shape: (num_samples, latent_dim)
106
  z_p_tensor = torch.vstack(z_p_list) # Shape: (num_samples, latent_dim)
 
 
 
 
 
 
 
 
107
 
108
  # Compute Dot Product scores
109
  dot_product_scores = torch.matmul(z_p_tensor, z_t_tensor.T) # Dot product
@@ -138,4 +157,5 @@ if __name__ == '__main__':
138
 
139
  print("\n=== Homology Matrix (Dot Product of Normalized z_p) ===")
140
  print(homology_matrix)
141
-
 
 
56
  help="Path to the JSON configuration file (stage1_config.json)")
57
  parser.add_argument('--model_path', type=str, required=True,
58
  help="Path to the pre-trained model weights (pytorch_model.bin)")
59
+ parser.add_argument('--output_path', type=str, required=True,
60
+ help="Path to save output embeddings")
61
+
62
  return parser.parse_args()
63
 
64
  # Step 6: Compute Homology Probabilities
 
93
  # Run inference and store z_t, z_p
94
  z_t_list = []
95
  z_p_list = []
96
+ text_list = []
97
+ protein_list = []
98
+
99
  with torch.no_grad():
100
  for idx in range(len(test_dataset)):
101
  batch = test_dataset[idx]
 
105
  z_p = outputs['seq_joint_latent'] # Protein latent
106
  z_t_list.append(z_t)
107
  z_p_list.append(z_p)
108
+
109
+ protein_sequence = test_dataset.protein_sequence_list[idx]
110
+ text_prompt = test_dataset.text_captions_list[idx]
111
+ text_list.append(text_prompt)
112
+ protein_list.append(protein_sequence)
113
+
114
 
115
  # Stack all latent vectors
116
  z_t_tensor = torch.vstack(z_t_list) # Shape: (num_samples, latent_dim)
117
  z_p_tensor = torch.vstack(z_p_list) # Shape: (num_samples, latent_dim)
118
+
119
+ # Prepare embedding dict.
120
+ embedding_dict = {
121
+ 'sequence': protein_list,
122
+ 'text_prompts': text_list,
123
+ 'z_t': z_t_tensor,
124
+ 'z_p': z_p_tensor
125
+ }
126
 
127
  # Compute Dot Product scores
128
  dot_product_scores = torch.matmul(z_p_tensor, z_t_tensor.T) # Dot product
 
157
 
158
  print("\n=== Homology Matrix (Dot Product of Normalized z_p) ===")
159
  print(homology_matrix)
160
+
161
+ torch.save(embedding_dict, config_args_parser.output_path)
stage2_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_checkpoint_path": "/project/andrewferguson/niksapraljak/Project_ProtARDM/logs/Stage2_facilitator/Stage2_facilitator/checkpoints/Stage2_MMD_Pfam_Swiss_epoch20_ckpt/last.ckpt",
3
+ "model_type": "pfam",
4
+ "fast_dev_run": 0,
5
+ "loss_type": "MMD",
6
+ "dataset_type": "default",
7
+ "precision": "32",
8
+ "stage1_dataset_path": "None",
9
+ "stage2_output_path": "None",
10
+ "seed": 42,
11
+ "num_workers": 12,
12
+ "dropout": 0.0,
13
+ "batch_size": 64,
14
+ "emb_dim": 512,
15
+ "hid_dim": 1024
16
+ }