Spaces:
Running
Running
File size: 2,039 Bytes
5984d9a e83e5dc 4b5d582 5984d9a e83e5dc 4b5d582 e83e5dc 4b5d582 e83e5dc 5984d9a e83e5dc 5984d9a e83e5dc 5984d9a e83e5dc 5984d9a e83e5dc 5984d9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import logging
from collections import defaultdict
from typing import List, Callable
from gt4sd.properties import PropertyPredictorRegistry
from gt4sd.algorithms.prediction.paccmann.core import PaccMann, AffinityPredictor
import torch
import mols2grid
import pandas as pd
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def get_affinity_function(target: str) -> Callable:
return lambda mols: torch.stack(
list(
PaccMann(
AffinityPredictor(protein_targets=[target] * len(mols), ligands=mols)
).sample(len(mols))
)
).tolist()
EVAL_DICT = {
"qed": PropertyPredictorRegistry.get_property_predictor("qed"),
"sa": PropertyPredictorRegistry.get_property_predictor("sas"),
}
def draw_grid_generate(
samples: List[str],
properties: List[str],
protein_target: str,
n_cols: int = 3,
size=(140, 200),
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples.
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
if protein_target != "":
EVAL_DICT.update({"affinity": get_affinity_function(protein_target)})
result = defaultdict(list)
result.update(
{"SMILES": samples, "Name": [f"Generated_{i}" for i in range(len(samples))]},
)
if "affinity" in properties:
properties.remove("affinity")
vals = EVAL_DICT["affinity"](samples)
result["affinity"] = vals
# Fill properties
for sample in samples:
for prop in properties:
value = EVAL_DICT[prop](sample)
result[prop].append(f"{prop} = {value}")
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data
|