File size: 3,697 Bytes
506eef2
c06baa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3685849
c06baa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1aa005
c06baa5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
from CGPT_utils import *
from transformers import RobertaTokenizerFast

config = {
        "batch_size" : 8,
        "num_epochs" : 5,
        "lr": 10**-4,
        "seq_len": 80,
        "d_model" : 128,
        "src_lang" : "ChEMBL_ID",
        "tgt_format" : "SMILES",
        "model_folder" : "weights",
        "model_basename": "tdmodel_",
        "preload" : None,
        "tokenizer_file" : "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel",
        "SMILES dataset" : './data/train_dataset.csv',
        "validation dataset" : './data/test_dataset.csv',
        "decoder only" : True,
    }

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len):
    sos_idx = tokenizer_tgt.encode("<s>", add_special_tokens=False)[0]
    # print(sos_idx)
    eos_idx = tokenizer_tgt.encode("</s>", add_special_tokens=False)[0]

    # Initialize the decoder input with the sos token
    decoder_input = torch.cat([torch.empty(1, 1).fill_(sos_idx).type_as(source), source], dim=1)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask)

        # calculate output
        out = model.decode(decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item())], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)



# Specify the directory where the tokenizers are saved
chem_tokenizer_dir = "chem_tokenizer"
chem_tokenizer = RobertaTokenizerFast.from_pretrained(chem_tokenizer_dir)


# Load the model
model_checkpoint = torch.load("tdmodel_04.pt", map_location=torch.device('cpu'))
model_state_dict = model_checkpoint['model_state_dict']

# Initialize the model
model = build_decoder_only_transformer(50265, config['seq_len'], config['d_model'])
# Load the model state
model.load_state_dict(model_state_dict)





def inference(input_sequence):
    
    input_tokens = chem_tokenizer.encode(input_sequence, add_special_tokens=False)

    max_len = 64
    input_tensor = torch.tensor(input_tokens).unsqueeze(0)  # Add batch dimension
    source_mask = causal_mask(input_tensor.size(1)).unsqueeze(0)  # Assuming decoder mask creation is available

    # Query the Model
    output_tensor = greedy_decode(model, input_tensor, source_mask, chem_tokenizer, max_len)

    # Decode Output
    output_sequence = chem_tokenizer.decode(output_tensor.cpu().numpy())

    # print("Input Sequence:", input_sequence)
    # print("Model Output:", output_sequence)
    output_sequence = output_sequence.replace("<s>", "").replace("</s>", "")
    return output_sequence



textinput = gr.components.Textbox(lines=1, label="Enter a SMILES", placeholder="C1=CC=C(C=C1)C")
textoutput = gr.components.Textbox()
examples = ["C1=CC=C(C=C1)C", "C1CC1C(=O)NC2=CC=CC(=C2)N", "CC(=O)OC1=CC=C"]

intf = gr.Interface(fn=inference, inputs=textinput, outputs=textoutput, examples=examples, title="CGPT (Chemical Generative Pretrained Transformer)", description="This model is a decoder-only transformer trained on the entire pubmed dataset. It can be used to complete the SMILES string from a given input SMILES string. d_model is 128, sequence length is 80, 8 attention heads per block with 6 blocks for a total of 16,473,203 parameters")
intf.launch(inline=False)