Edit model card

Model Card for ReactionT5-yield-prediction

This is a ReactionT5 pre-trained to predict yields of reactions. You can use the demo here.

Model Details

Model Sources

Uses

How to Get Started with the Model

Download files and use the code below to get started with the model.

import torch
import torch.nn as nn
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig, PreTrainedModel

class ReactionT5Yield(PreTrainedModel):
    config_class  = AutoConfig
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = T5ForConditionalGeneration.from_pretrained(self.config._name_or_path)
        self.model.resize_token_embeddings(self.config.vocab_size)
        self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
        self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
        self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size)
        self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.fc5 = nn.Linear(self.config.hidden_size, 1)

        self._init_weights(self.fc1)
        self._init_weights(self.fc2)
        self._init_weights(self.fc3)
        self._init_weights(self.fc4)
        self._init_weights(self.fc5)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, inputs):
        encoder_outputs = self.model.encoder(**inputs)
        encoder_hidden_states = encoder_outputs[0]
        outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1),
                                            self.config.decoder_start_token_id,
                                            dtype=torch.long), encoder_hidden_states=encoder_hidden_states)
        last_hidden_states = outputs[0]
        output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size))
        output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size))
        output = self.fc3(torch.hstack((output1, output2)))
        output = self.fc4(output)
        output = self.fc5(output)
        return output*100


model = ReactionT5Yield.from_pretrained('sagawa/ReactionT5-yield-prediction')
tokenizer = AutoTokenizer.from_pretrained('sagawa/ReactionT5-yield-prediction')
inp = tokenizer(['REACTANT:CC(C)n1ncnc1-c1cn2c(n1)-c1cnc(O)cc1OCC2.CCN(C(C)C)C(C)C.Cl.NC(=O)[C@@H]1C[C@H](F)CN1REAGENT: PRODUCT:O=C(NNC(=O)C(F)(F)F)C(F)(F)F'], return_tensors='pt')
print(model(inp)) # tensor([[19.1666]], grad_fn=<MulBackward0>)

Training Details

Training Procedure

We used Open Reaction Database (ORD) dataset for model training. Following is the command used for training. For more information, please refer to the paper and GitHub repository.

python train.py
  --data_path='all_ord_reaction_uniq_with_attr_v3.tsv'
  --pretrained_model_name_or_path='sagawa/ZINC-t5'
  --model='t5'
  --epochs=100
  --batch_size=50
  --max_len=400
  --num_workers=4
  --weight_decay=0.05
  --gradient_accumulation_steps=1
  --batch_scheduler
  --print_freq=100
  --output_dir='./'

Results

R^2 DFT MFF Yield-BERT T5Chem CompoundT5 ReactionT5 (without finetuning)
Random 70/30 0.92 0.927 ± 0.007 0.951 ± 0.005 0.970 ± 0.003 0.971 ± 0.002 0.904 ± 0.0007
Test 1 0.80 0.851 0.838 0.811 0.855 0.919
Test 2 0.77 0.713 0.836 0.907 0.852 0.927
Test 3 0.64 0.635 0.738 0.789 0.712 0.847
Test 4 0.54 0.184 0.538 0.627 0.547 0.909
Avg. Tests 1–4 0.69 ± 0.104 0.596 ± 0.251 0.738 ± 0.122 0.785 ± 0.094 0.741 ± 0.126 0.900 ± 0.031

Citation [optional]

Model Card Authors [optional]

{{ model_card_authors | default("[More Information Needed]", true)}}

Model Card Contact

{{ model_card_contact | default("[More Information Needed]", true)}}

Downloads last month
30
Unable to determine this model’s pipeline type. Check the docs .