BrightBlueCheese commited on
Commit
5b2887e
1 Parent(s): 33898af
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -33,6 +33,9 @@ sys.path.append( '../')
33
  import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
34
  import auto_evaluator_sl
35
 
 
 
 
36
  torch.manual_seed(1004)
37
  np.random.seed(1004)
38
 
@@ -48,26 +51,143 @@ with open(file_path, 'w') as file:
48
  ###
49
  # solute_or_solvent = 'solute'
50
  solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
51
- ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
52
- batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
53
- # since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
54
- lr = 0.0001
55
- epochs = 7
56
- use_freeze = False # Freeze the model or not # False measn not freezing
57
- overwrite_level_2 = True
58
- ###
59
- max_seq_length = 512
60
- tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
61
- max_seq_length=max_seq_length,
62
- )
63
- max_length = max_seq_length
64
- num_workers = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # I just reused our previous research code with some modifications.
67
  dir_main = "."
68
- name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
69
-
70
- dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
71
 
72
  max_seq_length = 512
73
 
@@ -79,75 +199,95 @@ num_workers = 2
79
 
80
  ## FT
81
 
82
- ver_ft = 0
83
  dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
84
  # name_model_ft = 'Solvent.pt'
85
  name_model_ft = f"{solute_or_solvent}.pt"
86
 
87
- # Load dataset for finetune
88
- batch_size_for_train = batch_size_pair[0]
89
- batch_size_for_valid = batch_size_pair[1]
90
-
91
- data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
92
- solute_or_solvent=solute_or_solvent,
93
- tokenizer=tokenizer,
94
- max_seq_length=max_length,
95
- batch_size_train=batch_size_for_train,
96
- batch_size_valid=batch_size_for_valid,
97
- # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
98
- num_device=num_workers,
99
- )
100
-
101
- data_module.prepare_data()
102
- data_module.setup()
103
- steps_per_epoch = len(data_module.test_dataloader())
104
-
105
- # Load model and optimizer for finetune
106
- learning_rate = lr
107
-
108
-
109
- model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
110
-
111
-
112
- model_ft = model_finetune_sl.CustomFinetuneModel(
113
- model_mtr=model_mtr,
114
- steps_per_epoch=steps_per_epoch,
115
- warmup_epochs=1,
116
- max_epochs=epochs,
117
- learning_rate=learning_rate,
118
- # dataset_dict=dataset_dict,
119
- use_freeze=use_freeze,
120
- )
121
-
122
- # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
123
-
124
- trainer = L.Trainer(
125
- default_root_dir=dir_model_ft_to_save,
126
- # profiler=profiler,
127
- # logger=csv_logger,
128
- accelerator='auto',
129
- devices='auto',
130
- # accelerator='gpu',
131
- # devices=[0],
132
- min_epochs=1,
133
- max_epochs=epochs,
134
- precision=32,
135
- # callbacks=[checkpoint_callback]
136
- )
137
-
138
-
139
  # Predict
140
  local_model_ft = utils_sl.load_model_ft_with(
141
- class_model_ft=model_ft,
142
  dir_model_ft=dir_model_ft_to_save,
143
  name_model_ft=name_model_ft
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- result = trainer.predict(local_model_ft, data_module)
147
- result_pred = list()
148
- result_label = list()
149
- for bat in range(len(result)):
150
- result_pred.append(result[bat][0].squeeze())
151
- result_label.append(result[bat][1])
 
 
 
 
 
152
 
153
- st.write(result_pred)
 
33
  import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
34
  import auto_evaluator_sl
35
 
36
+ from torch.utils.data import Dataset, DataLoader
37
+ from transformers import DataCollatorWithPadding
38
+
39
  torch.manual_seed(1004)
40
  np.random.seed(1004)
41
 
 
51
  ###
52
  # solute_or_solvent = 'solute'
53
  solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
54
+
55
+
56
+ class ChemLlama(nn.Module):
57
+ def __init__(
58
+ self,
59
+ max_position_embeddings=512,
60
+ vocab_size=591,
61
+ pad_token_id=0,
62
+ bos_token_id=12,
63
+ eos_token_id=13,
64
+ hidden_size=768,
65
+ intermediate_size=768,
66
+ num_labels=105,
67
+ attention_dropout=0.144,
68
+ num_hidden_layers=7,
69
+ num_attention_heads=8,
70
+ learning_rate=0.0001,
71
+ ):
72
+ super(ChemLlama, self).__init__()
73
+
74
+ self.hidden_size = hidden_size
75
+ self.intermediate_size = intermediate_size
76
+ self.num_labels = num_labels
77
+ self.vocab_size = vocab_size
78
+ self.pad_token_id = pad_token_id
79
+ self.bos_token_id = bos_token_id
80
+ self.eos_token_id = eos_token_id
81
+ self.num_hidden_layers = num_hidden_layers
82
+ self.num_attention_heads = num_attention_heads
83
+ self.attention_dropout = attention_dropout
84
+ self.max_position_embeddings = max_position_embeddings
85
+
86
+ self.mae = torchmetrics.MeanAbsoluteError()
87
+ self.mse = torchmetrics.MeanSquaredError()
88
+
89
+ self.config_llama = LlamaConfig(
90
+ max_position_embeddings=self.max_position_embeddings,
91
+ vocab_size=self.vocab_size,
92
+ hidden_size=self.hidden_size,
93
+ intermediate_size=self.intermediate_size,
94
+ num_hidden_layers=self.num_hidden_layers,
95
+ num_attention_heads=self.num_attention_heads,
96
+ attention_dropout=self.attention_dropout,
97
+ pad_token_id=self.pad_token_id,
98
+ bos_token_id=self.bos_token_id,
99
+ eos_token_id=self.eos_token_id,
100
+ )
101
+
102
+ self.loss_fn = nn.L1Loss()
103
+
104
+ self.llama = LlamaModel(self.config_llama)
105
+ self.gelu = nn.GELU()
106
+ self.score = nn.Linear(self.hidden_size, self.num_labels)
107
+
108
+ def forward(self, input_ids, attention_mask, labels=None):
109
+
110
+ transformer_outputs = self.llama(
111
+ input_ids=input_ids, attention_mask=attention_mask
112
+ )
113
+
114
+ hidden_states = transformer_outputs[0]
115
+ hidden_states = self.gelu(hidden_states)
116
+ logits = self.score(hidden_states)
117
+
118
+ if input_ids is not None:
119
+ batch_size = input_ids.shape[0]
120
+ else:
121
+ batch_size = inputs_embeds.shape[0]
122
+
123
+ if self.config_llama.pad_token_id is None and batch_size != 1:
124
+ raise ValueError(
125
+ "Cannot handle batch sizes > 1 if no padding token is defined."
126
+ )
127
+ if self.config_llama.pad_token_id is None:
128
+ sequence_lengths = -1
129
+ else:
130
+ if input_ids is not None:
131
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
132
+ sequence_lengths = (
133
+ torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
134
+ - 1
135
+ )
136
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
137
+ sequence_lengths = sequence_lengths.to(logits.device)
138
+ else:
139
+ sequence_lengths = -1
140
+ # raise ValueError(len(sequence_lengths), sequence_lengths)
141
+
142
+ pooled_logits = logits[
143
+ torch.arange(batch_size, device=logits.device), sequence_lengths
144
+ ]
145
+ return pooled_logits
146
+
147
+
148
+ chemllama_mtr = ChemLlama()
149
+
150
+ class ChemLlama_FT(nn.Module):
151
+ def __init__(
152
+ self,
153
+ model_mtr,
154
+ linear_param:int=64,
155
+ use_freeze:bool=True,
156
+ *args, **kwargs
157
+ ):
158
+ super(CustomFinetuneModel, self).__init__()
159
+ # self.save_hyperparameters()
160
+
161
+ self.model_mtr = model_mtr
162
+ if use_freeze:
163
+ self.model_mtr.freeze()
164
+ # for name, param in model_mtr.named_parameters():
165
+ # param.requires_grad = False
166
+ # print(name, param.requires_grad)
167
+
168
+ self.gelu = nn.GELU()
169
+ self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param)
170
+ self.linear2 = nn.Linear(linear_param, linear_param)
171
+ self.regression = nn.Linear(linear_param, 5)
172
+
173
+ self.loss_fn = nn.L1Loss()
174
+
175
+ def forward(self, input_ids, attention_mask, labels=None):
176
+ x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask)
177
+ x = self.gelu(x)
178
+ x = self.linear1(x)
179
+ x = self.gelu(x)
180
+ x = self.linear2(x)
181
+ x = self.gelu(x)
182
+ x = self.regression(x)
183
+
184
+ return x
185
+
186
+ chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr)
187
+
188
 
189
  # I just reused our previous research code with some modifications.
190
  dir_main = "."
 
 
 
191
 
192
  max_seq_length = 512
193
 
 
199
 
200
  ## FT
201
 
 
202
  dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
203
  # name_model_ft = 'Solvent.pt'
204
  name_model_ft = f"{solute_or_solvent}.pt"
205
 
206
+ # # Load dataset for finetune
207
+ # batch_size_for_train = batch_size_pair[0]
208
+ # batch_size_for_valid = batch_size_pair[1]
209
+
210
+ # data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
211
+ # solute_or_solvent=solute_or_solvent,
212
+ # tokenizer=tokenizer,
213
+ # max_seq_length=max_length,
214
+ # batch_size_train=batch_size_for_train,
215
+ # batch_size_valid=batch_size_for_valid,
216
+ # # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
217
+ # num_device=num_workers,
218
+ # )
219
+
220
+ # data_module.prepare_data()
221
+ # data_module.setup()
222
+ # steps_per_epoch = len(data_module.test_dataloader())
223
+
224
+ # # Load model and optimizer for finetune
225
+ # learning_rate = lr
226
+
227
+
228
+ # model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
229
+
230
+
231
+ # model_ft = model_finetune_sl.CustomFinetuneModel(
232
+ # model_mtr=model_mtr,
233
+ # steps_per_epoch=steps_per_epoch,
234
+ # warmup_epochs=1,
235
+ # max_epochs=epochs,
236
+ # learning_rate=learning_rate,
237
+ # # dataset_dict=dataset_dict,
238
+ # use_freeze=use_freeze,
239
+ # )
240
+
241
+ # # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
242
+
243
+ # trainer = L.Trainer(
244
+ # default_root_dir=dir_model_ft_to_save,
245
+ # # profiler=profiler,
246
+ # # logger=csv_logger,
247
+ # accelerator='auto',
248
+ # devices='auto',
249
+ # # accelerator='gpu',
250
+ # # devices=[0],
251
+ # min_epochs=1,
252
+ # max_epochs=epochs,
253
+ # precision=32,
254
+ # # callbacks=[checkpoint_callback]
255
+ # )
256
+
257
+ device = 'cpu'
258
  # Predict
259
  local_model_ft = utils_sl.load_model_ft_with(
260
+ class_model_ft=chemllama_ft,
261
  dir_model_ft=dir_model_ft_to_save,
262
  name_model_ft=name_model_ft
263
+ ).to(device)
264
+
265
+ # result = trainer.predict(local_model_ft, data_module)
266
+ # result_pred = list()
267
+ # result_label = list()
268
+ # for bat in range(len(result)):
269
+ # result_pred.append(result[bat][0].squeeze())
270
+ # result_label.append(result[bat][1])
271
+
272
+ with open('./smiles_str.txt', 'r') as file:
273
+ smiles_str = file.readline()
274
+
275
+ dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham(
276
+ df=pd.DataFrame([smiles_str]),
277
+ tokenizer=tokenizer,
278
+ max_seq_length=max_length
279
  )
280
 
281
+ dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer))
282
+
283
+ list_predictions = []
284
+ local_model_ft.eval()
285
+ with torch.inference_mode():
286
+ for i, v_batch in enumerate(dataloader_test):
287
+ v_input_ids = v_batch['input_ids'].to(device)
288
+ v_attention_mask = v_batch['attention_mask'].to(device)
289
+ # v_y_labels = v_batch['labels'].to(device)
290
+ v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask)
291
+ list_predictions.append(v_y_logits[0][0].tolist())
292
 
293
+ st.write(list_predictions)
.ipynb_checkpoints/app_old-checkpoint.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ # from git import Repo
4
+
5
+ # Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr')
6
+ import subprocess
7
+
8
+ def git_clone(repo_url, destination_dir):
9
+ try:
10
+ subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True)
11
+ print("Cloning successful!")
12
+ except subprocess.CalledProcessError as e:
13
+ print("Cloning failed:", e)
14
+
15
+ # Example usage
16
+ repo_url = "https://huggingface.co/ttmn/SolLlama-mtr"
17
+ destination_dir = "./SolLlama-mtr"
18
+
19
+ git_clone(repo_url, destination_dir)
20
+
21
+ import sys
22
+ import os
23
+ import torch
24
+ import numpy as np
25
+ import pandas as pd
26
+ import warnings
27
+ import lightning as L
28
+ torch.set_float32_matmul_precision('high')
29
+ warnings.filterwarnings("ignore", module="pl_bolts")
30
+
31
+ sys.path.append( '../')
32
+
33
+ import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
34
+ import auto_evaluator_sl
35
+
36
+ torch.manual_seed(1004)
37
+ np.random.seed(1004)
38
+
39
+ smiles_str = st.text_area('Enter SMILE string')
40
+ file_path = './smiles_str.txt'
41
+
42
+ # Open the file in write mode ('w') and write the content
43
+ with open(file_path, 'w') as file:
44
+ file.write(smiles_str)
45
+
46
+ # smiles_str = "CC02"
47
+
48
+ ###
49
+ # solute_or_solvent = 'solute'
50
+ solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
51
+ ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
52
+ batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
53
+ # since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
54
+ lr = 0.0001
55
+ epochs = 7
56
+ use_freeze = False # Freeze the model or not # False measn not freezing
57
+ overwrite_level_2 = True
58
+ ###
59
+ max_seq_length = 512
60
+ tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
61
+ max_seq_length=max_seq_length,
62
+ )
63
+ max_length = max_seq_length
64
+ num_workers = 2
65
+
66
+ # I just reused our previous research code with some modifications.
67
+ dir_main = "."
68
+ name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
69
+
70
+ dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
71
+
72
+ max_seq_length = 512
73
+
74
+ tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
75
+ max_seq_length=max_seq_length,
76
+ )
77
+ max_length = max_seq_length
78
+ num_workers = 2
79
+
80
+ ## FT
81
+
82
+ ver_ft = 0
83
+ dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
84
+ # name_model_ft = 'Solvent.pt'
85
+ name_model_ft = f"{solute_or_solvent}.pt"
86
+
87
+ # Load dataset for finetune
88
+ batch_size_for_train = batch_size_pair[0]
89
+ batch_size_for_valid = batch_size_pair[1]
90
+
91
+ data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
92
+ solute_or_solvent=solute_or_solvent,
93
+ tokenizer=tokenizer,
94
+ max_seq_length=max_length,
95
+ batch_size_train=batch_size_for_train,
96
+ batch_size_valid=batch_size_for_valid,
97
+ # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
98
+ num_device=num_workers,
99
+ )
100
+
101
+ data_module.prepare_data()
102
+ data_module.setup()
103
+ steps_per_epoch = len(data_module.test_dataloader())
104
+
105
+ # Load model and optimizer for finetune
106
+ learning_rate = lr
107
+
108
+
109
+ model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
110
+
111
+
112
+ model_ft = model_finetune_sl.CustomFinetuneModel(
113
+ model_mtr=model_mtr,
114
+ steps_per_epoch=steps_per_epoch,
115
+ warmup_epochs=1,
116
+ max_epochs=epochs,
117
+ learning_rate=learning_rate,
118
+ # dataset_dict=dataset_dict,
119
+ use_freeze=use_freeze,
120
+ )
121
+
122
+ # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
123
+
124
+ trainer = L.Trainer(
125
+ default_root_dir=dir_model_ft_to_save,
126
+ # profiler=profiler,
127
+ # logger=csv_logger,
128
+ accelerator='auto',
129
+ devices='auto',
130
+ # accelerator='gpu',
131
+ # devices=[0],
132
+ min_epochs=1,
133
+ max_epochs=epochs,
134
+ precision=32,
135
+ # callbacks=[checkpoint_callback]
136
+ )
137
+
138
+
139
+ # Predict
140
+ local_model_ft = utils_sl.load_model_ft_with(
141
+ class_model_ft=model_ft,
142
+ dir_model_ft=dir_model_ft_to_save,
143
+ name_model_ft=name_model_ft
144
+ )
145
+
146
+ result = trainer.predict(local_model_ft, data_module)
147
+ result_pred = list()
148
+ result_label = list()
149
+ for bat in range(len(result)):
150
+ result_pred.append(result[bat][0].squeeze())
151
+ result_label.append(result[bat][1])
152
+
153
+ st.write(result_pred)
.ipynb_checkpoints/chemllama_mtr-checkpoint.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ import torch
3
+ import torchmetrics
4
+
5
+ from torch import nn
6
+ from transformers import LlamaModel, LlamaConfig
7
+ from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
8
+
9
+ class ChemLlama(L.LightningModule):
10
+ def __init__(
11
+ self,
12
+ max_position_embeddings,
13
+ vocab_size,
14
+ pad_token_id,
15
+ bos_token_id,
16
+ eos_token_id,
17
+ steps_per_epoch=None, #
18
+ warmup_epochs=None, #
19
+ max_epochs=None, #
20
+ hidden_size=384,
21
+ intermediate_size=464,
22
+ num_labels=105,
23
+ attention_dropout=0.144,
24
+ num_hidden_layers=3,
25
+ num_attention_heads=12,
26
+ learning_rate=0.0001,
27
+ ):
28
+ super(ChemLlama, self).__init__()
29
+ self.save_hyperparameters()
30
+
31
+ self.hidden_size = hidden_size
32
+ self.intermediate_size = intermediate_size
33
+ self.num_labels = num_labels
34
+ self.vocab_size = vocab_size
35
+ self.pad_token_id = pad_token_id
36
+ self.bos_token_id = bos_token_id
37
+ self.eos_token_id = eos_token_id
38
+ self.steps_per_epoch = steps_per_epoch #
39
+ self.warmup_epochs = warmup_epochs #
40
+ self.max_epochs = max_epochs #
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+ self.attention_dropout = attention_dropout
44
+ self.max_position_embeddings = max_position_embeddings
45
+ self.learning_rate = learning_rate
46
+
47
+ self.mae = torchmetrics.MeanAbsoluteError()
48
+ self.mse = torchmetrics.MeanSquaredError()
49
+
50
+ self.config_llama = LlamaConfig(
51
+ max_position_embeddings=self.max_position_embeddings,
52
+ vocab_size=self.vocab_size,
53
+ hidden_size=self.hidden_size,
54
+ intermediate_size=self.intermediate_size,
55
+ num_hidden_layers=self.num_hidden_layers,
56
+ num_attention_heads=self.num_attention_heads,
57
+ attention_dropout=self.attention_dropout,
58
+ pad_token_id=self.pad_token_id,
59
+ bos_token_id=self.bos_token_id,
60
+ eos_token_id=self.eos_token_id,
61
+ )
62
+
63
+ self.loss_fn = nn.L1Loss()
64
+
65
+ self.llama = LlamaModel(self.config_llama)
66
+ self.gelu = nn.GELU()
67
+ self.score = nn.Linear(self.hidden_size, self.num_labels)
68
+
69
+ def forward(self, input_ids, attention_mask, labels=None):
70
+
71
+ transformer_outputs = self.llama(
72
+ input_ids=input_ids, attention_mask=attention_mask
73
+ )
74
+
75
+ hidden_states = transformer_outputs[0]
76
+ hidden_states = self.gelu(hidden_states)
77
+ logits = self.score(hidden_states)
78
+
79
+ if input_ids is not None:
80
+ batch_size = input_ids.shape[0]
81
+ else:
82
+ batch_size = inputs_embeds.shape[0]
83
+
84
+ if self.config_llama.pad_token_id is None and batch_size != 1:
85
+ raise ValueError(
86
+ "Cannot handle batch sizes > 1 if no padding token is defined."
87
+ )
88
+ if self.config_llama.pad_token_id is None:
89
+ sequence_lengths = -1
90
+ else:
91
+ if input_ids is not None:
92
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
93
+ sequence_lengths = (
94
+ torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
95
+ - 1
96
+ )
97
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
98
+ sequence_lengths = sequence_lengths.to(logits.device)
99
+ else:
100
+ sequence_lengths = -1
101
+ # raise ValueError(len(sequence_lengths), sequence_lengths)
102
+
103
+ pooled_logits = logits[
104
+ torch.arange(batch_size, device=logits.device), sequence_lengths
105
+ ]
106
+ return pooled_logits
107
+
108
+ def training_step(self, batch, batch_idx):
109
+
110
+ loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx)
111
+
112
+ # mae = self.mae(logits, labels)
113
+ # mse = self.mse(logits, labels)
114
+ self.log_dict(
115
+ {
116
+ "train_loss": loss,
117
+ # "train_mae": mae,
118
+ # "train_mse": mse
119
+ },
120
+ on_step=True,
121
+ on_epoch=True,
122
+ prog_bar=True,
123
+ sync_dist=True,
124
+ # logger=True,
125
+ )
126
+ # on_stecp = True will use lots of computational resources
127
+
128
+ # return loss
129
+ return {"loss": loss, "logits": logits, "labels": labels}
130
+
131
+ def train_epoch_end(self, outputs):
132
+ # avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
133
+ scores = torch.cat([x["logits"] for x in outputs])
134
+ labels = torch.cat([x["labels"] for x in outputs])
135
+ self.log_dict(
136
+ {
137
+ "train_mae": self.mae(scores, labels),
138
+ "train_mse": self.mse(scores, labels)
139
+ },
140
+ on_step=True,
141
+ on_epoch=True,
142
+ prog_bar=True,
143
+ sync_dist=True,
144
+ )
145
+
146
+ def validation_step(self, batch, batch_idx):
147
+
148
+ loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx)
149
+ # self.log("val_loss", loss)
150
+ self.log("val_loss", loss, sync_dist=True)
151
+ return loss
152
+
153
+ def test_step(self, batch, batch_idx):
154
+
155
+ loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx)
156
+ # self.log("val_loss", loss)
157
+ self.log("test_loss", loss, sync_dist=True,)
158
+ return loss
159
+
160
+ def _common_step(self, batch, batch_idx):
161
+
162
+ logits = self.forward(
163
+ input_ids=batch["input_ids"].squeeze(),
164
+ attention_mask=batch["attention_mask"].squeeze(),
165
+ )
166
+
167
+ labels = batch["labels"].squeeze()
168
+ loss = self.loss_fn(logits, labels)
169
+
170
+ # print(f"logits : {logits.shape}")
171
+ # print(f"labels : {labels.shape}")
172
+
173
+ return loss, logits, labels
174
+
175
+ # def configure_optimizers(self): # Schedular here too!
176
+ # # since confiture_optimizers and the model are included in the same class.. self.parameters()
177
+ # return torch.optim.AdamW(
178
+ # params=self.parameters(),
179
+ # lr=self.learning_rate,
180
+ # betas=(0.9, 0.999),
181
+ # weight_decay=0.01,
182
+ # )
183
+
184
+ # # The below is for warm-up scheduler
185
+ # https://lightning.ai/forums/t/how-to-use-warmup-lr-cosineannealinglr-in-lightning/1980
186
+ # https://github.com/Lightning-AI/pytorch-lightning/issues/328
187
+ def configure_optimizers(self): # Schedular here too!
188
+ # since confiture_optimizers and the model are included in the same class.. self.parameters()
189
+ optimizer = torch.optim.AdamW(
190
+ params=self.parameters(),
191
+ lr=self.learning_rate,
192
+ betas=(0.9, 0.999),
193
+ weight_decay=0.01,
194
+ )
195
+
196
+ # "warmup_epochs //4 only not max_epochs" will work
197
+ scheduler = LinearWarmupCosineAnnealingLR(
198
+ optimizer,
199
+ warmup_epochs=self.warmup_epochs*self.steps_per_epoch // torch.cuda.device_count(), # // num_device in this case
200
+ max_epochs=self.max_epochs*self.steps_per_epoch // torch.cuda.device_count(),
201
+ )
202
+
203
+ return {
204
+ "optimizer": optimizer,
205
+ "lr_scheduler": {
206
+ "scheduler": scheduler,
207
+ "interval": "step",
208
+ "frequency": 1,
209
+ "reduce_on_plateau": False,
210
+ "monitor": "val_loss",
211
+ }
212
+ }
app.py CHANGED
@@ -33,6 +33,9 @@ sys.path.append( '../')
33
  import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
34
  import auto_evaluator_sl
35
 
 
 
 
36
  torch.manual_seed(1004)
37
  np.random.seed(1004)
38
 
@@ -48,26 +51,143 @@ with open(file_path, 'w') as file:
48
  ###
49
  # solute_or_solvent = 'solute'
50
  solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
51
- ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
52
- batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
53
- # since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
54
- lr = 0.0001
55
- epochs = 7
56
- use_freeze = False # Freeze the model or not # False measn not freezing
57
- overwrite_level_2 = True
58
- ###
59
- max_seq_length = 512
60
- tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
61
- max_seq_length=max_seq_length,
62
- )
63
- max_length = max_seq_length
64
- num_workers = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # I just reused our previous research code with some modifications.
67
  dir_main = "."
68
- name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
69
-
70
- dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
71
 
72
  max_seq_length = 512
73
 
@@ -79,75 +199,95 @@ num_workers = 2
79
 
80
  ## FT
81
 
82
- ver_ft = 0
83
  dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
84
  # name_model_ft = 'Solvent.pt'
85
  name_model_ft = f"{solute_or_solvent}.pt"
86
 
87
- # Load dataset for finetune
88
- batch_size_for_train = batch_size_pair[0]
89
- batch_size_for_valid = batch_size_pair[1]
90
-
91
- data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
92
- solute_or_solvent=solute_or_solvent,
93
- tokenizer=tokenizer,
94
- max_seq_length=max_length,
95
- batch_size_train=batch_size_for_train,
96
- batch_size_valid=batch_size_for_valid,
97
- # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
98
- num_device=num_workers,
99
- )
100
-
101
- data_module.prepare_data()
102
- data_module.setup()
103
- steps_per_epoch = len(data_module.test_dataloader())
104
-
105
- # Load model and optimizer for finetune
106
- learning_rate = lr
107
-
108
-
109
- model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
110
-
111
-
112
- model_ft = model_finetune_sl.CustomFinetuneModel(
113
- model_mtr=model_mtr,
114
- steps_per_epoch=steps_per_epoch,
115
- warmup_epochs=1,
116
- max_epochs=epochs,
117
- learning_rate=learning_rate,
118
- # dataset_dict=dataset_dict,
119
- use_freeze=use_freeze,
120
- )
121
-
122
- # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
123
-
124
- trainer = L.Trainer(
125
- default_root_dir=dir_model_ft_to_save,
126
- # profiler=profiler,
127
- # logger=csv_logger,
128
- accelerator='auto',
129
- devices='auto',
130
- # accelerator='gpu',
131
- # devices=[0],
132
- min_epochs=1,
133
- max_epochs=epochs,
134
- precision=32,
135
- # callbacks=[checkpoint_callback]
136
- )
137
-
138
-
139
  # Predict
140
  local_model_ft = utils_sl.load_model_ft_with(
141
- class_model_ft=model_ft,
142
  dir_model_ft=dir_model_ft_to_save,
143
  name_model_ft=name_model_ft
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- result = trainer.predict(local_model_ft, data_module)
147
- result_pred = list()
148
- result_label = list()
149
- for bat in range(len(result)):
150
- result_pred.append(result[bat][0].squeeze())
151
- result_label.append(result[bat][1])
 
 
 
 
 
152
 
153
- st.write(result_pred)
 
33
  import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
34
  import auto_evaluator_sl
35
 
36
+ from torch.utils.data import Dataset, DataLoader
37
+ from transformers import DataCollatorWithPadding
38
+
39
  torch.manual_seed(1004)
40
  np.random.seed(1004)
41
 
 
51
  ###
52
  # solute_or_solvent = 'solute'
53
  solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
54
+
55
+
56
+ class ChemLlama(nn.Module):
57
+ def __init__(
58
+ self,
59
+ max_position_embeddings=512,
60
+ vocab_size=591,
61
+ pad_token_id=0,
62
+ bos_token_id=12,
63
+ eos_token_id=13,
64
+ hidden_size=768,
65
+ intermediate_size=768,
66
+ num_labels=105,
67
+ attention_dropout=0.144,
68
+ num_hidden_layers=7,
69
+ num_attention_heads=8,
70
+ learning_rate=0.0001,
71
+ ):
72
+ super(ChemLlama, self).__init__()
73
+
74
+ self.hidden_size = hidden_size
75
+ self.intermediate_size = intermediate_size
76
+ self.num_labels = num_labels
77
+ self.vocab_size = vocab_size
78
+ self.pad_token_id = pad_token_id
79
+ self.bos_token_id = bos_token_id
80
+ self.eos_token_id = eos_token_id
81
+ self.num_hidden_layers = num_hidden_layers
82
+ self.num_attention_heads = num_attention_heads
83
+ self.attention_dropout = attention_dropout
84
+ self.max_position_embeddings = max_position_embeddings
85
+
86
+ self.mae = torchmetrics.MeanAbsoluteError()
87
+ self.mse = torchmetrics.MeanSquaredError()
88
+
89
+ self.config_llama = LlamaConfig(
90
+ max_position_embeddings=self.max_position_embeddings,
91
+ vocab_size=self.vocab_size,
92
+ hidden_size=self.hidden_size,
93
+ intermediate_size=self.intermediate_size,
94
+ num_hidden_layers=self.num_hidden_layers,
95
+ num_attention_heads=self.num_attention_heads,
96
+ attention_dropout=self.attention_dropout,
97
+ pad_token_id=self.pad_token_id,
98
+ bos_token_id=self.bos_token_id,
99
+ eos_token_id=self.eos_token_id,
100
+ )
101
+
102
+ self.loss_fn = nn.L1Loss()
103
+
104
+ self.llama = LlamaModel(self.config_llama)
105
+ self.gelu = nn.GELU()
106
+ self.score = nn.Linear(self.hidden_size, self.num_labels)
107
+
108
+ def forward(self, input_ids, attention_mask, labels=None):
109
+
110
+ transformer_outputs = self.llama(
111
+ input_ids=input_ids, attention_mask=attention_mask
112
+ )
113
+
114
+ hidden_states = transformer_outputs[0]
115
+ hidden_states = self.gelu(hidden_states)
116
+ logits = self.score(hidden_states)
117
+
118
+ if input_ids is not None:
119
+ batch_size = input_ids.shape[0]
120
+ else:
121
+ batch_size = inputs_embeds.shape[0]
122
+
123
+ if self.config_llama.pad_token_id is None and batch_size != 1:
124
+ raise ValueError(
125
+ "Cannot handle batch sizes > 1 if no padding token is defined."
126
+ )
127
+ if self.config_llama.pad_token_id is None:
128
+ sequence_lengths = -1
129
+ else:
130
+ if input_ids is not None:
131
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
132
+ sequence_lengths = (
133
+ torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
134
+ - 1
135
+ )
136
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
137
+ sequence_lengths = sequence_lengths.to(logits.device)
138
+ else:
139
+ sequence_lengths = -1
140
+ # raise ValueError(len(sequence_lengths), sequence_lengths)
141
+
142
+ pooled_logits = logits[
143
+ torch.arange(batch_size, device=logits.device), sequence_lengths
144
+ ]
145
+ return pooled_logits
146
+
147
+
148
+ chemllama_mtr = ChemLlama()
149
+
150
+ class ChemLlama_FT(nn.Module):
151
+ def __init__(
152
+ self,
153
+ model_mtr,
154
+ linear_param:int=64,
155
+ use_freeze:bool=True,
156
+ *args, **kwargs
157
+ ):
158
+ super(CustomFinetuneModel, self).__init__()
159
+ # self.save_hyperparameters()
160
+
161
+ self.model_mtr = model_mtr
162
+ if use_freeze:
163
+ self.model_mtr.freeze()
164
+ # for name, param in model_mtr.named_parameters():
165
+ # param.requires_grad = False
166
+ # print(name, param.requires_grad)
167
+
168
+ self.gelu = nn.GELU()
169
+ self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param)
170
+ self.linear2 = nn.Linear(linear_param, linear_param)
171
+ self.regression = nn.Linear(linear_param, 5)
172
+
173
+ self.loss_fn = nn.L1Loss()
174
+
175
+ def forward(self, input_ids, attention_mask, labels=None):
176
+ x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask)
177
+ x = self.gelu(x)
178
+ x = self.linear1(x)
179
+ x = self.gelu(x)
180
+ x = self.linear2(x)
181
+ x = self.gelu(x)
182
+ x = self.regression(x)
183
+
184
+ return x
185
+
186
+ chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr)
187
+
188
 
189
  # I just reused our previous research code with some modifications.
190
  dir_main = "."
 
 
 
191
 
192
  max_seq_length = 512
193
 
 
199
 
200
  ## FT
201
 
 
202
  dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
203
  # name_model_ft = 'Solvent.pt'
204
  name_model_ft = f"{solute_or_solvent}.pt"
205
 
206
+ # # Load dataset for finetune
207
+ # batch_size_for_train = batch_size_pair[0]
208
+ # batch_size_for_valid = batch_size_pair[1]
209
+
210
+ # data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
211
+ # solute_or_solvent=solute_or_solvent,
212
+ # tokenizer=tokenizer,
213
+ # max_seq_length=max_length,
214
+ # batch_size_train=batch_size_for_train,
215
+ # batch_size_valid=batch_size_for_valid,
216
+ # # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
217
+ # num_device=num_workers,
218
+ # )
219
+
220
+ # data_module.prepare_data()
221
+ # data_module.setup()
222
+ # steps_per_epoch = len(data_module.test_dataloader())
223
+
224
+ # # Load model and optimizer for finetune
225
+ # learning_rate = lr
226
+
227
+
228
+ # model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
229
+
230
+
231
+ # model_ft = model_finetune_sl.CustomFinetuneModel(
232
+ # model_mtr=model_mtr,
233
+ # steps_per_epoch=steps_per_epoch,
234
+ # warmup_epochs=1,
235
+ # max_epochs=epochs,
236
+ # learning_rate=learning_rate,
237
+ # # dataset_dict=dataset_dict,
238
+ # use_freeze=use_freeze,
239
+ # )
240
+
241
+ # # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
242
+
243
+ # trainer = L.Trainer(
244
+ # default_root_dir=dir_model_ft_to_save,
245
+ # # profiler=profiler,
246
+ # # logger=csv_logger,
247
+ # accelerator='auto',
248
+ # devices='auto',
249
+ # # accelerator='gpu',
250
+ # # devices=[0],
251
+ # min_epochs=1,
252
+ # max_epochs=epochs,
253
+ # precision=32,
254
+ # # callbacks=[checkpoint_callback]
255
+ # )
256
+
257
+ device = 'cpu'
258
  # Predict
259
  local_model_ft = utils_sl.load_model_ft_with(
260
+ class_model_ft=chemllama_ft,
261
  dir_model_ft=dir_model_ft_to_save,
262
  name_model_ft=name_model_ft
263
+ ).to(device)
264
+
265
+ # result = trainer.predict(local_model_ft, data_module)
266
+ # result_pred = list()
267
+ # result_label = list()
268
+ # for bat in range(len(result)):
269
+ # result_pred.append(result[bat][0].squeeze())
270
+ # result_label.append(result[bat][1])
271
+
272
+ with open('./smiles_str.txt', 'r') as file:
273
+ smiles_str = file.readline()
274
+
275
+ dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham(
276
+ df=pd.DataFrame([smiles_str]),
277
+ tokenizer=tokenizer,
278
+ max_seq_length=max_length
279
  )
280
 
281
+ dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer))
282
+
283
+ list_predictions = []
284
+ local_model_ft.eval()
285
+ with torch.inference_mode():
286
+ for i, v_batch in enumerate(dataloader_test):
287
+ v_input_ids = v_batch['input_ids'].to(device)
288
+ v_attention_mask = v_batch['attention_mask'].to(device)
289
+ # v_y_labels = v_batch['labels'].to(device)
290
+ v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask)
291
+ list_predictions.append(v_y_logits[0][0].tolist())
292
 
293
+ st.write(list_predictions)
app_old.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ # from git import Repo
4
+
5
+ # Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr')
6
+ import subprocess
7
+
8
+ def git_clone(repo_url, destination_dir):
9
+ try:
10
+ subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True)
11
+ print("Cloning successful!")
12
+ except subprocess.CalledProcessError as e:
13
+ print("Cloning failed:", e)
14
+
15
+ # Example usage
16
+ repo_url = "https://huggingface.co/ttmn/SolLlama-mtr"
17
+ destination_dir = "./SolLlama-mtr"
18
+
19
+ git_clone(repo_url, destination_dir)
20
+
21
+ import sys
22
+ import os
23
+ import torch
24
+ import numpy as np
25
+ import pandas as pd
26
+ import warnings
27
+ import lightning as L
28
+ torch.set_float32_matmul_precision('high')
29
+ warnings.filterwarnings("ignore", module="pl_bolts")
30
+
31
+ sys.path.append( '../')
32
+
33
+ import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
34
+ import auto_evaluator_sl
35
+
36
+ torch.manual_seed(1004)
37
+ np.random.seed(1004)
38
+
39
+ smiles_str = st.text_area('Enter SMILE string')
40
+ file_path = './smiles_str.txt'
41
+
42
+ # Open the file in write mode ('w') and write the content
43
+ with open(file_path, 'w') as file:
44
+ file.write(smiles_str)
45
+
46
+ # smiles_str = "CC02"
47
+
48
+ ###
49
+ # solute_or_solvent = 'solute'
50
+ solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
51
+ ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
52
+ batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
53
+ # since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
54
+ lr = 0.0001
55
+ epochs = 7
56
+ use_freeze = False # Freeze the model or not # False measn not freezing
57
+ overwrite_level_2 = True
58
+ ###
59
+ max_seq_length = 512
60
+ tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
61
+ max_seq_length=max_seq_length,
62
+ )
63
+ max_length = max_seq_length
64
+ num_workers = 2
65
+
66
+ # I just reused our previous research code with some modifications.
67
+ dir_main = "."
68
+ name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
69
+
70
+ dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
71
+
72
+ max_seq_length = 512
73
+
74
+ tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
75
+ max_seq_length=max_seq_length,
76
+ )
77
+ max_length = max_seq_length
78
+ num_workers = 2
79
+
80
+ ## FT
81
+
82
+ ver_ft = 0
83
+ dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
84
+ # name_model_ft = 'Solvent.pt'
85
+ name_model_ft = f"{solute_or_solvent}.pt"
86
+
87
+ # Load dataset for finetune
88
+ batch_size_for_train = batch_size_pair[0]
89
+ batch_size_for_valid = batch_size_pair[1]
90
+
91
+ data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
92
+ solute_or_solvent=solute_or_solvent,
93
+ tokenizer=tokenizer,
94
+ max_seq_length=max_length,
95
+ batch_size_train=batch_size_for_train,
96
+ batch_size_valid=batch_size_for_valid,
97
+ # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
98
+ num_device=num_workers,
99
+ )
100
+
101
+ data_module.prepare_data()
102
+ data_module.setup()
103
+ steps_per_epoch = len(data_module.test_dataloader())
104
+
105
+ # Load model and optimizer for finetune
106
+ learning_rate = lr
107
+
108
+
109
+ model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
110
+
111
+
112
+ model_ft = model_finetune_sl.CustomFinetuneModel(
113
+ model_mtr=model_mtr,
114
+ steps_per_epoch=steps_per_epoch,
115
+ warmup_epochs=1,
116
+ max_epochs=epochs,
117
+ learning_rate=learning_rate,
118
+ # dataset_dict=dataset_dict,
119
+ use_freeze=use_freeze,
120
+ )
121
+
122
+ # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
123
+
124
+ trainer = L.Trainer(
125
+ default_root_dir=dir_model_ft_to_save,
126
+ # profiler=profiler,
127
+ # logger=csv_logger,
128
+ accelerator='auto',
129
+ devices='auto',
130
+ # accelerator='gpu',
131
+ # devices=[0],
132
+ min_epochs=1,
133
+ max_epochs=epochs,
134
+ precision=32,
135
+ # callbacks=[checkpoint_callback]
136
+ )
137
+
138
+
139
+ # Predict
140
+ local_model_ft = utils_sl.load_model_ft_with(
141
+ class_model_ft=model_ft,
142
+ dir_model_ft=dir_model_ft_to_save,
143
+ name_model_ft=name_model_ft
144
+ )
145
+
146
+ result = trainer.predict(local_model_ft, data_module)
147
+ result_pred = list()
148
+ result_label = list()
149
+ for bat in range(len(result)):
150
+ result_pred.append(result[bat][0].squeeze())
151
+ result_label.append(result[bat][1])
152
+
153
+ st.write(result_pred)