SolLlama / app_old.py
BrightBlueCheese
app
5b2887e
raw
history blame contribute delete
No virus
4.12 kB
import streamlit as st
# from git import Repo
# Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr')
import subprocess
def git_clone(repo_url, destination_dir):
try:
subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True)
print("Cloning successful!")
except subprocess.CalledProcessError as e:
print("Cloning failed:", e)
# Example usage
repo_url = "https://huggingface.co/ttmn/SolLlama-mtr"
destination_dir = "./SolLlama-mtr"
git_clone(repo_url, destination_dir)
import sys
import os
import torch
import numpy as np
import pandas as pd
import warnings
import lightning as L
torch.set_float32_matmul_precision('high')
warnings.filterwarnings("ignore", module="pl_bolts")
sys.path.append( '../')
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
import auto_evaluator_sl
torch.manual_seed(1004)
np.random.seed(1004)
smiles_str = st.text_area('Enter SMILE string')
file_path = './smiles_str.txt'
# Open the file in write mode ('w') and write the content
with open(file_path, 'w') as file:
file.write(smiles_str)
# smiles_str = "CC02"
###
# solute_or_solvent = 'solute'
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
ver_ft = 0 # version control for FT model & evaluation data # Or it will overwrite the models and results
batch_size_pair = [64, 64] if solute_or_solvent == 'Solute' else [10, 10] # [train, valid(test)]
# since 'solute' has very small dataset. So I thinl 10 for train and 10 for valid(test) should be the maximum values.
lr = 0.0001
epochs = 7
use_freeze = False # Freeze the model or not # False measn not freezing
overwrite_level_2 = True
###
max_seq_length = 512
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
max_seq_length=max_seq_length,
)
max_length = max_seq_length
num_workers = 2
# I just reused our previous research code with some modifications.
dir_main = "."
name_model_mtr = "ChemLlama_Medium_30m_vloss_val_loss=0.029_ep_epoch=04.ckpt"
dir_model_mtr = f"{dir_main}/SolLlama-mtr/{name_model_mtr}"
max_seq_length = 512
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
max_seq_length=max_seq_length,
)
max_length = max_seq_length
num_workers = 2
## FT
ver_ft = 0
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
# name_model_ft = 'Solvent.pt'
name_model_ft = f"{solute_or_solvent}.pt"
# Load dataset for finetune
batch_size_for_train = batch_size_pair[0]
batch_size_for_valid = batch_size_pair[1]
data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
solute_or_solvent=solute_or_solvent,
tokenizer=tokenizer,
max_seq_length=max_length,
batch_size_train=batch_size_for_train,
batch_size_valid=batch_size_for_valid,
# num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
num_device=num_workers,
)
data_module.prepare_data()
data_module.setup()
steps_per_epoch = len(data_module.test_dataloader())
# Load model and optimizer for finetune
learning_rate = lr
model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
model_ft = model_finetune_sl.CustomFinetuneModel(
model_mtr=model_mtr,
steps_per_epoch=steps_per_epoch,
warmup_epochs=1,
max_epochs=epochs,
learning_rate=learning_rate,
# dataset_dict=dataset_dict,
use_freeze=use_freeze,
)
# 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
trainer = L.Trainer(
default_root_dir=dir_model_ft_to_save,
# profiler=profiler,
# logger=csv_logger,
accelerator='auto',
devices='auto',
# accelerator='gpu',
# devices=[0],
min_epochs=1,
max_epochs=epochs,
precision=32,
# callbacks=[checkpoint_callback]
)
# Predict
local_model_ft = utils_sl.load_model_ft_with(
class_model_ft=model_ft,
dir_model_ft=dir_model_ft_to_save,
name_model_ft=name_model_ft
)
result = trainer.predict(local_model_ft, data_module)
result_pred = list()
result_label = list()
for bat in range(len(result)):
result_pred.append(result[bat][0].squeeze())
result_label.append(result[bat][1])
st.write(result_pred)