jannisborn
feat: Initial RT app
8b150bd
import json
import logging
import os
from collections import defaultdict
from typing import Dict, List, Tuple
import mols2grid
import pandas as pd
from gt4sd.algorithms import (
RegressionTransformerMolecules,
RegressionTransformerProteins,
)
from gt4sd.algorithms.core import AlgorithmConfiguration
from rdkit import Chem
from terminator.selfies import decoder
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def get_application(application: str) -> AlgorithmConfiguration:
"""
Convert application name to AlgorithmConfiguration.
Args:
application: Molecules or Proteins
Returns:
The corresponding AlgorithmConfiguration
"""
if application == "Molecules":
application = RegressionTransformerMolecules
elif application == "Proteins":
application = RegressionTransformerProteins
else:
raise ValueError(
"Currently only models for molecules and proteins are supported"
)
return application
def get_inference_dict(
application: AlgorithmConfiguration, algorithm_version: str
) -> Dict:
"""
Get inference dictionary for a given application and algorithm version.
Args:
application: algorithm application (Molecules or Proteins)
algorithm_version: algorithm version (e.g. qed)
Returns:
A dictionary with the inference parameters.
"""
config = application(algorithm_version=algorithm_version)
with open(os.path.join(config.ensure_artifacts(), "inference.json"), "r") as f:
data = json.load(f)
return data
def get_rt_name(x: Dict) -> str:
"""
Get the UI display name of the regression transformer.
Args:
x: dictionary with the inference parameters
Returns:
The display name
"""
return (
x["algorithm_application"].split("Transformer")[-1]
+ ": "
+ x["algorithm_version"].capitalize()
)
def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
"""
Uses mols2grid to draw a HTML grid for the prediction
Args:
prediction: Predicted sequence.
target: Target molecule
domain: Domain of the prediction (molecules or proteins)
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
seq = target.split("|")[-1]
converter = (
decoder
if domain == "Molecules"
else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
)
try:
seq = converter(seq)
except Exception:
logger.warning(f"Could not draw sequence {seq}")
result = {"SMILES": [seq], "Name": ["Target"]}
# Add properties
for prop in prediction.split("<")[1:]:
result[
prop.split(">")[0]
] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}"
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=900,
n_cols=1,
name="Results",
size=(600, 700),
)
return obj.data
def draw_grid_generate(
samples: List[Tuple[str]], domain: str, n_cols: int = 5, size=(140, 200)
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples (with properties)
domain: Domain of the prediction (molecules or proteins)
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
if domain == "Proteins":
try:
smis = list(
map(lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x[0])), samples)
)
except Exception:
logger.warning(f"Could not convert some sequences {samples}")
else:
smis = [s[0] for s in samples]
result = defaultdict(list)
result.update({"SMILES": smis, "Name": [f"sample_{i}" for i in range(len(smis))]})
# Create properties
properties = [s.split("<")[1] for s in samples[0][1].split(">")[:-1]]
# Fill properties
for sample in samples:
for prop in properties:
value = float(sample[1].split(prop)[-1][1:].split("<")[0])
result[prop].append(f"{prop} = {value}")
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data