Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,12 +16,12 @@ import streamlit as st
|
|
16 |
|
17 |
st.title('predictproduct-t5')
|
18 |
st.markdown('##### At this space, you can predict the products of reactions from their inputs.')
|
19 |
-
st.markdown('##### The code expects input_data as a string or CSV file that contains an "input" column. The format of the string or contents of the column are like "REACTANT:{reactants of the reaction}
|
20 |
-
st.markdown('##### If there
|
21 |
st.markdown('##### The output contains smiles of predicted products and sum of log-likelihood for each prediction. Predictions are ordered by their log-likelihood.(0th is the most probable product.) "valid compound" is the most probable and valid(can be recognized by RDKit) prediction.')
|
22 |
|
23 |
|
24 |
-
display_text = 'input the reaction smiles (e.g. REACTANT:
|
25 |
|
26 |
st.download_button(
|
27 |
label="Download demo_input.csv",
|
@@ -67,12 +67,12 @@ if st.button('predict'):
|
|
67 |
outputs = []
|
68 |
for idx, row in input_data.iterrows():
|
69 |
input_compound = row['input']
|
70 |
-
min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
|
71 |
inp = tokenizer(input_compound, return_tensors='pt').to(device)
|
72 |
-
output = model.generate(**inp, min_length=
|
73 |
if CFG.num_beams > 1:
|
74 |
scores = output['sequences_scores'].tolist()
|
75 |
-
output = [tokenizer.decode(i, skip_special_tokens=True).replace('
|
76 |
for ith, out in enumerate(output):
|
77 |
mol = Chem.MolFromSmiles(out.rstrip('.'))
|
78 |
if type(mol) == rdkit.Chem.rdchem.Mol:
|
@@ -118,12 +118,12 @@ if st.button('predict'):
|
|
118 |
|
119 |
else:
|
120 |
input_compound = CFG.input_data
|
121 |
-
min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
|
122 |
inp = tokenizer(input_compound, return_tensors='pt').to(device)
|
123 |
-
output = model.generate(**inp, min_length=
|
124 |
if CFG.num_beams > 1:
|
125 |
scores = output['sequences_scores'].tolist()
|
126 |
-
output = [tokenizer.decode(i, skip_special_tokens=True).replace('
|
127 |
for ith, out in enumerate(output):
|
128 |
mol = Chem.MolFromSmiles(out.rstrip('.'))
|
129 |
if type(mol) == rdkit.Chem.rdchem.Mol:
|
|
|
16 |
|
17 |
st.title('predictproduct-t5')
|
18 |
st.markdown('##### At this space, you can predict the products of reactions from their inputs.')
|
19 |
+
st.markdown('##### The code expects input_data as a string or CSV file that contains an "input" column. The format of the string or contents of the column are like "REACTANT:{reactants of the reaction}REAGENT:{reagents, catalysts, or solvents of the reaction}".')
|
20 |
+
st.markdown('##### If there is no reagent, fill the blank with a space. And if there are multiple compounds, concatenate them with "."')
|
21 |
st.markdown('##### The output contains smiles of predicted products and sum of log-likelihood for each prediction. Predictions are ordered by their log-likelihood.(0th is the most probable product.) "valid compound" is the most probable and valid(can be recognized by RDKit) prediction.')
|
22 |
|
23 |
|
24 |
+
display_text = 'input the reaction smiles (e.g. REACTANT:COC(=O)C1=CCCN(C)C1.O.[Al+3].[H-].[Li+].[Na+].[OH-]REAGENT:C1CCOC1'
|
25 |
|
26 |
st.download_button(
|
27 |
label="Download demo_input.csv",
|
|
|
67 |
outputs = []
|
68 |
for idx, row in input_data.iterrows():
|
69 |
input_compound = row['input']
|
70 |
+
# min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
|
71 |
inp = tokenizer(input_compound, return_tensors='pt').to(device)
|
72 |
+
output = model.generate(**inp, min_length=2, max_length=181, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
|
73 |
if CFG.num_beams > 1:
|
74 |
scores = output['sequences_scores'].tolist()
|
75 |
+
output = [tokenizer.decode(i, skip_special_tokens=True).replace(' ', '').rstrip('.') for i in output['sequences']]
|
76 |
for ith, out in enumerate(output):
|
77 |
mol = Chem.MolFromSmiles(out.rstrip('.'))
|
78 |
if type(mol) == rdkit.Chem.rdchem.Mol:
|
|
|
118 |
|
119 |
else:
|
120 |
input_compound = CFG.input_data
|
121 |
+
# min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
|
122 |
inp = tokenizer(input_compound, return_tensors='pt').to(device)
|
123 |
+
output = model.generate(**inp, min_length=2, max_length=181, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
|
124 |
if CFG.num_beams > 1:
|
125 |
scores = output['sequences_scores'].tolist()
|
126 |
+
output = [tokenizer.decode(i, skip_special_tokens=True).replace(' ', '').rstrip('.') for i in output['sequences']]
|
127 |
for ith, out in enumerate(output):
|
128 |
mol = Chem.MolFromSmiles(out.rstrip('.'))
|
129 |
if type(mol) == rdkit.Chem.rdchem.Mol:
|