solarafold / app.py
Jhsmit's picture
feat: always show tooltip data
e1dfecd
import json
from dataclasses import dataclass
from io import StringIO
from typing import Literal, Optional, TypedDict, cast
from zipfile import ZipFile
import polars as pl
import solara
from Bio.PDB import MMCIFParser, Structure
from ipymolstar import PDBeMolstar
from ipymolstar.widget import QueryParam
from polarify import polarify
from solara.components.file_drop import FileInfo
class ColorData(TypedDict):
data: list[QueryParam]
NonSelectedColor: None
class TooltipData(TypedDict):
data: list[QueryParam]
class CustomData(TypedDict):
data: str
format: Literal["cif"]
binary: Literal[False]
@dataclass
class AlphaFoldData:
name: str
structure: Structure
atom_data: pl.DataFrame
residue_data: pl.DataFrame
custom_data: CustomData
color_data: ColorData
tooltip_data: TooltipData
def write_atoms(self):
return self.atom_data.write_csv()
COLOR_LUT = {
"very-high": {"r": 16, "g": 109, "b": 255},
"confident": {"r": 16, "g": 207, "b": 241},
"low": {"r": 246, "g": 237, "b": 18},
"very-low": {"r": 239, "g": 130, "b": 30},
}
NO_COLOR_DATA = {"data": [], "nonSelectedColor": None}
NO_TOOLTIP_DATA = {"data": []}
PARSER = MMCIFParser()
result_index = solara.reactive(0)
file_info = solara.reactive(cast(Optional[FileInfo], None))
@polarify
def assign_confidence(x: pl.Expr) -> pl.Expr:
s = pl.lit("very-high")
if x < 50:
s = pl.lit("very-low")
elif x < 70:
s = pl.lit("low")
elif x < 90:
s = pl.lit("confident")
return s
@solara.lab.task
def load_result() -> Optional[AlphaFoldData]:
f_idx = result_index.value
with ZipFile(file_info.value["file_obj"]) as zf:
files = zf.namelist()
names = sorted(f for f in files if f.endswith(".cif"))
structure_file = sorted(f for f in files if f.endswith(".cif"))[f_idx]
json_data_file = sorted(f for f in files if "full_data" in f)[f_idx]
with zf.open(json_data_file) as json_f:
json_load = json.load(json_f)
cif_str = zf.read(structure_file).decode("utf-8")
alphafold_name = structure_file.rstrip(".cif")
sio = StringIO(cif_str)
sio.seek(0)
structure = PARSER.get_structure(structure_file.removesuffix(".cif"), sio)
names = pl.Series(
(atom.get_parent().resname for atom in structure.get_atoms()),
dtype=pl.Categorical,
)
resn = pl.Series(atom.get_parent().id[1] for atom in structure.get_atoms())
chain = pl.Series(json_load["atom_chain_ids"], dtype=pl.Categorical)
atoms_df = pl.DataFrame(
{
"name": names,
"resn": resn,
"chain": chain,
"plddt": json_load["atom_plddts"],
}
)
residue_df = (
atoms_df.group_by(["chain", "resn", "name"])
.agg(pl.col("plddt").mean().alias("mean_plddt"))
.sort(["chain", "resn"])
.with_columns(
assign_confidence(pl.col("mean_plddt"))
.alias("confidence")
.cast(pl.Categorical)
)
)
custom_data = {
"data": cif_str,
"format": "cif",
"binary": False,
}
color_query = []
tooltip_query = []
for chain, resn, name, mean_plddt, confidence in residue_df.iter_rows():
res_color = {
"struct_asym_id": chain,
"residue_number": resn,
"color": COLOR_LUT[confidence],
}
res_tt = {
"struct_asym_id": chain,
"residue_number": resn,
"tooltip": f"Confidence: {confidence}; plddt: {mean_plddt:.2f}",
}
color_query.append(res_color)
tooltip_query.append(res_tt)
plddt_color_data = {"data": color_query, "nonSelectedColor": None}
plddt_tooltip_data = {"data": tooltip_query}
data = AlphaFoldData(
name=alphafold_name,
structure=structure,
atom_data=atoms_df,
residue_data=residue_df,
custom_data=custom_data,
color_data=plddt_color_data,
tooltip_data=plddt_tooltip_data,
)
return data
@solara.component
def Page():
color_mode = solara.use_reactive("chain")
spin = solara.use_reactive(True)
dark_effective = solara.lab.use_dark_effective()
def on_color_mode(value: str):
color_mode.set(value)
def set_result_index(value: int):
result_index.set(value)
load_result()
solara.Title("Solarafold result viewer")
with solara.AppBar():
solara.lab.ThemeToggle()
with solara.Sidebar():
solara.FileDrop(label="Upload zip file", on_file=file_info.set, lazy=True)
solara.Button(
label="Load result",
on_click=load_result,
block=True,
disabled=file_info.value is None,
)
if not load_result.not_called:
disabled = load_result.pending
solara.Select(
label="Result index",
value=result_index.value,
on_value=set_result_index,
values=list(range(5)),
disabled=disabled,
)
solara.Select(
label="Color mode",
values=["chain", "plddt"],
value=color_mode.value,
on_value=on_color_mode,
disabled=disabled,
)
solara.Checkbox(label="Spin", value=spin)
def write_atoms():
return load_result.value.atom_data.write_csv()
solara.FileDownload(
write_atoms,
filename="NA" if disabled else f"{load_result.value.name}_atoms.csv",
children=[
solara.Button(
"Download atom plddt",
block=True,
disabled=disabled,
)
],
)
solara.Div(style={"height": "20px"})
def write_residues():
return load_result.value.residue_data.write_csv()
solara.FileDownload(
write_residues,
filename="NA" if disabled else f"{load_result.value.name}_residues.csv",
children=[
solara.Button(
"Download residue plddt",
block=True,
disabled=disabled,
)
],
)
if load_result.not_called:
solara.HTML(
tag="p",
unsafe_innerHTML='Drag and drop an alphafold3 result .zip file to get started. You can download an example file <a href="https://www.gstatic.com/alphafoldserver/examplefold_pdb_8aw3/examplefold_pdb_8aw3.zip">here</a>.',
)
elif load_result.pending:
solara.ProgressLinear(load_result.pending)
elif load_result.finished:
fold_data: AlphaFoldData = load_result.value
color_data = (
NO_COLOR_DATA if color_mode.value == "chain" else fold_data.color_data
)
with solara.Card():
theme = "dark" if dark_effective else "light"
PDBeMolstar.element(
height="calc(100vh - 150px)",
custom_data=fold_data.custom_data,
color_data=color_data,
tooltips=fold_data.tooltip_data,
show_water=False,
spin=spin.value,
theme=theme,
).key(f"pdbemolstar-{dark_effective}")
@solara.component
def Layout(children):
dark_effective = solara.lab.use_dark_effective()
return solara.AppLayout(
children=children, toolbar_dark=dark_effective, color=None
) # if dark_effective else "primary")