cgpt / app.py
clamepending's picture
fixed title
b1aa005
raw
history blame contribute delete
No virus
3.7 kB
import gradio as gr
import torch
from CGPT_utils import *
from transformers import RobertaTokenizerFast
config = {
"batch_size" : 8,
"num_epochs" : 5,
"lr": 10**-4,
"seq_len": 80,
"d_model" : 128,
"src_lang" : "ChEMBL_ID",
"tgt_format" : "SMILES",
"model_folder" : "weights",
"model_basename": "tdmodel_",
"preload" : None,
"tokenizer_file" : "tokenizer_{0}.json",
"experiment_name": "runs/tmodel",
"SMILES dataset" : './data/train_dataset.csv',
"validation dataset" : './data/test_dataset.csv',
"decoder only" : True,
}
def causal_mask(size):
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0
def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len):
sos_idx = tokenizer_tgt.encode("<s>", add_special_tokens=False)[0]
# print(sos_idx)
eos_idx = tokenizer_tgt.encode("</s>", add_special_tokens=False)[0]
# Initialize the decoder input with the sos token
decoder_input = torch.cat([torch.empty(1, 1).fill_(sos_idx).type_as(source), source], dim=1)
while True:
if decoder_input.size(1) == max_len:
break
# build mask for target
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask)
# calculate output
out = model.decode(decoder_input, decoder_mask)
# get next token
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item())], dim=1
)
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
# Specify the directory where the tokenizers are saved
chem_tokenizer_dir = "chem_tokenizer"
chem_tokenizer = RobertaTokenizerFast.from_pretrained(chem_tokenizer_dir)
# Load the model
model_checkpoint = torch.load("tdmodel_04.pt", map_location=torch.device('cpu'))
model_state_dict = model_checkpoint['model_state_dict']
# Initialize the model
model = build_decoder_only_transformer(50265, config['seq_len'], config['d_model'])
# Load the model state
model.load_state_dict(model_state_dict)
def inference(input_sequence):
input_tokens = chem_tokenizer.encode(input_sequence, add_special_tokens=False)
max_len = 64
input_tensor = torch.tensor(input_tokens).unsqueeze(0) # Add batch dimension
source_mask = causal_mask(input_tensor.size(1)).unsqueeze(0) # Assuming decoder mask creation is available
# Query the Model
output_tensor = greedy_decode(model, input_tensor, source_mask, chem_tokenizer, max_len)
# Decode Output
output_sequence = chem_tokenizer.decode(output_tensor.cpu().numpy())
# print("Input Sequence:", input_sequence)
# print("Model Output:", output_sequence)
output_sequence = output_sequence.replace("<s>", "").replace("</s>", "")
return output_sequence
textinput = gr.components.Textbox(lines=1, label="Enter a SMILES", placeholder="C1=CC=C(C=C1)C")
textoutput = gr.components.Textbox()
examples = ["C1=CC=C(C=C1)C", "C1CC1C(=O)NC2=CC=CC(=C2)N", "CC(=O)OC1=CC=C"]
intf = gr.Interface(fn=inference, inputs=textinput, outputs=textoutput, examples=examples, title="CGPT (Chemical Generative Pretrained Transformer)", description="This model is a decoder-only transformer trained on the entire pubmed dataset. It can be used to complete the SMILES string from a given input SMILES string. d_model is 128, sequence length is 80, 8 attention heads per block with 6 blocks for a total of 16,473,203 parameters")
intf.launch(inline=False)