SolLlama / app.py
BrightBlueCheese
app
a4c4e65
import streamlit as st
# from git import Repo
# Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr')
import subprocess
def git_clone(repo_url, destination_dir):
try:
subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True)
print("Cloning successful!")
except subprocess.CalledProcessError as e:
print("Cloning failed:", e)
# Example usage
repo_url = "https://huggingface.co/ttmn/SolLlama-mtr"
destination_dir = "./SolLlama-mtr"
git_clone(repo_url, destination_dir)
import sys
import os
import torch
from torch import nn
import torchmetrics
from transformers import LlamaModel, LlamaConfig
import numpy as np
import pandas as pd
import warnings
import lightning as L
torch.set_float32_matmul_precision('high')
warnings.filterwarnings("ignore", module="pl_bolts")
sys.path.append( '../')
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl
import auto_evaluator_sl
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorWithPadding
torch.manual_seed(1004)
np.random.seed(1004)
st.title('Sol-LLaMA')
# file_path = './smiles_str.txt'
# # Open the file in write mode ('w') and write the content
# with open(file_path, 'w') as file:
# file.write(smiles_str)
# smiles_str = "CC02"
###
# solute_or_solvent = 'solute'
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent'])
if solute_or_solvent == 'Solute':
smiles_str = st.text_area('Enter SMILE string', value='Clc1ccc(cc1)N(=O)=O')
elif solute_or_solvent == 'Solvent':
smiles_str = st.text_area('Enter SMILE string', value='ClCCCl')
class ChemLlama(nn.Module):
def __init__(
self,
max_position_embeddings=512,
vocab_size=591,
pad_token_id=0,
bos_token_id=12,
eos_token_id=13,
hidden_size=768,
intermediate_size=768,
num_labels=105,
attention_dropout=0.144,
num_hidden_layers=7,
num_attention_heads=8,
learning_rate=0.0001,
):
super(ChemLlama, self).__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_labels = num_labels
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.mae = torchmetrics.MeanAbsoluteError()
self.mse = torchmetrics.MeanSquaredError()
self.config_llama = LlamaConfig(
max_position_embeddings=self.max_position_embeddings,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
attention_dropout=self.attention_dropout,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
)
self.loss_fn = nn.L1Loss()
self.llama = LlamaModel(self.config_llama)
self.gelu = nn.GELU()
self.score = nn.Linear(self.hidden_size, self.num_labels)
def forward(self, input_ids, attention_mask, labels=None):
transformer_outputs = self.llama(
input_ids=input_ids, attention_mask=attention_mask
)
hidden_states = transformer_outputs[0]
hidden_states = self.gelu(hidden_states)
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config_llama.pad_token_id is None and batch_size != 1:
raise ValueError(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if self.config_llama.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = (
torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1)
- 1
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
# raise ValueError(len(sequence_lengths), sequence_lengths)
pooled_logits = logits[
torch.arange(batch_size, device=logits.device), sequence_lengths
]
return pooled_logits
chemllama_mtr = ChemLlama()
class ChemLlama_FT(nn.Module):
def __init__(
self,
model_mtr,
linear_param:int=64,
use_freeze:bool=True,
*args, **kwargs
):
super(ChemLlama_FT, self).__init__()
# self.save_hyperparameters()
self.model_mtr = model_mtr
if use_freeze:
# self.model_mtr.freeze()
for name, param in model_mtr.named_parameters():
param.requires_grad = False
print(name, param.requires_grad)
self.gelu = nn.GELU()
self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param)
self.linear2 = nn.Linear(linear_param, linear_param)
self.regression = nn.Linear(linear_param, 5)
self.loss_fn = nn.L1Loss()
def forward(self, input_ids, attention_mask, labels=None):
x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask)
x = self.gelu(x)
x = self.linear1(x)
x = self.gelu(x)
x = self.linear2(x)
x = self.gelu(x)
x = self.regression(x)
return x
chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr)
# I just reused our previous research code with some modifications.
dir_main = "."
max_seq_length = 512
tokenizer = tokenizer_sl.fn_load_tokenizer_llama(
max_seq_length=max_seq_length,
)
max_length = max_seq_length
num_workers = 2
## FT
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr"
# name_model_ft = 'Solvent.pt'
name_model_ft = f"{solute_or_solvent}.pt"
# # Load dataset for finetune
# batch_size_for_train = batch_size_pair[0]
# batch_size_for_valid = batch_size_pair[1]
# data_module = datamodule_finetune_sl.CustomFinetuneDataModule(
# solute_or_solvent=solute_or_solvent,
# tokenizer=tokenizer,
# max_seq_length=max_length,
# batch_size_train=batch_size_for_train,
# batch_size_valid=batch_size_for_valid,
# # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER,
# num_device=num_workers,
# )
# data_module.prepare_data()
# data_module.setup()
# steps_per_epoch = len(data_module.test_dataloader())
# # Load model and optimizer for finetune
# learning_rate = lr
# model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr)
# model_ft = model_finetune_sl.CustomFinetuneModel(
# model_mtr=model_mtr,
# steps_per_epoch=steps_per_epoch,
# warmup_epochs=1,
# max_epochs=epochs,
# learning_rate=learning_rate,
# # dataset_dict=dataset_dict,
# use_freeze=use_freeze,
# )
# # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt'
# trainer = L.Trainer(
# default_root_dir=dir_model_ft_to_save,
# # profiler=profiler,
# # logger=csv_logger,
# accelerator='auto',
# devices='auto',
# # accelerator='gpu',
# # devices=[0],
# min_epochs=1,
# max_epochs=epochs,
# precision=32,
# # callbacks=[checkpoint_callback]
# )
device = 'cpu'
# Predict
local_model_ft = utils_sl.load_model_ft_with(
class_model_ft=chemllama_ft,
dir_model_ft=dir_model_ft_to_save,
name_model_ft=name_model_ft
).to(device)
# result = trainer.predict(local_model_ft, data_module)
# result_pred = list()
# result_label = list()
# for bat in range(len(result)):
# result_pred.append(result[bat][0].squeeze())
# result_label.append(result[bat][1])
# with open('./smiles_str.txt', 'r') as file:
# smiles_str = file.readline()
if st.button('Predict'):
# st.write(data)
dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham(
df=pd.DataFrame([smiles_str]),
tokenizer=tokenizer,
max_seq_length=max_length
)
st.subheader(f'Prediction - {solute_or_solvent}')
dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer))
list_predictions = []
local_model_ft.eval()
with torch.inference_mode():
for i, v_batch in enumerate(dataloader_test):
v_input_ids = v_batch['input_ids'].to(device)
v_attention_mask = v_batch['attention_mask'].to(device)
# v_y_labels = v_batch['labels'].to(device)
v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask)
# list_predictions.append(v_y_logits[0][0].tolist())
list_predictions.append(v_y_logits[0].tolist())
list_predictions[0].insert(0, smiles_str)
df = pd.DataFrame(list_predictions, columns=["SMILES", "E", "S", "A", "B", "V"])
st.dataframe(df)
# https://docs.streamlit.io/develop/api-reference/data/st.dataframe
# https://docs.streamlit.io/get-started/tutorials/create-an-app
# https://www.datacamp.com/tutorial/streamlit
# https://huggingface.co/spaces/ttmn/chemllama-demo