Model Card for ReactionT5v1-yield
This is a ReactionT5 pre-trained to predict yields of reactions. You can use the demo here.
Model Sources
- Repository: https://github.com/sagawatatsuya/ReactionT5
- Paper: https://arxiv.org/abs/2311.06708
- Demo: https://huggingface.co/spaces/sagawa/ReactionT5_task_yield
Uses
How to Get Started with the Model
Use the code below to get started with the model.
import torch
import torch.nn as nn
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig, PreTrainedModel
class ReactionT5Yield(PreTrainedModel):
config_class = AutoConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = T5ForConditionalGeneration.from_pretrained(self.config._name_or_path)
self.model.resize_token_embeddings(self.config.vocab_size)
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size)
self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.fc5 = nn.Linear(self.config.hidden_size, 1)
self._init_weights(self.fc1)
self._init_weights(self.fc2)
self._init_weights(self.fc3)
self._init_weights(self.fc4)
self._init_weights(self.fc5)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, inputs):
encoder_outputs = self.model.encoder(**inputs)
encoder_hidden_states = encoder_outputs[0]
outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1),
self.config.decoder_start_token_id,
dtype=torch.long), encoder_hidden_states=encoder_hidden_states)
last_hidden_states = outputs[0]
output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size))
output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size))
output = self.fc3(torch.hstack((output1, output2)))
output = self.fc4(output)
output = self.fc5(output)
return output*100
model = ReactionT5Yield.from_pretrained('sagawa/ReactionT5v1-yield')
tokenizer = AutoTokenizer.from_pretrained('sagawa/ReactionT5v1-yield')
inp = tokenizer(['REACTANT:CC(C)n1ncnc1-c1cn2c(n1)-c1cnc(O)cc1OCC2.CCN(C(C)C)C(C)C.Cl.NC(=O)[C@@H]1C[C@H](F)CN1REAGENT: PRODUCT:O=C(NNC(=O)C(F)(F)F)C(F)(F)F'], return_tensors='pt')
print(model(inp)) # tensor([[19.1666]], grad_fn=<MulBackward0>)
Training Details
Training Procedure
We used Open Reaction Database (ORD) dataset for model training. Following is the command used for training. For more information, please refer to the paper and GitHub repository.
python train.py
--data_path='all_ord_reaction_uniq_with_attr_v3.tsv'
--pretrained_model_name_or_path='sagawa/ZINC-t5'
--model='t5'
--epochs=100
--batch_size=50
--max_len=400
--num_workers=4
--weight_decay=0.05
--gradient_accumulation_steps=1
--batch_scheduler
--print_freq=100
--output_dir='./'
Results
R^2 | DFT | MFF | Yield-BERT | T5Chem | CompoundT5 | ReactionT5 (without finetuning) |
---|---|---|---|---|---|---|
Random 70/30 | 0.92 | 0.927 ± 0.007 | 0.951 ± 0.005 | 0.970 ± 0.003 | 0.971 ± 0.002 | 0.904 ± 0.0007 |
Test 1 | 0.80 | 0.851 | 0.838 | 0.811 | 0.855 | 0.919 |
Test 2 | 0.77 | 0.713 | 0.836 | 0.907 | 0.852 | 0.927 |
Test 3 | 0.64 | 0.635 | 0.738 | 0.789 | 0.712 | 0.847 |
Test 4 | 0.54 | 0.184 | 0.538 | 0.627 | 0.547 | 0.909 |
Avg. Tests 1–4 | 0.69 ± 0.104 | 0.596 ± 0.251 | 0.738 ± 0.122 | 0.785 ± 0.094 | 0.741 ± 0.126 | 0.900 ± 0.031 |
Citation
arxiv link: https://arxiv.org/abs/2311.06708
@misc{sagawa2023reactiont5,
title={ReactionT5: a large-scale pre-trained model towards application of limited reaction data},
author={Tatsuya Sagawa and Ryosuke Kojima},
year={2023},
eprint={2311.06708},
archivePrefix={arXiv},
primaryClass={physics.chem-ph}
}
- Downloads last month
- 8