|
import os |
|
import random |
|
import numpy as np |
|
import warnings |
|
import pandas as pd |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from torch.utils.data import Dataset, DataLoader |
|
import gc |
|
import streamlit as st |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
st.title('ReactionT5_task_retrosynthesis') |
|
st.markdown(''' |
|
##### At this space, you can predict the reactants of reactions from their products. |
|
##### The code expects input_data as a string or CSV file that contains an "input" column. |
|
##### The format of the string or contents of the column should be smiles generated by RDKit. |
|
##### For multiple compounds, concatenate them with ".". |
|
##### The output contains SMILES of predicted reactants and the sum of log-likelihood for each prediction, ordered by their log-likelihood (0th is the most probable reactant). |
|
''') |
|
|
|
display_text = 'input the product smiles (e.g. CCN(CC)CCNC(=S)NC1CCCc2cc(C)cnc21)' |
|
|
|
st.download_button( |
|
label="Download demo_input.csv", |
|
data=pd.read_csv('demo_input.csv').to_csv(index=False), |
|
file_name='demo_input.csv', |
|
mime='text/csv', |
|
) |
|
|
|
class CFG(): |
|
num_beams = st.number_input(label='num beams', min_value=1, max_value=10, value=5, step=1) |
|
num_return_sequences = num_beams |
|
uploaded_file = st.file_uploader("Choose a CSV file") |
|
input_data = st.text_area(display_text) |
|
model_name_or_path = 'sagawa/ReactionT5v2-retrosynthesis' |
|
input_column = 'input' |
|
input_max_length = 100 |
|
model = 't5' |
|
seed = 42 |
|
batch_size=1 |
|
|
|
def seed_everything(seed=42): |
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
|
|
def prepare_input(cfg, text): |
|
inputs = tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=cfg.input_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
) |
|
dic = {"input_ids": [], "attention_mask": []} |
|
for k, v in inputs.items(): |
|
dic[k].append(torch.tensor(v[0], dtype=torch.long)) |
|
return dic |
|
|
|
|
|
class ProductDataset(Dataset): |
|
def __init__(self, cfg, df): |
|
self.cfg = cfg |
|
self.inputs = df[cfg.input_column].values |
|
|
|
def __len__(self): |
|
return len(self.inputs) |
|
|
|
def __getitem__(self, idx): |
|
return prepare_input(self.cfg, self.inputs[idx]) |
|
|
|
|
|
def predict_single_input(input_compound): |
|
inp = tokenizer(input_compound, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
output = model.generate( |
|
**inp, |
|
num_beams=CFG.num_beams, |
|
num_return_sequences=CFG.num_return_sequences, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
) |
|
return output |
|
|
|
|
|
def decode_output(output): |
|
sequences = [ |
|
tokenizer.decode(seq, skip_special_tokens=True).replace(" ", "").rstrip(".") |
|
for seq in output["sequences"] |
|
] |
|
if CFG.num_beams > 1: |
|
scores = output["sequences_scores"].tolist() |
|
return sequences, scores |
|
return sequences, None |
|
|
|
|
|
def save_single_prediction(input_compound, output, scores): |
|
output_data = [input_compound] + output + (scores if scores else []) |
|
columns = ( |
|
["input"] |
|
+ [f"{i}th" for i in range(CFG.num_beams)] |
|
+ ([f"{i}th score" for i in range(CFG.num_beams)] if scores else []) |
|
) |
|
output_df = pd.DataFrame([output_data], columns=columns) |
|
return output_df |
|
|
|
|
|
def save_multiple_predictions(input_data, sequences, scores): |
|
output_list = [ |
|
[input_data.loc[i // CFG.num_return_sequences, CFG.input_column]] |
|
+ sequences[i : i + CFG.num_return_sequences] |
|
+ scores[i : i + CFG.num_return_sequences] |
|
for i in range(0, len(sequences), CFG.num_return_sequences) |
|
] |
|
columns = ( |
|
["input"] |
|
+ [f"{i}th" for i in range(CFG.num_return_sequences)] |
|
+ ([f"{i}th score" for i in range(CFG.num_return_sequences)] if scores else []) |
|
) |
|
output_df = pd.DataFrame(output_list, columns=columns) |
|
return output_df |
|
|
|
|
|
if st.button('predict'): |
|
with st.spinner('Now processing. If num beams=5, this process takes about 15 seconds per reaction.'): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
seed_everything(seed=CFG.seed) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors="pt") |
|
model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device) |
|
model.eval() |
|
|
|
if CFG.uploaded_file is None: |
|
input_compound = CFG.input_data |
|
output = predict_single_input(input_compound) |
|
sequences, scores = decode_output(output) |
|
output_df = save_single_prediction(input_compound, sequences, scores) |
|
else: |
|
input_data = pd.read_csv(CFG.uploaded_file) |
|
dataset = ProductDataset(CFG, input_data) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=CFG.batch_size, |
|
shuffle=False, |
|
num_workers=4, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
all_sequences, all_scores = [], [] |
|
for inputs in dataloader: |
|
inputs = {k: v[0].to(device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
output = model.generate( |
|
**inputs, |
|
num_beams=CFG.num_beams, |
|
num_return_sequences=CFG.num_return_sequences, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
) |
|
sequences, scores = decode_output(output) |
|
all_sequences.extend(sequences) |
|
if scores: |
|
all_scores.extend(scores) |
|
del output |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
output_df = save_multiple_predictions(input_data, all_sequences, all_scores) |
|
|
|
@st.cache |
|
def convert_df(df): |
|
return df.to_csv(index=False) |
|
|
|
csv = convert_df(output_df) |
|
|
|
st.download_button( |
|
label="Download data as CSV", |
|
data=csv, |
|
file_name='output.csv', |
|
mime='text/csv', |
|
) |
|
|