|
|
|
import streamlit as st |
|
|
|
|
|
|
|
import subprocess |
|
|
|
def git_clone(repo_url, destination_dir): |
|
try: |
|
subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True) |
|
print("Cloning successful!") |
|
except subprocess.CalledProcessError as e: |
|
print("Cloning failed:", e) |
|
|
|
|
|
repo_url = "https://huggingface.co/ttmn/SolLlama-mtr" |
|
destination_dir = "./SolLlama-mtr" |
|
|
|
git_clone(repo_url, destination_dir) |
|
|
|
import sys |
|
import os |
|
import torch |
|
from torch import nn |
|
import torchmetrics |
|
from transformers import LlamaModel, LlamaConfig |
|
import numpy as np |
|
import pandas as pd |
|
import warnings |
|
import lightning as L |
|
torch.set_float32_matmul_precision('high') |
|
warnings.filterwarnings("ignore", module="pl_bolts") |
|
|
|
sys.path.append( '../') |
|
|
|
import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl |
|
import auto_evaluator_sl |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import DataCollatorWithPadding |
|
|
|
torch.manual_seed(1004) |
|
np.random.seed(1004) |
|
|
|
st.title('Sol-LLaMA') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent']) |
|
if solute_or_solvent == 'Solute': |
|
smiles_str = st.text_area('Enter SMILE string', value='Clc1ccc(cc1)N(=O)=O') |
|
elif solute_or_solvent == 'Solvent': |
|
smiles_str = st.text_area('Enter SMILE string', value='ClCCCl') |
|
|
|
class ChemLlama(nn.Module): |
|
def __init__( |
|
self, |
|
max_position_embeddings=512, |
|
vocab_size=591, |
|
pad_token_id=0, |
|
bos_token_id=12, |
|
eos_token_id=13, |
|
hidden_size=768, |
|
intermediate_size=768, |
|
num_labels=105, |
|
attention_dropout=0.144, |
|
num_hidden_layers=7, |
|
num_attention_heads=8, |
|
learning_rate=0.0001, |
|
): |
|
super(ChemLlama, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.intermediate_size = intermediate_size |
|
self.num_labels = num_labels |
|
self.vocab_size = vocab_size |
|
self.pad_token_id = pad_token_id |
|
self.bos_token_id = bos_token_id |
|
self.eos_token_id = eos_token_id |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_dropout = attention_dropout |
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
self.mae = torchmetrics.MeanAbsoluteError() |
|
self.mse = torchmetrics.MeanSquaredError() |
|
|
|
self.config_llama = LlamaConfig( |
|
max_position_embeddings=self.max_position_embeddings, |
|
vocab_size=self.vocab_size, |
|
hidden_size=self.hidden_size, |
|
intermediate_size=self.intermediate_size, |
|
num_hidden_layers=self.num_hidden_layers, |
|
num_attention_heads=self.num_attention_heads, |
|
attention_dropout=self.attention_dropout, |
|
pad_token_id=self.pad_token_id, |
|
bos_token_id=self.bos_token_id, |
|
eos_token_id=self.eos_token_id, |
|
) |
|
|
|
self.loss_fn = nn.L1Loss() |
|
|
|
self.llama = LlamaModel(self.config_llama) |
|
self.gelu = nn.GELU() |
|
self.score = nn.Linear(self.hidden_size, self.num_labels) |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
|
|
transformer_outputs = self.llama( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
) |
|
|
|
hidden_states = transformer_outputs[0] |
|
hidden_states = self.gelu(hidden_states) |
|
logits = self.score(hidden_states) |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
|
|
if self.config_llama.pad_token_id is None and batch_size != 1: |
|
raise ValueError( |
|
"Cannot handle batch sizes > 1 if no padding token is defined." |
|
) |
|
if self.config_llama.pad_token_id is None: |
|
sequence_lengths = -1 |
|
else: |
|
if input_ids is not None: |
|
|
|
sequence_lengths = ( |
|
torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1) |
|
- 1 |
|
) |
|
sequence_lengths = sequence_lengths % input_ids.shape[-1] |
|
sequence_lengths = sequence_lengths.to(logits.device) |
|
else: |
|
sequence_lengths = -1 |
|
|
|
|
|
pooled_logits = logits[ |
|
torch.arange(batch_size, device=logits.device), sequence_lengths |
|
] |
|
return pooled_logits |
|
|
|
|
|
chemllama_mtr = ChemLlama() |
|
|
|
class ChemLlama_FT(nn.Module): |
|
def __init__( |
|
self, |
|
model_mtr, |
|
linear_param:int=64, |
|
use_freeze:bool=True, |
|
*args, **kwargs |
|
): |
|
super(ChemLlama_FT, self).__init__() |
|
|
|
|
|
self.model_mtr = model_mtr |
|
if use_freeze: |
|
|
|
for name, param in model_mtr.named_parameters(): |
|
param.requires_grad = False |
|
print(name, param.requires_grad) |
|
|
|
self.gelu = nn.GELU() |
|
self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param) |
|
self.linear2 = nn.Linear(linear_param, linear_param) |
|
self.regression = nn.Linear(linear_param, 5) |
|
|
|
self.loss_fn = nn.L1Loss() |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask) |
|
x = self.gelu(x) |
|
x = self.linear1(x) |
|
x = self.gelu(x) |
|
x = self.linear2(x) |
|
x = self.gelu(x) |
|
x = self.regression(x) |
|
|
|
return x |
|
|
|
chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr) |
|
|
|
|
|
|
|
dir_main = "." |
|
|
|
max_seq_length = 512 |
|
|
|
tokenizer = tokenizer_sl.fn_load_tokenizer_llama( |
|
max_seq_length=max_seq_length, |
|
) |
|
max_length = max_seq_length |
|
num_workers = 2 |
|
|
|
|
|
|
|
dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr" |
|
|
|
name_model_ft = f"{solute_or_solvent}.pt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = 'cpu' |
|
|
|
local_model_ft = utils_sl.load_model_ft_with( |
|
class_model_ft=chemllama_ft, |
|
dir_model_ft=dir_model_ft_to_save, |
|
name_model_ft=name_model_ft |
|
).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button('Predict'): |
|
|
|
|
|
dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham( |
|
df=pd.DataFrame([smiles_str]), |
|
tokenizer=tokenizer, |
|
max_seq_length=max_length |
|
) |
|
st.subheader(f'Prediction - {solute_or_solvent}') |
|
dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer)) |
|
|
|
list_predictions = [] |
|
local_model_ft.eval() |
|
with torch.inference_mode(): |
|
for i, v_batch in enumerate(dataloader_test): |
|
v_input_ids = v_batch['input_ids'].to(device) |
|
v_attention_mask = v_batch['attention_mask'].to(device) |
|
|
|
v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask) |
|
|
|
list_predictions.append(v_y_logits[0].tolist()) |
|
|
|
list_predictions[0].insert(0, smiles_str) |
|
df = pd.DataFrame(list_predictions, columns=["SMILES", "E", "S", "A", "B", "V"]) |
|
|
|
st.dataframe(df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|