sagawa commited on
Commit
d0939bd
1 Parent(s): b15be69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -16,12 +16,12 @@ import streamlit as st
16
 
17
  st.title('predictproduct-t5')
18
  st.markdown('##### At this space, you can predict the products of reactions from their inputs.')
19
- st.markdown('##### The code expects input_data as a string or CSV file that contains an "input" column. The format of the string or contents of the column are like "REACTANT:{reactants of the reaction}CATALYST:{catalysts of the reaction}REAGENT:{reagents of the reaction}SOLVENT:{solvent of the reaction}".')
20
- st.markdown('##### If there are no catalyst or reagent, fill the blank with a space. And if there are multiple reactants, concatenate them with "."')
21
  st.markdown('##### The output contains smiles of predicted products and sum of log-likelihood for each prediction. Predictions are ordered by their log-likelihood.(0th is the most probable product.) "valid compound" is the most probable and valid(can be recognized by RDKit) prediction.')
22
 
23
 
24
- display_text = 'input the reaction smiles (e.g. REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O)'
25
 
26
  st.download_button(
27
  label="Download demo_input.csv",
@@ -67,12 +67,12 @@ if st.button('predict'):
67
  outputs = []
68
  for idx, row in input_data.iterrows():
69
  input_compound = row['input']
70
- min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
71
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
72
- output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
73
  if CFG.num_beams > 1:
74
  scores = output['sequences_scores'].tolist()
75
- output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
76
  for ith, out in enumerate(output):
77
  mol = Chem.MolFromSmiles(out.rstrip('.'))
78
  if type(mol) == rdkit.Chem.rdchem.Mol:
@@ -118,12 +118,12 @@ if st.button('predict'):
118
 
119
  else:
120
  input_compound = CFG.input_data
121
- min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
122
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
123
- output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
124
  if CFG.num_beams > 1:
125
  scores = output['sequences_scores'].tolist()
126
- output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
127
  for ith, out in enumerate(output):
128
  mol = Chem.MolFromSmiles(out.rstrip('.'))
129
  if type(mol) == rdkit.Chem.rdchem.Mol:
 
16
 
17
  st.title('predictproduct-t5')
18
  st.markdown('##### At this space, you can predict the products of reactions from their inputs.')
19
+ st.markdown('##### The code expects input_data as a string or CSV file that contains an "input" column. The format of the string or contents of the column are like "REACTANT:{reactants of the reaction}REAGENT:{reagents, catalysts, or solvents of the reaction}".')
20
+ st.markdown('##### If there is no reagent, fill the blank with a space. And if there are multiple compounds, concatenate them with "."')
21
  st.markdown('##### The output contains smiles of predicted products and sum of log-likelihood for each prediction. Predictions are ordered by their log-likelihood.(0th is the most probable product.) "valid compound" is the most probable and valid(can be recognized by RDKit) prediction.')
22
 
23
 
24
+ display_text = 'input the reaction smiles (e.g. REACTANT:COC(=O)C1=CCCN(C)C1.O.[Al+3].[H-].[Li+].[Na+].[OH-]REAGENT:C1CCOC1'
25
 
26
  st.download_button(
27
  label="Download demo_input.csv",
 
67
  outputs = []
68
  for idx, row in input_data.iterrows():
69
  input_compound = row['input']
70
+ # min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
71
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
72
+ output = model.generate(**inp, min_length=2, max_length=181, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
73
  if CFG.num_beams > 1:
74
  scores = output['sequences_scores'].tolist()
75
+ output = [tokenizer.decode(i, skip_special_tokens=True).replace(' ', '').rstrip('.') for i in output['sequences']]
76
  for ith, out in enumerate(output):
77
  mol = Chem.MolFromSmiles(out.rstrip('.'))
78
  if type(mol) == rdkit.Chem.rdchem.Mol:
 
118
 
119
  else:
120
  input_compound = CFG.input_data
121
+ # min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
122
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
123
+ output = model.generate(**inp, min_length=2, max_length=181, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
124
  if CFG.num_beams > 1:
125
  scores = output['sequences_scores'].tolist()
126
+ output = [tokenizer.decode(i, skip_special_tokens=True).replace(' ', '').rstrip('.') for i in output['sequences']]
127
  for ith, out in enumerate(output):
128
  mol = Chem.MolFromSmiles(out.rstrip('.'))
129
  if type(mol) == rdkit.Chem.rdchem.Mol: