import streamlit as st # from git import Repo # Repo.clone_from('https://huggingface.co/ttmn/SolLlama-mtr', './SolLlama-mtr') import subprocess def git_clone(repo_url, destination_dir): try: subprocess.run(['git', 'clone', '-v', '--', repo_url, destination_dir], check=True) print("Cloning successful!") except subprocess.CalledProcessError as e: print("Cloning failed:", e) # Example usage repo_url = "https://huggingface.co/ttmn/SolLlama-mtr" destination_dir = "./SolLlama-mtr" git_clone(repo_url, destination_dir) import sys import os import torch from torch import nn import torchmetrics from transformers import LlamaModel, LlamaConfig import numpy as np import pandas as pd import warnings import lightning as L torch.set_float32_matmul_precision('high') warnings.filterwarnings("ignore", module="pl_bolts") sys.path.append( '../') import tokenizer_sl, datamodule_finetune_sl, model_finetune_sl, chemllama_mtr, utils_sl import auto_evaluator_sl from torch.utils.data import Dataset, DataLoader from transformers import DataCollatorWithPadding torch.manual_seed(1004) np.random.seed(1004) st.title('Sol-LLaMA') # file_path = './smiles_str.txt' # # Open the file in write mode ('w') and write the content # with open(file_path, 'w') as file: # file.write(smiles_str) # smiles_str = "CC02" ### # solute_or_solvent = 'solute' solute_or_solvent = st.selectbox('Solute or Solvent', ['Solute', 'Solvent']) if solute_or_solvent == 'Solute': smiles_str = st.text_area('Enter SMILE string', value='Clc1ccc(cc1)N(=O)=O') elif solute_or_solvent == 'Solvent': smiles_str = st.text_area('Enter SMILE string', value='ClCCCl') class ChemLlama(nn.Module): def __init__( self, max_position_embeddings=512, vocab_size=591, pad_token_id=0, bos_token_id=12, eos_token_id=13, hidden_size=768, intermediate_size=768, num_labels=105, attention_dropout=0.144, num_hidden_layers=7, num_attention_heads=8, learning_rate=0.0001, ): super(ChemLlama, self).__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_labels = num_labels self.vocab_size = vocab_size self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.attention_dropout = attention_dropout self.max_position_embeddings = max_position_embeddings self.mae = torchmetrics.MeanAbsoluteError() self.mse = torchmetrics.MeanSquaredError() self.config_llama = LlamaConfig( max_position_embeddings=self.max_position_embeddings, vocab_size=self.vocab_size, hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, attention_dropout=self.attention_dropout, pad_token_id=self.pad_token_id, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, ) self.loss_fn = nn.L1Loss() self.llama = LlamaModel(self.config_llama) self.gelu = nn.GELU() self.score = nn.Linear(self.hidden_size, self.num_labels) def forward(self, input_ids, attention_mask, labels=None): transformer_outputs = self.llama( input_ids=input_ids, attention_mask=attention_mask ) hidden_states = transformer_outputs[0] hidden_states = self.gelu(hidden_states) logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config_llama.pad_token_id is None and batch_size != 1: raise ValueError( "Cannot handle batch sizes > 1 if no padding token is defined." ) if self.config_llama.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = ( torch.eq(input_ids, self.config_llama.pad_token_id).int().argmax(-1) - 1 ) sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 # raise ValueError(len(sequence_lengths), sequence_lengths) pooled_logits = logits[ torch.arange(batch_size, device=logits.device), sequence_lengths ] return pooled_logits chemllama_mtr = ChemLlama() class ChemLlama_FT(nn.Module): def __init__( self, model_mtr, linear_param:int=64, use_freeze:bool=True, *args, **kwargs ): super(ChemLlama_FT, self).__init__() # self.save_hyperparameters() self.model_mtr = model_mtr if use_freeze: # self.model_mtr.freeze() for name, param in model_mtr.named_parameters(): param.requires_grad = False print(name, param.requires_grad) self.gelu = nn.GELU() self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param) self.linear2 = nn.Linear(linear_param, linear_param) self.regression = nn.Linear(linear_param, 5) self.loss_fn = nn.L1Loss() def forward(self, input_ids, attention_mask, labels=None): x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask) x = self.gelu(x) x = self.linear1(x) x = self.gelu(x) x = self.linear2(x) x = self.gelu(x) x = self.regression(x) return x chemllama_ft = ChemLlama_FT(model_mtr=chemllama_mtr) # I just reused our previous research code with some modifications. dir_main = "." max_seq_length = 512 tokenizer = tokenizer_sl.fn_load_tokenizer_llama( max_seq_length=max_seq_length, ) max_length = max_seq_length num_workers = 2 ## FT dir_model_ft_to_save = f"{dir_main}/SolLlama-mtr" # name_model_ft = 'Solvent.pt' name_model_ft = f"{solute_or_solvent}.pt" # # Load dataset for finetune # batch_size_for_train = batch_size_pair[0] # batch_size_for_valid = batch_size_pair[1] # data_module = datamodule_finetune_sl.CustomFinetuneDataModule( # solute_or_solvent=solute_or_solvent, # tokenizer=tokenizer, # max_seq_length=max_length, # batch_size_train=batch_size_for_train, # batch_size_valid=batch_size_for_valid, # # num_device=int(config.NUM_DEVICE) * config.NUM_WORKERS_MULTIPLIER, # num_device=num_workers, # ) # data_module.prepare_data() # data_module.setup() # steps_per_epoch = len(data_module.test_dataloader()) # # Load model and optimizer for finetune # learning_rate = lr # model_mtr = chemllama_mtr.ChemLlama.load_from_checkpoint(dir_model_mtr) # model_ft = model_finetune_sl.CustomFinetuneModel( # model_mtr=model_mtr, # steps_per_epoch=steps_per_epoch, # warmup_epochs=1, # max_epochs=epochs, # learning_rate=learning_rate, # # dataset_dict=dataset_dict, # use_freeze=use_freeze, # ) # # 'SolLlama_solute_vloss_val_loss=0.082_ep_epoch=06.pt' # trainer = L.Trainer( # default_root_dir=dir_model_ft_to_save, # # profiler=profiler, # # logger=csv_logger, # accelerator='auto', # devices='auto', # # accelerator='gpu', # # devices=[0], # min_epochs=1, # max_epochs=epochs, # precision=32, # # callbacks=[checkpoint_callback] # ) device = 'cpu' # Predict local_model_ft = utils_sl.load_model_ft_with( class_model_ft=chemllama_ft, dir_model_ft=dir_model_ft_to_save, name_model_ft=name_model_ft ).to(device) # result = trainer.predict(local_model_ft, data_module) # result_pred = list() # result_label = list() # for bat in range(len(result)): # result_pred.append(result[bat][0].squeeze()) # result_label.append(result[bat][1]) # with open('./smiles_str.txt', 'r') as file: # smiles_str = file.readline() if st.button('Predict'): # st.write(data) dataset_test = datamodule_finetune_sl.CustomLlamaDatasetAbraham( df=pd.DataFrame([smiles_str]), tokenizer=tokenizer, max_seq_length=max_length ) st.subheader(f'Prediction - {solute_or_solvent}') dataloader_test = DataLoader(dataset_test, shuffle=False, collate_fn=DataCollatorWithPadding(tokenizer)) list_predictions = [] local_model_ft.eval() with torch.inference_mode(): for i, v_batch in enumerate(dataloader_test): v_input_ids = v_batch['input_ids'].to(device) v_attention_mask = v_batch['attention_mask'].to(device) # v_y_labels = v_batch['labels'].to(device) v_y_logits = local_model_ft(input_ids=v_input_ids, attention_mask=v_attention_mask) # list_predictions.append(v_y_logits[0][0].tolist()) list_predictions.append(v_y_logits[0].tolist()) list_predictions[0].insert(0, smiles_str) df = pd.DataFrame(list_predictions, columns=["SMILES", "E", "S", "A", "B", "V"]) st.dataframe(df) # https://docs.streamlit.io/develop/api-reference/data/st.dataframe # https://docs.streamlit.io/get-started/tutorials/create-an-app # https://www.datacamp.com/tutorial/streamlit # https://huggingface.co/spaces/ttmn/chemllama-demo