tags:
- chemistry
- molecule
- drug
Roberta Zinc Decoder
This model is a GPT2 decoder model designed to reconstruct SMILES strings from embeddings created by the roberta_zinc_480m model. The decoder model was trained on 30m compounds from the ZINC Database.
The decoder model conditions generation on mean pooled embeddings from the encoder model. Mean pooled embeddings are used to allow for integration with vector databases, which require fixed length embeddings.
Condition embeddings are passed to the decoder model using the encoder_hidden_states
attribute.
The standard GPT2LMHeadModel
does not support generation with encoder hidden states, so this repo
includes a custom ConditionalGPT2LMHeadModel
. See example below for how to instantiate the model.
import torch
from transformers import AutoModelForCausalLM, RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding
tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=256)
collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')
encoder_model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m')
encoder_model.eval();
commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7'
decoder_model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder",
trust_remote_code=True, revision=commit_hash)
decoder_model.eval();
smiles = ['Brc1cc2c(NCc3ccccc3)ncnc2s1',
'Brc1cc2c(NCc3ccccn3)ncnc2s1',
'Brc1cc2c(NCc3cccs3)ncnc2s1',
'Brc1cc2c(NCc3ccncc3)ncnc2s1',
'Brc1cc2c(Nc3ccccc3)ncnc2s1']
inputs = collator(tokenizer(smiles))
outputs = encoder_model(**inputs, output_hidden_states=True)
full_embeddings = outputs[1][-1]
mask = inputs['attention_mask']
mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))
decoder_inputs = torch.tensor([[tokenizer.bos_token_id] for i in range(len(smiles))])
hidden_states = mean_embeddings[:,None] # hidden states shape (bs, 1, -1)
gen = decoder_model.generate(
decoder_inputs,
encoder_hidden_states=hidden_states,
do_sample=False, # greedy decoding is recommended
max_length=100,
temperature=1.,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
)
reconstructed_smiles = tokenizer.batch_decode(gen, skip_special_tokens=True)
Model Performance
The decoder model was evaluated on a test set of 1m compounds from ZINC. Compounds were encoded with the roberta_zinc_480m model and reconstructed with the decoder model.
The following metrics are computed:
exact_match
- percent of inputs exactly reconstructedtoken_accuracy
- percent of output tokens exactly matching input tokens (excluding padding)valid_structure
- percent of generated outputs that resolved to a valid SMILES stringtanimoto
- tanimoto similarity between inputs and generated outputs. Excludes invalid structurescos_sim
- cosine similarity between input encoder embeddings and output encoder embeddings
eval_type=full
reports metrics for the full 1m compound test set.
eval_type=failed
subsets metrics for generated outputs that failed to exactly replicate the inputs.
eval_type | exact_match | token_accuracy | valid_structure | tanimoto | cos_sim |
---|---|---|---|---|---|
full | 0.948277 | 0.990704 | 0.994278 | 0.987698 | 0.998224 |
failed | 0.000000 | 0.820293 | 0.889372 | 0.734097 | 0.965668 |