Spaces:
Running
Running
File size: 4,726 Bytes
8b150bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import json
import logging
import os
from collections import defaultdict
from typing import Dict, List, Tuple
import mols2grid
import pandas as pd
from gt4sd.algorithms import (
RegressionTransformerMolecules,
RegressionTransformerProteins,
)
from gt4sd.algorithms.core import AlgorithmConfiguration
from rdkit import Chem
from terminator.selfies import decoder
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def get_application(application: str) -> AlgorithmConfiguration:
"""
Convert application name to AlgorithmConfiguration.
Args:
application: Molecules or Proteins
Returns:
The corresponding AlgorithmConfiguration
"""
if application == "Molecules":
application = RegressionTransformerMolecules
elif application == "Proteins":
application = RegressionTransformerProteins
else:
raise ValueError(
"Currently only models for molecules and proteins are supported"
)
return application
def get_inference_dict(
application: AlgorithmConfiguration, algorithm_version: str
) -> Dict:
"""
Get inference dictionary for a given application and algorithm version.
Args:
application: algorithm application (Molecules or Proteins)
algorithm_version: algorithm version (e.g. qed)
Returns:
A dictionary with the inference parameters.
"""
config = application(algorithm_version=algorithm_version)
with open(os.path.join(config.ensure_artifacts(), "inference.json"), "r") as f:
data = json.load(f)
return data
def get_rt_name(x: Dict) -> str:
"""
Get the UI display name of the regression transformer.
Args:
x: dictionary with the inference parameters
Returns:
The display name
"""
return (
x["algorithm_application"].split("Transformer")[-1]
+ ": "
+ x["algorithm_version"].capitalize()
)
def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
"""
Uses mols2grid to draw a HTML grid for the prediction
Args:
prediction: Predicted sequence.
target: Target molecule
domain: Domain of the prediction (molecules or proteins)
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
seq = target.split("|")[-1]
converter = (
decoder
if domain == "Molecules"
else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
)
try:
seq = converter(seq)
except Exception:
logger.warning(f"Could not draw sequence {seq}")
result = {"SMILES": [seq], "Name": ["Target"]}
# Add properties
for prop in prediction.split("<")[1:]:
result[
prop.split(">")[0]
] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}"
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=900,
n_cols=1,
name="Results",
size=(600, 700),
)
return obj.data
def draw_grid_generate(
samples: List[Tuple[str]], domain: str, n_cols: int = 5, size=(140, 200)
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples (with properties)
domain: Domain of the prediction (molecules or proteins)
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
if domain == "Proteins":
try:
smis = list(
map(lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x[0])), samples)
)
except Exception:
logger.warning(f"Could not convert some sequences {samples}")
else:
smis = [s[0] for s in samples]
result = defaultdict(list)
result.update({"SMILES": smis, "Name": [f"sample_{i}" for i in range(len(smis))]})
# Create properties
properties = [s.split("<")[1] for s in samples[0][1].split(">")[:-1]]
# Fill properties
for sample in samples:
for prop in properties:
value = float(sample[1].split(prop)[-1][1:].split("<")[0])
result[prop].append(f"{prop} = {value}")
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data
|