BrightBlueCheese
commited on
Commit
•
5b2887e
1
Parent(s):
33898af
app
Browse files- .ipynb_checkpoints/app-checkpoint.py +218 -78
- .ipynb_checkpoints/app_old-checkpoint.py +153 -0
- .ipynb_checkpoints/chemllama_mtr-checkpoint.py +212 -0
- app.py +218 -78
- app_old.py +153 -0
.ipynb_checkpoints/app-checkpoint.py
CHANGED
@@ -33,6 +33,9 @@ sys.path.append( '../')
|
|
33 |
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
|
34 |
import auto_evaluator_sl
|
35 |
|
|
|
|
|
|
|
36 |
torch.manual_seed(1004)
|
37 |
np.random.seed(1004)
|
38 |
|
@@ -48,26 +51,143 @@ with open(file_path, 'w') as file:
|
|
48 |
###
|
49 |
# solute_or_solvent = 'solute'
|
50 |
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# I just reused our previous research code with some modifications.
|
67 |
dir_main = "."
|
68 |
-
name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
|
69 |
-
|
70 |
-
dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
|
71 |
|
72 |
max_seq_length = 512
|
73 |
|
@@ -79,75 +199,95 @@ num_workers = 2
|
|
79 |
|
80 |
## FT
|
81 |
|
82 |
-
ver_ft = 0
|
83 |
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
|
84 |
# name_model_ft = 'Solvent.pt'
|
85 |
name_model_ft = f"{solute_or_solvent}.pt"
|
86 |
|
87 |
-
# Load dataset for finetune
|
88 |
-
batch_size_for_train = batch_size_pair[0]
|
89 |
-
batch_size_for_valid = batch_size_pair[1]
|
90 |
-
|
91 |
-
data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
)
|
100 |
-
|
101 |
-
data_module.prepare_data()
|
102 |
-
data_module.setup()
|
103 |
-
steps_per_epoch = len(data_module.test_dataloader())
|
104 |
-
|
105 |
-
# Load model and optimizer for finetune
|
106 |
-
learning_rate = lr
|
107 |
-
|
108 |
-
|
109 |
-
model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
|
110 |
-
|
111 |
-
|
112 |
-
model_ft = model_finetune_sl.CustomFinetuneModel(
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
)
|
121 |
-
|
122 |
-
# 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
|
123 |
-
|
124 |
-
trainer = L.Trainer(
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
)
|
137 |
-
|
138 |
-
|
139 |
# Predict
|
140 |
local_model_ft = utils_sl.load_model_ft_with(
|
141 |
-
class_model_ft=
|
142 |
dir_model_ft=dir_model_ft_to_save,
|
143 |
name_model_ft=name_model_ft
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
st.write(
|
|
|
33 |
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
|
34 |
import auto_evaluator_sl
|
35 |
|
36 |
+
from torch.utils.data import Dataset, DataLoader
|
37 |
+
from transformers import DataCollatorWithPadding
|
38 |
+
|
39 |
torch.manual_seed(1004)
|
40 |
np.random.seed(1004)
|
41 |
|
|
|
51 |
###
|
52 |
# solute_or_solvent = 'solute'
|
53 |
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
|
54 |
+
|
55 |
+
|
56 |
+
class ChemLlama(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
max_position_embeddings=512,
|
60 |
+
vocab_size=591,
|
61 |
+
pad_token_id=0,
|
62 |
+
bos_token_id=12,
|
63 |
+
eos_token_id=13,
|
64 |
+
hidden_size=768,
|
65 |
+
intermediate_size=768,
|
66 |
+
num_labels=105,
|
67 |
+
attention_dropout=0.144,
|
68 |
+
num_hidden_layers=7,
|
69 |
+
num_attention_heads=8,
|
70 |
+
learning_rate=0.0001,
|
71 |
+
):
|
72 |
+
super(ChemLlama, self).__init__()
|
73 |
+
|
74 |
+
self.hidden_size = hidden_size
|
75 |
+
self.intermediate_size = intermediate_size
|
76 |
+
self.num_labels = num_labels
|
77 |
+
self.vocab_size = vocab_size
|
78 |
+
self.pad_token_id = pad_token_id
|
79 |
+
self.bos_token_id = bos_token_id
|
80 |
+
self.eos_token_id = eos_token_id
|
81 |
+
self.num_hidden_layers = num_hidden_layers
|
82 |
+
self.num_attention_heads = num_attention_heads
|
83 |
+
self.attention_dropout = attention_dropout
|
84 |
+
self.max_position_embeddings = max_position_embeddings
|
85 |
+
|
86 |
+
self.mae = torchmetrics.MeanAbsoluteError()
|
87 |
+
self.mse = torchmetrics.MeanSquaredError()
|
88 |
+
|
89 |
+
self.config_llama = LlamaConfig(
|
90 |
+
max_position_embeddings=self.max_position_embeddings,
|
91 |
+
vocab_size=self.vocab_size,
|
92 |
+
hidden_size=self.hidden_size,
|
93 |
+
intermediate_size=self.intermediate_size,
|
94 |
+
num_hidden_layers=self.num_hidden_layers,
|
95 |
+
num_attention_heads=self.num_attention_heads,
|
96 |
+
attention_dropout=self.attention_dropout,
|
97 |
+
pad_token_id=self.pad_token_id,
|
98 |
+
bos_token_id=self.bos_token_id,
|
99 |
+
eos_token_id=self.eos_token_id,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.loss_fn = nn.L1Loss()
|
103 |
+
|
104 |
+
self.llama = LlamaModel(self.config_llama)
|
105 |
+
self.gelu = nn.GELU()
|
106 |
+
self.score = nn.Linear(self.hidden_size, self.num_labels)
|
107 |
+
|
108 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
109 |
+
|
110 |
+
transformer_outputs = self.llama(
|
111 |
+
input_ids=input_ids, attention_mask=attention_mask
|
112 |
+
)
|
113 |
+
|
114 |
+
hidden_states = transformer_outputs[0]
|
115 |
+
hidden_states = self.gelu(hidden_states)
|
116 |
+
logits = self.score(hidden_states)
|
117 |
+
|
118 |
+
if input_ids is not None:
|
119 |
+
batch_size = input_ids.shape[0]
|
120 |
+
else:
|
121 |
+
batch_size = inputs_embeds.shape[0]
|
122 |
+
|
123 |
+
if self.config_llama.pad_token_id is None and batch_size != 1:
|
124 |
+
raise ValueError(
|
125 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
126 |
+
)
|
127 |
+
if self.config_llama.pad_token_id is None:
|
128 |
+
sequence_lengths = -1
|
129 |
+
else:
|
130 |
+
if input_ids is not None:
|
131 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
132 |
+
sequence_lengths = (
|
133 |
+
torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
|
134 |
+
- 1
|
135 |
+
)
|
136 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
137 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
138 |
+
else:
|
139 |
+
sequence_lengths = -1
|
140 |
+
# raise ValueError(len(sequence_lengths), sequence_lengths)
|
141 |
+
|
142 |
+
pooled_logits = logits[
|
143 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
144 |
+
]
|
145 |
+
return pooled_logits
|
146 |
+
|
147 |
+
|
148 |
+
chemllama_mtr = ChemLlama()
|
149 |
+
|
150 |
+
class ChemLlama_FT(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
model_mtr,
|
154 |
+
linear_param:int=64,
|
155 |
+
use_freeze:bool=True,
|
156 |
+
*args, **kwargs
|
157 |
+
):
|
158 |
+
super(CustomFinetuneModel, self).__init__()
|
159 |
+
# self.save_hyperparameters()
|
160 |
+
|
161 |
+
self.model_mtr = model_mtr
|
162 |
+
if use_freeze:
|
163 |
+
self.model_mtr.freeze()
|
164 |
+
# for name, param in model_mtr.named_parameters():
|
165 |
+
# param.requires_grad = False
|
166 |
+
# print(name, param.requires_grad)
|
167 |
+
|
168 |
+
self.gelu = nn.GELU()
|
169 |
+
self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param)
|
170 |
+
self.linear2 = nn.Linear(linear_param, linear_param)
|
171 |
+
self.regression = nn.Linear(linear_param, 5)
|
172 |
+
|
173 |
+
self.loss_fn = nn.L1Loss()
|
174 |
+
|
175 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
176 |
+
x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask)
|
177 |
+
x = self.gelu(x)
|
178 |
+
x = self.linear1(x)
|
179 |
+
x = self.gelu(x)
|
180 |
+
x = self.linear2(x)
|
181 |
+
x = self.gelu(x)
|
182 |
+
x = self.regression(x)
|
183 |
+
|
184 |
+
return x
|
185 |
+
|
186 |
+
chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr)
|
187 |
+
|
188 |
|
189 |
# I just reused our previous research code with some modifications.
|
190 |
dir_main = "."
|
|
|
|
|
|
|
191 |
|
192 |
max_seq_length = 512
|
193 |
|
|
|
199 |
|
200 |
## FT
|
201 |
|
|
|
202 |
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
|
203 |
# name_model_ft = 'Solvent.pt'
|
204 |
name_model_ft = f"{solute_or_solvent}.pt"
|
205 |
|
206 |
+
# # Load dataset for finetune
|
207 |
+
# batch_size_for_train = batch_size_pair[0]
|
208 |
+
# batch_size_for_valid = batch_size_pair[1]
|
209 |
+
|
210 |
+
# data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
|
211 |
+
# solute_or_solvent=solute_or_solvent,
|
212 |
+
# tokenizer=tokenizer,
|
213 |
+
# max_seq_length=max_length,
|
214 |
+
# batch_size_train=batch_size_for_train,
|
215 |
+
# batch_size_valid=batch_size_for_valid,
|
216 |
+
# # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
|
217 |
+
# num_device=num_workers,
|
218 |
+
# )
|
219 |
+
|
220 |
+
# data_module.prepare_data()
|
221 |
+
# data_module.setup()
|
222 |
+
# steps_per_epoch = len(data_module.test_dataloader())
|
223 |
+
|
224 |
+
# # Load model and optimizer for finetune
|
225 |
+
# learning_rate = lr
|
226 |
+
|
227 |
+
|
228 |
+
# model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
|
229 |
+
|
230 |
+
|
231 |
+
# model_ft = model_finetune_sl.CustomFinetuneModel(
|
232 |
+
# model_mtr=model_mtr,
|
233 |
+
# steps_per_epoch=steps_per_epoch,
|
234 |
+
# warmup_epochs=1,
|
235 |
+
# max_epochs=epochs,
|
236 |
+
# learning_rate=learning_rate,
|
237 |
+
# # dataset_dict=dataset_dict,
|
238 |
+
# use_freeze=use_freeze,
|
239 |
+
# )
|
240 |
+
|
241 |
+
# # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
|
242 |
+
|
243 |
+
# trainer = L.Trainer(
|
244 |
+
# default_root_dir=dir_model_ft_to_save,
|
245 |
+
# # profiler=profiler,
|
246 |
+
# # logger=csv_logger,
|
247 |
+
# accelerator='auto',
|
248 |
+
# devices='auto',
|
249 |
+
# # accelerator='gpu',
|
250 |
+
# # devices=[0],
|
251 |
+
# min_epochs=1,
|
252 |
+
# max_epochs=epochs,
|
253 |
+
# precision=32,
|
254 |
+
# # callbacks=[checkpoint_callback]
|
255 |
+
# )
|
256 |
+
|
257 |
+
device = 'cpu'
|
258 |
# Predict
|
259 |
local_model_ft = utils_sl.load_model_ft_with(
|
260 |
+
class_model_ft=chemllama_ft,
|
261 |
dir_model_ft=dir_model_ft_to_save,
|
262 |
name_model_ft=name_model_ft
|
263 |
+
).to(device)
|
264 |
+
|
265 |
+
# result = trainer.predict(local_model_ft, data_module)
|
266 |
+
# result_pred = list()
|
267 |
+
# result_label = list()
|
268 |
+
# for bat in range(len(result)):
|
269 |
+
# result_pred.append(result[bat][0].squeeze())
|
270 |
+
# result_label.append(result[bat][1])
|
271 |
+
|
272 |
+
with open('./smiles_str.txt', 'r') as file:
|
273 |
+
smiles_str = file.readline()
|
274 |
+
|
275 |
+
dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham(
|
276 |
+
df=pd.DataFrame([smiles_str]),
|
277 |
+
tokenizer=tokenizer,
|
278 |
+
max_seq_length=max_length
|
279 |
)
|
280 |
|
281 |
+
dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer))
|
282 |
+
|
283 |
+
list_predictions = []
|
284 |
+
local_model_ft.eval()
|
285 |
+
with torch.inference_mode():
|
286 |
+
for i, v_batch in enumerate(dataloader_test):
|
287 |
+
v_input_ids = v_batch['input_ids'].to(device)
|
288 |
+
v_attention_mask = v_batch['attention_mask'].to(device)
|
289 |
+
# v_y_labels = v_batch['labels'].to(device)
|
290 |
+
v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask)
|
291 |
+
list_predictions.append(v_y_logits[0][0].tolist())
|
292 |
|
293 |
+
st.write(list_predictions)
|
.ipynb_checkpoints/app_old-checkpoint.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import streamlit as st
|
3 |
+
# from git import Repo
|
4 |
+
|
5 |
+
# Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr')
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
def git_clone(repo_url, destination_dir):
|
9 |
+
try:
|
10 |
+
subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True)
|
11 |
+
print("Cloning successful!")
|
12 |
+
except subprocess.CalledProcessError as e:
|
13 |
+
print("Cloning failed:", e)
|
14 |
+
|
15 |
+
# Example usage
|
16 |
+
repo_url = "https://huggingface.co/ttmn/SolLlama-mtr"
|
17 |
+
destination_dir = "./SolLlama-mtr"
|
18 |
+
|
19 |
+
git_clone(repo_url, destination_dir)
|
20 |
+
|
21 |
+
import sys
|
22 |
+
import os
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
import pandas as pd
|
26 |
+
import warnings
|
27 |
+
import lightning as L
|
28 |
+
torch.set_float32_matmul_precision('high')
|
29 |
+
warnings.filterwarnings("ignore", module="pl_bolts")
|
30 |
+
|
31 |
+
sys.path.append( '../')
|
32 |
+
|
33 |
+
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
|
34 |
+
import auto_evaluator_sl
|
35 |
+
|
36 |
+
torch.manual_seed(1004)
|
37 |
+
np.random.seed(1004)
|
38 |
+
|
39 |
+
smiles_str = st.text_area('Enter SMILE string')
|
40 |
+
file_path = './smiles_str.txt'
|
41 |
+
|
42 |
+
# Open the file in write mode ('w') and write the content
|
43 |
+
with open(file_path, 'w') as file:
|
44 |
+
file.write(smiles_str)
|
45 |
+
|
46 |
+
# smiles_str = "CC02"
|
47 |
+
|
48 |
+
###
|
49 |
+
# solute_or_solvent = 'solute'
|
50 |
+
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
|
51 |
+
ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
|
52 |
+
batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
|
53 |
+
# since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
|
54 |
+
lr = 0.0001
|
55 |
+
epochs = 7
|
56 |
+
use_freeze = False # Freeze the model or not # False measn not freezing
|
57 |
+
overwrite_level_2 = True
|
58 |
+
###
|
59 |
+
max_seq_length = 512
|
60 |
+
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
|
61 |
+
max_seq_length=max_seq_length,
|
62 |
+
)
|
63 |
+
max_length = max_seq_length
|
64 |
+
num_workers = 2
|
65 |
+
|
66 |
+
# I just reused our previous research code with some modifications.
|
67 |
+
dir_main = "."
|
68 |
+
name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
|
69 |
+
|
70 |
+
dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
|
71 |
+
|
72 |
+
max_seq_length = 512
|
73 |
+
|
74 |
+
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
|
75 |
+
max_seq_length=max_seq_length,
|
76 |
+
)
|
77 |
+
max_length = max_seq_length
|
78 |
+
num_workers = 2
|
79 |
+
|
80 |
+
## FT
|
81 |
+
|
82 |
+
ver_ft = 0
|
83 |
+
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
|
84 |
+
# name_model_ft = 'Solvent.pt'
|
85 |
+
name_model_ft = f"{solute_or_solvent}.pt"
|
86 |
+
|
87 |
+
# Load dataset for finetune
|
88 |
+
batch_size_for_train = batch_size_pair[0]
|
89 |
+
batch_size_for_valid = batch_size_pair[1]
|
90 |
+
|
91 |
+
data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
|
92 |
+
solute_or_solvent=solute_or_solvent,
|
93 |
+
tokenizer=tokenizer,
|
94 |
+
max_seq_length=max_length,
|
95 |
+
batch_size_train=batch_size_for_train,
|
96 |
+
batch_size_valid=batch_size_for_valid,
|
97 |
+
# num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
|
98 |
+
num_device=num_workers,
|
99 |
+
)
|
100 |
+
|
101 |
+
data_module.prepare_data()
|
102 |
+
data_module.setup()
|
103 |
+
steps_per_epoch = len(data_module.test_dataloader())
|
104 |
+
|
105 |
+
# Load model and optimizer for finetune
|
106 |
+
learning_rate = lr
|
107 |
+
|
108 |
+
|
109 |
+
model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
|
110 |
+
|
111 |
+
|
112 |
+
model_ft = model_finetune_sl.CustomFinetuneModel(
|
113 |
+
model_mtr=model_mtr,
|
114 |
+
steps_per_epoch=steps_per_epoch,
|
115 |
+
warmup_epochs=1,
|
116 |
+
max_epochs=epochs,
|
117 |
+
learning_rate=learning_rate,
|
118 |
+
# dataset_dict=dataset_dict,
|
119 |
+
use_freeze=use_freeze,
|
120 |
+
)
|
121 |
+
|
122 |
+
# 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
|
123 |
+
|
124 |
+
trainer = L.Trainer(
|
125 |
+
default_root_dir=dir_model_ft_to_save,
|
126 |
+
# profiler=profiler,
|
127 |
+
# logger=csv_logger,
|
128 |
+
accelerator='auto',
|
129 |
+
devices='auto',
|
130 |
+
# accelerator='gpu',
|
131 |
+
# devices=[0],
|
132 |
+
min_epochs=1,
|
133 |
+
max_epochs=epochs,
|
134 |
+
precision=32,
|
135 |
+
# callbacks=[checkpoint_callback]
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
# Predict
|
140 |
+
local_model_ft = utils_sl.load_model_ft_with(
|
141 |
+
class_model_ft=model_ft,
|
142 |
+
dir_model_ft=dir_model_ft_to_save,
|
143 |
+
name_model_ft=name_model_ft
|
144 |
+
)
|
145 |
+
|
146 |
+
result = trainer.predict(local_model_ft, data_module)
|
147 |
+
result_pred = list()
|
148 |
+
result_label = list()
|
149 |
+
for bat in range(len(result)):
|
150 |
+
result_pred.append(result[bat][0].squeeze())
|
151 |
+
result_label.append(result[bat][1])
|
152 |
+
|
153 |
+
st.write(result_pred)
|
.ipynb_checkpoints/chemllama_mtr-checkpoint.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
import torch
|
3 |
+
import torchmetrics
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from transformers import LlamaModel, LlamaConfig
|
7 |
+
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
8 |
+
|
9 |
+
class ChemLlama(L.LightningModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
max_position_embeddings,
|
13 |
+
vocab_size,
|
14 |
+
pad_token_id,
|
15 |
+
bos_token_id,
|
16 |
+
eos_token_id,
|
17 |
+
steps_per_epoch=None, #
|
18 |
+
warmup_epochs=None, #
|
19 |
+
max_epochs=None, #
|
20 |
+
hidden_size=384,
|
21 |
+
intermediate_size=464,
|
22 |
+
num_labels=105,
|
23 |
+
attention_dropout=0.144,
|
24 |
+
num_hidden_layers=3,
|
25 |
+
num_attention_heads=12,
|
26 |
+
learning_rate=0.0001,
|
27 |
+
):
|
28 |
+
super(ChemLlama, self).__init__()
|
29 |
+
self.save_hyperparameters()
|
30 |
+
|
31 |
+
self.hidden_size = hidden_size
|
32 |
+
self.intermediate_size = intermediate_size
|
33 |
+
self.num_labels = num_labels
|
34 |
+
self.vocab_size = vocab_size
|
35 |
+
self.pad_token_id = pad_token_id
|
36 |
+
self.bos_token_id = bos_token_id
|
37 |
+
self.eos_token_id = eos_token_id
|
38 |
+
self.steps_per_epoch = steps_per_epoch #
|
39 |
+
self.warmup_epochs = warmup_epochs #
|
40 |
+
self.max_epochs = max_epochs #
|
41 |
+
self.num_hidden_layers = num_hidden_layers
|
42 |
+
self.num_attention_heads = num_attention_heads
|
43 |
+
self.attention_dropout = attention_dropout
|
44 |
+
self.max_position_embeddings = max_position_embeddings
|
45 |
+
self.learning_rate = learning_rate
|
46 |
+
|
47 |
+
self.mae = torchmetrics.MeanAbsoluteError()
|
48 |
+
self.mse = torchmetrics.MeanSquaredError()
|
49 |
+
|
50 |
+
self.config_llama = LlamaConfig(
|
51 |
+
max_position_embeddings=self.max_position_embeddings,
|
52 |
+
vocab_size=self.vocab_size,
|
53 |
+
hidden_size=self.hidden_size,
|
54 |
+
intermediate_size=self.intermediate_size,
|
55 |
+
num_hidden_layers=self.num_hidden_layers,
|
56 |
+
num_attention_heads=self.num_attention_heads,
|
57 |
+
attention_dropout=self.attention_dropout,
|
58 |
+
pad_token_id=self.pad_token_id,
|
59 |
+
bos_token_id=self.bos_token_id,
|
60 |
+
eos_token_id=self.eos_token_id,
|
61 |
+
)
|
62 |
+
|
63 |
+
self.loss_fn = nn.L1Loss()
|
64 |
+
|
65 |
+
self.llama = LlamaModel(self.config_llama)
|
66 |
+
self.gelu = nn.GELU()
|
67 |
+
self.score = nn.Linear(self.hidden_size, self.num_labels)
|
68 |
+
|
69 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
70 |
+
|
71 |
+
transformer_outputs = self.llama(
|
72 |
+
input_ids=input_ids, attention_mask=attention_mask
|
73 |
+
)
|
74 |
+
|
75 |
+
hidden_states = transformer_outputs[0]
|
76 |
+
hidden_states = self.gelu(hidden_states)
|
77 |
+
logits = self.score(hidden_states)
|
78 |
+
|
79 |
+
if input_ids is not None:
|
80 |
+
batch_size = input_ids.shape[0]
|
81 |
+
else:
|
82 |
+
batch_size = inputs_embeds.shape[0]
|
83 |
+
|
84 |
+
if self.config_llama.pad_token_id is None and batch_size != 1:
|
85 |
+
raise ValueError(
|
86 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
87 |
+
)
|
88 |
+
if self.config_llama.pad_token_id is None:
|
89 |
+
sequence_lengths = -1
|
90 |
+
else:
|
91 |
+
if input_ids is not None:
|
92 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
93 |
+
sequence_lengths = (
|
94 |
+
torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
|
95 |
+
- 1
|
96 |
+
)
|
97 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
98 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
99 |
+
else:
|
100 |
+
sequence_lengths = -1
|
101 |
+
# raise ValueError(len(sequence_lengths), sequence_lengths)
|
102 |
+
|
103 |
+
pooled_logits = logits[
|
104 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
105 |
+
]
|
106 |
+
return pooled_logits
|
107 |
+
|
108 |
+
def training_step(self, batch, batch_idx):
|
109 |
+
|
110 |
+
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx)
|
111 |
+
|
112 |
+
# mae = self.mae(logits, labels)
|
113 |
+
# mse = self.mse(logits, labels)
|
114 |
+
self.log_dict(
|
115 |
+
{
|
116 |
+
"train_loss": loss,
|
117 |
+
# "train_mae": mae,
|
118 |
+
# "train_mse": mse
|
119 |
+
},
|
120 |
+
on_step=True,
|
121 |
+
on_epoch=True,
|
122 |
+
prog_bar=True,
|
123 |
+
sync_dist=True,
|
124 |
+
# logger=True,
|
125 |
+
)
|
126 |
+
# on_stecp = True will use lots of computational resources
|
127 |
+
|
128 |
+
# return loss
|
129 |
+
return {"loss": loss, "logits": logits, "labels": labels}
|
130 |
+
|
131 |
+
def train_epoch_end(self, outputs):
|
132 |
+
# avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
|
133 |
+
scores = torch.cat([x["logits"] for x in outputs])
|
134 |
+
labels = torch.cat([x["labels"] for x in outputs])
|
135 |
+
self.log_dict(
|
136 |
+
{
|
137 |
+
"train_mae": self.mae(scores, labels),
|
138 |
+
"train_mse": self.mse(scores, labels)
|
139 |
+
},
|
140 |
+
on_step=True,
|
141 |
+
on_epoch=True,
|
142 |
+
prog_bar=True,
|
143 |
+
sync_dist=True,
|
144 |
+
)
|
145 |
+
|
146 |
+
def validation_step(self, batch, batch_idx):
|
147 |
+
|
148 |
+
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx)
|
149 |
+
# self.log("val_loss", loss)
|
150 |
+
self.log("val_loss", loss, sync_dist=True)
|
151 |
+
return loss
|
152 |
+
|
153 |
+
def test_step(self, batch, batch_idx):
|
154 |
+
|
155 |
+
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx)
|
156 |
+
# self.log("val_loss", loss)
|
157 |
+
self.log("test_loss", loss, sync_dist=True,)
|
158 |
+
return loss
|
159 |
+
|
160 |
+
def _common_step(self, batch, batch_idx):
|
161 |
+
|
162 |
+
logits = self.forward(
|
163 |
+
input_ids=batch["input_ids"].squeeze(),
|
164 |
+
attention_mask=batch["attention_mask"].squeeze(),
|
165 |
+
)
|
166 |
+
|
167 |
+
labels = batch["labels"].squeeze()
|
168 |
+
loss = self.loss_fn(logits, labels)
|
169 |
+
|
170 |
+
# print(f"logits : {logits.shape}")
|
171 |
+
# print(f"labels : {labels.shape}")
|
172 |
+
|
173 |
+
return loss, logits, labels
|
174 |
+
|
175 |
+
# def configure_optimizers(self): # Schedular here too!
|
176 |
+
# # since confiture_optimizers and the model are included in the same class.. self.parameters()
|
177 |
+
# return torch.optim.AdamW(
|
178 |
+
# params=self.parameters(),
|
179 |
+
# lr=self.learning_rate,
|
180 |
+
# betas=(0.9, 0.999),
|
181 |
+
# weight_decay=0.01,
|
182 |
+
# )
|
183 |
+
|
184 |
+
# # The below is for warm-up scheduler
|
185 |
+
# https://lightning.ai/forums/t/how-to-use-warmup-lr-cosineannealinglr-in-lightning/1980
|
186 |
+
# https://github.com/Lightning-AI/pytorch-lightning/issues/328
|
187 |
+
def configure_optimizers(self): # Schedular here too!
|
188 |
+
# since confiture_optimizers and the model are included in the same class.. self.parameters()
|
189 |
+
optimizer = torch.optim.AdamW(
|
190 |
+
params=self.parameters(),
|
191 |
+
lr=self.learning_rate,
|
192 |
+
betas=(0.9, 0.999),
|
193 |
+
weight_decay=0.01,
|
194 |
+
)
|
195 |
+
|
196 |
+
# "warmup_epochs //4 only not max_epochs" will work
|
197 |
+
scheduler = LinearWarmupCosineAnnealingLR(
|
198 |
+
optimizer,
|
199 |
+
warmup_epochs=self.warmup_epochs*self.steps_per_epoch // torch.cuda.device_count(), # // num_device in this case
|
200 |
+
max_epochs=self.max_epochs*self.steps_per_epoch // torch.cuda.device_count(),
|
201 |
+
)
|
202 |
+
|
203 |
+
return {
|
204 |
+
"optimizer": optimizer,
|
205 |
+
"lr_scheduler": {
|
206 |
+
"scheduler": scheduler,
|
207 |
+
"interval": "step",
|
208 |
+
"frequency": 1,
|
209 |
+
"reduce_on_plateau": False,
|
210 |
+
"monitor": "val_loss",
|
211 |
+
}
|
212 |
+
}
|
app.py
CHANGED
@@ -33,6 +33,9 @@ sys.path.append( '../')
|
|
33 |
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
|
34 |
import auto_evaluator_sl
|
35 |
|
|
|
|
|
|
|
36 |
torch.manual_seed(1004)
|
37 |
np.random.seed(1004)
|
38 |
|
@@ -48,26 +51,143 @@ with open(file_path, 'w') as file:
|
|
48 |
###
|
49 |
# solute_or_solvent = 'solute'
|
50 |
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# I just reused our previous research code with some modifications.
|
67 |
dir_main = "."
|
68 |
-
name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
|
69 |
-
|
70 |
-
dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
|
71 |
|
72 |
max_seq_length = 512
|
73 |
|
@@ -79,75 +199,95 @@ num_workers = 2
|
|
79 |
|
80 |
## FT
|
81 |
|
82 |
-
ver_ft = 0
|
83 |
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
|
84 |
# name_model_ft = 'Solvent.pt'
|
85 |
name_model_ft = f"{solute_or_solvent}.pt"
|
86 |
|
87 |
-
# Load dataset for finetune
|
88 |
-
batch_size_for_train = batch_size_pair[0]
|
89 |
-
batch_size_for_valid = batch_size_pair[1]
|
90 |
-
|
91 |
-
data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
)
|
100 |
-
|
101 |
-
data_module.prepare_data()
|
102 |
-
data_module.setup()
|
103 |
-
steps_per_epoch = len(data_module.test_dataloader())
|
104 |
-
|
105 |
-
# Load model and optimizer for finetune
|
106 |
-
learning_rate = lr
|
107 |
-
|
108 |
-
|
109 |
-
model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
|
110 |
-
|
111 |
-
|
112 |
-
model_ft = model_finetune_sl.CustomFinetuneModel(
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
)
|
121 |
-
|
122 |
-
# 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
|
123 |
-
|
124 |
-
trainer = L.Trainer(
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
)
|
137 |
-
|
138 |
-
|
139 |
# Predict
|
140 |
local_model_ft = utils_sl.load_model_ft_with(
|
141 |
-
class_model_ft=
|
142 |
dir_model_ft=dir_model_ft_to_save,
|
143 |
name_model_ft=name_model_ft
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
st.write(
|
|
|
33 |
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
|
34 |
import auto_evaluator_sl
|
35 |
|
36 |
+
from torch.utils.data import Dataset, DataLoader
|
37 |
+
from transformers import DataCollatorWithPadding
|
38 |
+
|
39 |
torch.manual_seed(1004)
|
40 |
np.random.seed(1004)
|
41 |
|
|
|
51 |
###
|
52 |
# solute_or_solvent = 'solute'
|
53 |
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
|
54 |
+
|
55 |
+
|
56 |
+
class ChemLlama(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
max_position_embeddings=512,
|
60 |
+
vocab_size=591,
|
61 |
+
pad_token_id=0,
|
62 |
+
bos_token_id=12,
|
63 |
+
eos_token_id=13,
|
64 |
+
hidden_size=768,
|
65 |
+
intermediate_size=768,
|
66 |
+
num_labels=105,
|
67 |
+
attention_dropout=0.144,
|
68 |
+
num_hidden_layers=7,
|
69 |
+
num_attention_heads=8,
|
70 |
+
learning_rate=0.0001,
|
71 |
+
):
|
72 |
+
super(ChemLlama, self).__init__()
|
73 |
+
|
74 |
+
self.hidden_size = hidden_size
|
75 |
+
self.intermediate_size = intermediate_size
|
76 |
+
self.num_labels = num_labels
|
77 |
+
self.vocab_size = vocab_size
|
78 |
+
self.pad_token_id = pad_token_id
|
79 |
+
self.bos_token_id = bos_token_id
|
80 |
+
self.eos_token_id = eos_token_id
|
81 |
+
self.num_hidden_layers = num_hidden_layers
|
82 |
+
self.num_attention_heads = num_attention_heads
|
83 |
+
self.attention_dropout = attention_dropout
|
84 |
+
self.max_position_embeddings = max_position_embeddings
|
85 |
+
|
86 |
+
self.mae = torchmetrics.MeanAbsoluteError()
|
87 |
+
self.mse = torchmetrics.MeanSquaredError()
|
88 |
+
|
89 |
+
self.config_llama = LlamaConfig(
|
90 |
+
max_position_embeddings=self.max_position_embeddings,
|
91 |
+
vocab_size=self.vocab_size,
|
92 |
+
hidden_size=self.hidden_size,
|
93 |
+
intermediate_size=self.intermediate_size,
|
94 |
+
num_hidden_layers=self.num_hidden_layers,
|
95 |
+
num_attention_heads=self.num_attention_heads,
|
96 |
+
attention_dropout=self.attention_dropout,
|
97 |
+
pad_token_id=self.pad_token_id,
|
98 |
+
bos_token_id=self.bos_token_id,
|
99 |
+
eos_token_id=self.eos_token_id,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.loss_fn = nn.L1Loss()
|
103 |
+
|
104 |
+
self.llama = LlamaModel(self.config_llama)
|
105 |
+
self.gelu = nn.GELU()
|
106 |
+
self.score = nn.Linear(self.hidden_size, self.num_labels)
|
107 |
+
|
108 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
109 |
+
|
110 |
+
transformer_outputs = self.llama(
|
111 |
+
input_ids=input_ids, attention_mask=attention_mask
|
112 |
+
)
|
113 |
+
|
114 |
+
hidden_states = transformer_outputs[0]
|
115 |
+
hidden_states = self.gelu(hidden_states)
|
116 |
+
logits = self.score(hidden_states)
|
117 |
+
|
118 |
+
if input_ids is not None:
|
119 |
+
batch_size = input_ids.shape[0]
|
120 |
+
else:
|
121 |
+
batch_size = inputs_embeds.shape[0]
|
122 |
+
|
123 |
+
if self.config_llama.pad_token_id is None and batch_size != 1:
|
124 |
+
raise ValueError(
|
125 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
126 |
+
)
|
127 |
+
if self.config_llama.pad_token_id is None:
|
128 |
+
sequence_lengths = -1
|
129 |
+
else:
|
130 |
+
if input_ids is not None:
|
131 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
132 |
+
sequence_lengths = (
|
133 |
+
torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
|
134 |
+
- 1
|
135 |
+
)
|
136 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
137 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
138 |
+
else:
|
139 |
+
sequence_lengths = -1
|
140 |
+
# raise ValueError(len(sequence_lengths), sequence_lengths)
|
141 |
+
|
142 |
+
pooled_logits = logits[
|
143 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
144 |
+
]
|
145 |
+
return pooled_logits
|
146 |
+
|
147 |
+
|
148 |
+
chemllama_mtr = ChemLlama()
|
149 |
+
|
150 |
+
class ChemLlama_FT(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
model_mtr,
|
154 |
+
linear_param:int=64,
|
155 |
+
use_freeze:bool=True,
|
156 |
+
*args, **kwargs
|
157 |
+
):
|
158 |
+
super(CustomFinetuneModel, self).__init__()
|
159 |
+
# self.save_hyperparameters()
|
160 |
+
|
161 |
+
self.model_mtr = model_mtr
|
162 |
+
if use_freeze:
|
163 |
+
self.model_mtr.freeze()
|
164 |
+
# for name, param in model_mtr.named_parameters():
|
165 |
+
# param.requires_grad = False
|
166 |
+
# print(name, param.requires_grad)
|
167 |
+
|
168 |
+
self.gelu = nn.GELU()
|
169 |
+
self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param)
|
170 |
+
self.linear2 = nn.Linear(linear_param, linear_param)
|
171 |
+
self.regression = nn.Linear(linear_param, 5)
|
172 |
+
|
173 |
+
self.loss_fn = nn.L1Loss()
|
174 |
+
|
175 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
176 |
+
x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask)
|
177 |
+
x = self.gelu(x)
|
178 |
+
x = self.linear1(x)
|
179 |
+
x = self.gelu(x)
|
180 |
+
x = self.linear2(x)
|
181 |
+
x = self.gelu(x)
|
182 |
+
x = self.regression(x)
|
183 |
+
|
184 |
+
return x
|
185 |
+
|
186 |
+
chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr)
|
187 |
+
|
188 |
|
189 |
# I just reused our previous research code with some modifications.
|
190 |
dir_main = "."
|
|
|
|
|
|
|
191 |
|
192 |
max_seq_length = 512
|
193 |
|
|
|
199 |
|
200 |
## FT
|
201 |
|
|
|
202 |
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
|
203 |
# name_model_ft = 'Solvent.pt'
|
204 |
name_model_ft = f"{solute_or_solvent}.pt"
|
205 |
|
206 |
+
# # Load dataset for finetune
|
207 |
+
# batch_size_for_train = batch_size_pair[0]
|
208 |
+
# batch_size_for_valid = batch_size_pair[1]
|
209 |
+
|
210 |
+
# data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
|
211 |
+
# solute_or_solvent=solute_or_solvent,
|
212 |
+
# tokenizer=tokenizer,
|
213 |
+
# max_seq_length=max_length,
|
214 |
+
# batch_size_train=batch_size_for_train,
|
215 |
+
# batch_size_valid=batch_size_for_valid,
|
216 |
+
# # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
|
217 |
+
# num_device=num_workers,
|
218 |
+
# )
|
219 |
+
|
220 |
+
# data_module.prepare_data()
|
221 |
+
# data_module.setup()
|
222 |
+
# steps_per_epoch = len(data_module.test_dataloader())
|
223 |
+
|
224 |
+
# # Load model and optimizer for finetune
|
225 |
+
# learning_rate = lr
|
226 |
+
|
227 |
+
|
228 |
+
# model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
|
229 |
+
|
230 |
+
|
231 |
+
# model_ft = model_finetune_sl.CustomFinetuneModel(
|
232 |
+
# model_mtr=model_mtr,
|
233 |
+
# steps_per_epoch=steps_per_epoch,
|
234 |
+
# warmup_epochs=1,
|
235 |
+
# max_epochs=epochs,
|
236 |
+
# learning_rate=learning_rate,
|
237 |
+
# # dataset_dict=dataset_dict,
|
238 |
+
# use_freeze=use_freeze,
|
239 |
+
# )
|
240 |
+
|
241 |
+
# # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
|
242 |
+
|
243 |
+
# trainer = L.Trainer(
|
244 |
+
# default_root_dir=dir_model_ft_to_save,
|
245 |
+
# # profiler=profiler,
|
246 |
+
# # logger=csv_logger,
|
247 |
+
# accelerator='auto',
|
248 |
+
# devices='auto',
|
249 |
+
# # accelerator='gpu',
|
250 |
+
# # devices=[0],
|
251 |
+
# min_epochs=1,
|
252 |
+
# max_epochs=epochs,
|
253 |
+
# precision=32,
|
254 |
+
# # callbacks=[checkpoint_callback]
|
255 |
+
# )
|
256 |
+
|
257 |
+
device = 'cpu'
|
258 |
# Predict
|
259 |
local_model_ft = utils_sl.load_model_ft_with(
|
260 |
+
class_model_ft=chemllama_ft,
|
261 |
dir_model_ft=dir_model_ft_to_save,
|
262 |
name_model_ft=name_model_ft
|
263 |
+
).to(device)
|
264 |
+
|
265 |
+
# result = trainer.predict(local_model_ft, data_module)
|
266 |
+
# result_pred = list()
|
267 |
+
# result_label = list()
|
268 |
+
# for bat in range(len(result)):
|
269 |
+
# result_pred.append(result[bat][0].squeeze())
|
270 |
+
# result_label.append(result[bat][1])
|
271 |
+
|
272 |
+
with open('./smiles_str.txt', 'r') as file:
|
273 |
+
smiles_str = file.readline()
|
274 |
+
|
275 |
+
dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham(
|
276 |
+
df=pd.DataFrame([smiles_str]),
|
277 |
+
tokenizer=tokenizer,
|
278 |
+
max_seq_length=max_length
|
279 |
)
|
280 |
|
281 |
+
dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer))
|
282 |
+
|
283 |
+
list_predictions = []
|
284 |
+
local_model_ft.eval()
|
285 |
+
with torch.inference_mode():
|
286 |
+
for i, v_batch in enumerate(dataloader_test):
|
287 |
+
v_input_ids = v_batch['input_ids'].to(device)
|
288 |
+
v_attention_mask = v_batch['attention_mask'].to(device)
|
289 |
+
# v_y_labels = v_batch['labels'].to(device)
|
290 |
+
v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask)
|
291 |
+
list_predictions.append(v_y_logits[0][0].tolist())
|
292 |
|
293 |
+
st.write(list_predictions)
|
app_old.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import streamlit as st
|
3 |
+
# from git import Repo
|
4 |
+
|
5 |
+
# Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr')
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
def git_clone(repo_url, destination_dir):
|
9 |
+
try:
|
10 |
+
subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True)
|
11 |
+
print("Cloning successful!")
|
12 |
+
except subprocess.CalledProcessError as e:
|
13 |
+
print("Cloning failed:", e)
|
14 |
+
|
15 |
+
# Example usage
|
16 |
+
repo_url = "https://huggingface.co/ttmn/SolLlama-mtr"
|
17 |
+
destination_dir = "./SolLlama-mtr"
|
18 |
+
|
19 |
+
git_clone(repo_url, destination_dir)
|
20 |
+
|
21 |
+
import sys
|
22 |
+
import os
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
import pandas as pd
|
26 |
+
import warnings
|
27 |
+
import lightning as L
|
28 |
+
torch.set_float32_matmul_precision('high')
|
29 |
+
warnings.filterwarnings("ignore", module="pl_bolts")
|
30 |
+
|
31 |
+
sys.path.append( '../')
|
32 |
+
|
33 |
+
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
|
34 |
+
import auto_evaluator_sl
|
35 |
+
|
36 |
+
torch.manual_seed(1004)
|
37 |
+
np.random.seed(1004)
|
38 |
+
|
39 |
+
smiles_str = st.text_area('Enter SMILE string')
|
40 |
+
file_path = './smiles_str.txt'
|
41 |
+
|
42 |
+
# Open the file in write mode ('w') and write the content
|
43 |
+
with open(file_path, 'w') as file:
|
44 |
+
file.write(smiles_str)
|
45 |
+
|
46 |
+
# smiles_str = "CC02"
|
47 |
+
|
48 |
+
###
|
49 |
+
# solute_or_solvent = 'solute'
|
50 |
+
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
|
51 |
+
ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
|
52 |
+
batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
|
53 |
+
# since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
|
54 |
+
lr = 0.0001
|
55 |
+
epochs = 7
|
56 |
+
use_freeze = False # Freeze the model or not # False measn not freezing
|
57 |
+
overwrite_level_2 = True
|
58 |
+
###
|
59 |
+
max_seq_length = 512
|
60 |
+
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
|
61 |
+
max_seq_length=max_seq_length,
|
62 |
+
)
|
63 |
+
max_length = max_seq_length
|
64 |
+
num_workers = 2
|
65 |
+
|
66 |
+
# I just reused our previous research code with some modifications.
|
67 |
+
dir_main = "."
|
68 |
+
name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
|
69 |
+
|
70 |
+
dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
|
71 |
+
|
72 |
+
max_seq_length = 512
|
73 |
+
|
74 |
+
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
|
75 |
+
max_seq_length=max_seq_length,
|
76 |
+
)
|
77 |
+
max_length = max_seq_length
|
78 |
+
num_workers = 2
|
79 |
+
|
80 |
+
## FT
|
81 |
+
|
82 |
+
ver_ft = 0
|
83 |
+
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
|
84 |
+
# name_model_ft = 'Solvent.pt'
|
85 |
+
name_model_ft = f"{solute_or_solvent}.pt"
|
86 |
+
|
87 |
+
# Load dataset for finetune
|
88 |
+
batch_size_for_train = batch_size_pair[0]
|
89 |
+
batch_size_for_valid = batch_size_pair[1]
|
90 |
+
|
91 |
+
data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
|
92 |
+
solute_or_solvent=solute_or_solvent,
|
93 |
+
tokenizer=tokenizer,
|
94 |
+
max_seq_length=max_length,
|
95 |
+
batch_size_train=batch_size_for_train,
|
96 |
+
batch_size_valid=batch_size_for_valid,
|
97 |
+
# num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
|
98 |
+
num_device=num_workers,
|
99 |
+
)
|
100 |
+
|
101 |
+
data_module.prepare_data()
|
102 |
+
data_module.setup()
|
103 |
+
steps_per_epoch = len(data_module.test_dataloader())
|
104 |
+
|
105 |
+
# Load model and optimizer for finetune
|
106 |
+
learning_rate = lr
|
107 |
+
|
108 |
+
|
109 |
+
model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
|
110 |
+
|
111 |
+
|
112 |
+
model_ft = model_finetune_sl.CustomFinetuneModel(
|
113 |
+
model_mtr=model_mtr,
|
114 |
+
steps_per_epoch=steps_per_epoch,
|
115 |
+
warmup_epochs=1,
|
116 |
+
max_epochs=epochs,
|
117 |
+
learning_rate=learning_rate,
|
118 |
+
# dataset_dict=dataset_dict,
|
119 |
+
use_freeze=use_freeze,
|
120 |
+
)
|
121 |
+
|
122 |
+
# 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
|
123 |
+
|
124 |
+
trainer = L.Trainer(
|
125 |
+
default_root_dir=dir_model_ft_to_save,
|
126 |
+
# profiler=profiler,
|
127 |
+
# logger=csv_logger,
|
128 |
+
accelerator='auto',
|
129 |
+
devices='auto',
|
130 |
+
# accelerator='gpu',
|
131 |
+
# devices=[0],
|
132 |
+
min_epochs=1,
|
133 |
+
max_epochs=epochs,
|
134 |
+
precision=32,
|
135 |
+
# callbacks=[checkpoint_callback]
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
# Predict
|
140 |
+
local_model_ft = utils_sl.load_model_ft_with(
|
141 |
+
class_model_ft=model_ft,
|
142 |
+
dir_model_ft=dir_model_ft_to_save,
|
143 |
+
name_model_ft=name_model_ft
|
144 |
+
)
|
145 |
+
|
146 |
+
result = trainer.predict(local_model_ft, data_module)
|
147 |
+
result_pred = list()
|
148 |
+
result_label = list()
|
149 |
+
for bat in range(len(result)):
|
150 |
+
result_pred.append(result[bat][0].squeeze())
|
151 |
+
result_label.append(result[bat][1])
|
152 |
+
|
153 |
+
st.write(result_pred)
|