jannisborn
feat: Initial RT app
8b150bd
raw
history blame
No virus
4.73 kB
import json
import logging
import os
from collections import defaultdict
from typing import Dict, List, Tuple
import mols2grid
import pandas as pd
from gt4sd.algorithms import (
RegressionTransformerMolecules,
RegressionTransformerProteins,
)
from gt4sd.algorithms.core import AlgorithmConfiguration
from rdkit import Chem
from terminator.selfies import decoder
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def get_application(application: str) -> AlgorithmConfiguration:
"""
Convert application name to AlgorithmConfiguration.
Args:
application: Molecules or Proteins
Returns:
The corresponding AlgorithmConfiguration
"""
if application == "Molecules":
application = RegressionTransformerMolecules
elif application == "Proteins":
application = RegressionTransformerProteins
else:
raise ValueError(
"Currently only models for molecules and proteins are supported"
)
return application
def get_inference_dict(
application: AlgorithmConfiguration, algorithm_version: str
) -> Dict:
"""
Get inference dictionary for a given application and algorithm version.
Args:
application: algorithm application (Molecules or Proteins)
algorithm_version: algorithm version (e.g. qed)
Returns:
A dictionary with the inference parameters.
"""
config = application(algorithm_version=algorithm_version)
with open(os.path.join(config.ensure_artifacts(), "inference.json"), "r") as f:
data = json.load(f)
return data
def get_rt_name(x: Dict) -> str:
"""
Get the UI display name of the regression transformer.
Args:
x: dictionary with the inference parameters
Returns:
The display name
"""
return (
x["algorithm_application"].split("Transformer")[-1]
+ ": "
+ x["algorithm_version"].capitalize()
)
def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
"""
Uses mols2grid to draw a HTML grid for the prediction
Args:
prediction: Predicted sequence.
target: Target molecule
domain: Domain of the prediction (molecules or proteins)
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
seq = target.split("|")[-1]
converter = (
decoder
if domain == "Molecules"
else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
)
try:
seq = converter(seq)
except Exception:
logger.warning(f"Could not draw sequence {seq}")
result = {"SMILES": [seq], "Name": ["Target"]}
# Add properties
for prop in prediction.split("<")[1:]:
result[
prop.split(">")[0]
] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}"
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=900,
n_cols=1,
name="Results",
size=(600, 700),
)
return obj.data
def draw_grid_generate(
samples: List[Tuple[str]], domain: str, n_cols: int = 5, size=(140, 200)
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples (with properties)
domain: Domain of the prediction (molecules or proteins)
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
if domain == "Proteins":
try:
smis = list(
map(lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x[0])), samples)
)
except Exception:
logger.warning(f"Could not convert some sequences {samples}")
else:
smis = [s[0] for s in samples]
result = defaultdict(list)
result.update({"SMILES": smis, "Name": [f"sample_{i}" for i in range(len(smis))]})
# Create properties
properties = [s.split("<")[1] for s in samples[0][1].split(">")[:-1]]
# Fill properties
for sample in samples:
for prop in properties:
value = float(sample[1].split(prop)[-1][1:].split("<")[0])
result[prop].append(f"{prop} = {value}")
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data