|
|
|
import streamlit as st |
|
|
|
|
|
|
|
import subprocess |
|
|
|
def git_clone(repo_url, destination_dir): |
|
try: |
|
subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True) |
|
print("Cloning successful!") |
|
except subprocess.CalledProcessError as e: |
|
print("Cloning failed:", e) |
|
|
|
|
|
repo_url = "https://huggingface.co/ttmn/SolLlama-mtr" |
|
destination_dir = "./SolLlama-mtr" |
|
|
|
git_clone(repo_url, destination_dir) |
|
|
|
import sys |
|
import os |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import warnings |
|
import lightning as L |
|
torch.set_float32_matmul_precision('high') |
|
warnings.filterwarnings("ignore", module="pl_bolts") |
|
|
|
sys.path.append( '../') |
|
|
|
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl |
|
import auto_evaluator_sl |
|
|
|
torch.manual_seed(1004) |
|
np.random.seed(1004) |
|
|
|
smiles_str = st.text_area('Enter SMILE string') |
|
file_path = './smiles_str.txt' |
|
|
|
|
|
with open(file_path, 'w') as file: |
|
file.write(smiles_str) |
|
|
|
|
|
|
|
|
|
|
|
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent']) |
|
ver_ft = 0 |
|
batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] |
|
|
|
lr = 0.0001 |
|
epochs = 7 |
|
use_freeze = False |
|
overwrite_level_2 = True |
|
|
|
max_seq_length = 512 |
|
tokenizer = tokenizer_sl.fn_load_tokenizer_llama( |
|
max_seq_length=max_seq_length, |
|
) |
|
max_length = max_seq_length |
|
num_workers = 2 |
|
|
|
|
|
dir_main = "." |
|
name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt" |
|
|
|
dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}" |
|
|
|
max_seq_length = 512 |
|
|
|
tokenizer = tokenizer_sl.fn_load_tokenizer_llama( |
|
max_seq_length=max_seq_length, |
|
) |
|
max_length = max_seq_length |
|
num_workers = 2 |
|
|
|
|
|
|
|
ver_ft = 0 |
|
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr" |
|
|
|
name_model_ft = f"{solute_or_solvent}.pt" |
|
|
|
|
|
batch_size_for_train = batch_size_pair[0] |
|
batch_size_for_valid = batch_size_pair[1] |
|
|
|
data_module = datamodule_finetune_sl.CustomFinetuneDataModule( |
|
solute_or_solvent=solute_or_solvent, |
|
tokenizer=tokenizer, |
|
max_seq_length=max_length, |
|
batch_size_train=batch_size_for_train, |
|
batch_size_valid=batch_size_for_valid, |
|
|
|
num_device=num_workers, |
|
) |
|
|
|
data_module.prepare_data() |
|
data_module.setup() |
|
steps_per_epoch = len(data_module.test_dataloader()) |
|
|
|
|
|
learning_rate = lr |
|
|
|
|
|
model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr) |
|
|
|
|
|
model_ft = model_finetune_sl.CustomFinetuneModel( |
|
model_mtr=model_mtr, |
|
steps_per_epoch=steps_per_epoch, |
|
warmup_epochs=1, |
|
max_epochs=epochs, |
|
learning_rate=learning_rate, |
|
|
|
use_freeze=use_freeze, |
|
) |
|
|
|
|
|
|
|
trainer = L.Trainer( |
|
default_root_dir=dir_model_ft_to_save, |
|
|
|
|
|
accelerator='auto', |
|
devices='auto', |
|
|
|
|
|
min_epochs=1, |
|
max_epochs=epochs, |
|
precision=32, |
|
|
|
) |
|
|
|
|
|
|
|
local_model_ft = utils_sl.load_model_ft_with( |
|
class_model_ft=model_ft, |
|
dir_model_ft=dir_model_ft_to_save, |
|
name_model_ft=name_model_ft |
|
) |
|
|
|
result = trainer.predict(local_model_ft, data_module) |
|
result_pred = list() |
|
result_label = list() |
|
for bat in range(len(result)): |
|
result_pred.append(result[bat][0].squeeze()) |
|
result_label.append(result[bat][1]) |
|
|
|
st.write(result_pred) |